Per sample shading

This cl introduces per sample shading in the fragment shader.
Rather than call the fragment shader multiple times per sample,
this cl adds a potential loop in the fragment shader where each
sample is processes in one of the loop's iteration.

- Each multisample related loop now processes either all samples,
  like before, or the current sample, if per sample shading is
  enabled
- A new per sample PixelProgram::maskAny() function was added
- emitEpilog() now has an option not to clear phis in order to be
  able to only clear them on the last sample
- The routine's fragCoord values are set per sample, with the
  proper sample offsets
- Similarly, the xxxx and yyyy values used for interpolation are
  now offset with the proper sample offsets when per sample
  shading is enabled

Bug: b/171415086
Change-Id: Ibd0c1bad23e2d81f7fa97240ebb50f88f1fee36e
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/51733
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Alexis Hétu <sugoi@google.com>
Tested-by: Alexis Hétu <sugoi@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/ComputeProgram.cpp b/src/Pipeline/ComputeProgram.cpp
index b1cdb77..1016663 100644
--- a/src/Pipeline/ComputeProgram.cpp
+++ b/src/Pipeline/ComputeProgram.cpp
@@ -57,6 +57,7 @@
 	shader->emitProlog(&routine);
 	emit(&routine);
 	shader->emitEpilog(&routine);
+	shader->clearPhis(&routine);
 }
 
 void ComputeProgram::setWorkgroupBuiltins(Pointer<Byte> data, SpirvRoutine *routine, Int workgroupID[3])
diff --git a/src/Pipeline/PixelProgram.cpp b/src/Pipeline/PixelProgram.cpp
index 0deffa9..5066996 100644
--- a/src/Pipeline/PixelProgram.cpp
+++ b/src/Pipeline/PixelProgram.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "PixelProgram.hpp"
+#include "Constants.hpp"
 
 #include "SamplerCore.hpp"
 #include "Device/Primitive.hpp"
@@ -56,14 +57,37 @@
 	return mask;
 }
 
-void PixelProgram::setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4])
+Int4 PixelProgram::maskAny(Int cMask, Int sMask, Int zMask) const
+{
+	Int maskUnion = cMask & sMask & zMask;
+
+	// Convert to 4 booleans
+	Int4 laneBits = Int4(1, 2, 4, 8);
+	Int4 laneShiftsToMSB = Int4(31, 30, 29, 28);
+	Int4 mask(maskUnion);
+	mask = ((mask & laneBits) << laneShiftsToMSB) >> Int4(31);
+	return mask;
+}
+
+void PixelProgram::setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId)
 {
 	routine.setImmutableInputBuiltins(spirvShader);
 
 	// TODO(b/146486064): Consider only assigning these to the SpirvRoutine iff
 	// they are ever going to be read.
-	routine.fragCoord[0] = SIMD::Float(Float(x)) + SIMD::Float(0.5f, 1.5f, 0.5f, 1.5f);
-	routine.fragCoord[1] = SIMD::Float(Float(y)) + SIMD::Float(0.5f, 0.5f, 1.5f, 1.5f);
+	float x0 = 0.5f;
+	float y0 = 0.5f;
+	float x1 = 1.5f;
+	float y1 = 1.5f;
+	if((state.multiSampleCount > 1) && (sampleId >= 0))
+	{
+		x0 = Constants::VkSampleLocations4[sampleId][0];
+		y0 = Constants::VkSampleLocations4[sampleId][1];
+		x1 = 1.0f + x0;
+		y1 = 1.0f + y0;
+	}
+	routine.fragCoord[0] = SIMD::Float(Float(x)) + SIMD::Float(x0, x1, x0, x1);
+	routine.fragCoord[1] = SIMD::Float(Float(y)) + SIMD::Float(y0, y0, y1, y1);
 	routine.fragCoord[2] = z[0];  // sample 0
 	routine.fragCoord[3] = w;
 
@@ -109,8 +133,11 @@
 	});
 }
 
-void PixelProgram::applyShader(Int cMask[4], Int sMask[4], Int zMask[4])
+void PixelProgram::applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId)
 {
+	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
+	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
+
 	routine.descriptorSets = data + OFFSET(DrawData, descriptorSets);
 	routine.descriptorDynamicOffsets = data + OFFSET(DrawData, descriptorDynamicOffsets);
 	routine.pushConstants = data + OFFSET(DrawData, pushConstants);
@@ -130,8 +157,8 @@
 		static_assert(SIMD::Width == 4, "Expects SIMD width to be 4");
 		Int4 laneBits = Int4(1, 2, 4, 8);
 
-		Int4 inputSampleMask = Int4(1) & CmpNEQ(Int4(cMask[0]) & laneBits, Int4(0));
-		for(auto i = 1u; i < state.multiSampleCount; i++)
+		Int4 inputSampleMask = 0;
+		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
 		{
 			inputSampleMask |= Int4(1 << i) & CmpNEQ(Int4(cMask[i]) & laneBits, Int4(0));
 		}
@@ -146,11 +173,15 @@
 	// Note: all lanes initially active to facilitate derivatives etc. Actual coverage is
 	// handled separately, through the cMask.
 	auto activeLaneMask = SIMD::Int(0xFFFFFFFF);
-	auto storesAndAtomicsMask = maskAny(cMask, sMask, zMask);
+	auto storesAndAtomicsMask = (sampleId >= 0) ? maskAny(cMask[sampleId], sMask[sampleId], zMask[sampleId]) : maskAny(cMask, sMask, zMask);
 	routine.killMask = 0;
 
 	spirvShader->emit(&routine, activeLaneMask, storesAndAtomicsMask, descriptorSets);
 	spirvShader->emitEpilog(&routine);
+	if((sampleId < 0) || (sampleId == static_cast<int>(state.multiSampleCount - 1)))
+	{
+		spirvShader->clearPhis(&routine);
+	}
 
 	for(int i = 0; i < RENDERTARGETS; i++)
 	{
@@ -168,7 +199,7 @@
 
 	if(spirvShader->getModes().ContainsKill)
 	{
-		for(auto i = 0u; i < state.multiSampleCount; i++)
+		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
 		{
 			cMask[i] &= ~routine.killMask;
 		}
@@ -179,7 +210,7 @@
 	{
 		auto outputSampleMask = As<SIMD::Int>(routine.getVariable(it->second.Id)[it->second.FirstComponent]);
 
-		for(auto i = 0u; i < state.multiSampleCount; i++)
+		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
 		{
 			cMask[i] &= SignMask(CmpNEQ(outputSampleMask & SIMD::Int(1 << i), SIMD::Int(0)));
 		}
@@ -192,14 +223,19 @@
 	}
 }
 
-Bool PixelProgram::alphaTest(Int cMask[4])
+Bool PixelProgram::alphaTest(Int cMask[4], int sampleId)
 {
 	if(!state.alphaToCoverage)
 	{
 		return true;
 	}
 
-	alphaToCoverage(cMask, c[0].w);
+	alphaToCoverage(cMask, c[0].w, sampleId);
+
+	if(sampleId >= 0)
+	{
+		return cMask[sampleId] != 0x0;
+	}
 
 	Int pass = cMask[0];
 
@@ -211,8 +247,11 @@
 	return pass != 0x0;
 }
 
-void PixelProgram::rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4])
+void PixelProgram::rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId)
 {
+	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
+	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
+
 	for(int index = 0; index < RENDERTARGETS; index++)
 	{
 		if(!state.colorWriteActive(index))
@@ -237,7 +276,7 @@
 			case VK_FORMAT_A8B8G8R8_SRGB_PACK32:
 			case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
 			case VK_FORMAT_A2R10G10B10_UNORM_PACK32:
-				for(unsigned int q = 0; q < state.multiSampleCount; q++)
+				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 				{
 					if(state.multiSampleMask & (1 << q))
 					{
@@ -283,7 +322,7 @@
 			case VK_FORMAT_A8B8G8R8_SINT_PACK32:
 			case VK_FORMAT_A2B10G10R10_UINT_PACK32:
 			case VK_FORMAT_A2R10G10B10_UINT_PACK32:
-				for(unsigned int q = 0; q < state.multiSampleCount; q++)
+				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 				{
 					if(state.multiSampleMask & (1 << q))
 					{
diff --git a/src/Pipeline/PixelProgram.hpp b/src/Pipeline/PixelProgram.hpp
index 40f8b93..d2d2301 100644
--- a/src/Pipeline/PixelProgram.hpp
+++ b/src/Pipeline/PixelProgram.hpp
@@ -34,10 +34,10 @@
 	virtual ~PixelProgram() {}
 
 protected:
-	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4]);
-	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4]);
-	virtual Bool alphaTest(Int cMask[4]);
-	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4]);
+	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId);
+	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId);
+	virtual Bool alphaTest(Int cMask[4], int sampleId);
+	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId);
 
 private:
 	// Color outputs
@@ -48,6 +48,7 @@
 
 	Int4 maskAny(Int cMask[4]) const;
 	Int4 maskAny(Int cMask[4], Int sMask[4], Int zMask[4]) const;
+	Int4 maskAny(Int cMask, Int sMask, Int zMask) const;
 };
 
 }  // namespace sw
diff --git a/src/Pipeline/PixelRoutine.cpp b/src/Pipeline/PixelRoutine.cpp
index 0653d7c..5a53a82 100644
--- a/src/Pipeline/PixelRoutine.cpp
+++ b/src/Pipeline/PixelRoutine.cpp
@@ -62,222 +62,246 @@
 	Int zMask[4];  // Depth mask
 	Int sMask[4];  // Stencil mask
 
-	for(unsigned int q = 0; q < state.multiSampleCount; q++)
+	bool perSampleShading = (state.sampleShadingEnabled && (state.minSampleShading > 0.0f)) ||
+	                        (spirvShader && spirvShader->getModes().ContainsSampleQualifier);
+	unsigned int numSampleRenders = perSampleShading ? state.multiSampleCount : 1;
+
+	for(unsigned int i = 0; i < numSampleRenders; ++i)
 	{
-		zMask[q] = cMask[q];
-		sMask[q] = cMask[q];
-	}
+		int sampleId = perSampleShading ? i : -1;
+		unsigned int sampleLoopInit = perSampleShading ? sampleId : 0;
+		unsigned int sampleLoopEnd = perSampleShading ? sampleId + 1 : state.multiSampleCount;
 
-	for(unsigned int q = 0; q < state.multiSampleCount; q++)
-	{
-		stencilTest(sBuffer, q, x, sMask[q], cMask[q]);
-	}
-
-	Float4 f;
-	Float4 rhwCentroid;
-
-	Float4 xxxx = Float4(Float(x)) + *Pointer<Float4>(primitive + OFFSET(Primitive, xQuad), 16);
-
-	if(interpolateZ())
-	{
-		for(unsigned int q = 0; q < state.multiSampleCount; q++)
+		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 		{
-			Float4 x = xxxx;
-
-			if(state.enableMultiSampling)
-			{
-				x += *Pointer<Float4>(constants + OFFSET(Constants, X) + q * sizeof(float4));
-			}
-
-			z[q] = interpolate(x, Dz[q], z[q], primitive + OFFSET(Primitive, z), false, false);
-
-			if(state.depthBias)
-			{
-				z[q] += *Pointer<Float4>(primitive + OFFSET(Primitive, zBias), 16);
-			}
-
-			if(state.depthClamp)
-			{
-				z[q] = Min(Max(z[q], Float4(0.0f)), Float4(1.0f));
-			}
-		}
-	}
-
-	Bool depthPass = false;
-
-	if(earlyDepthTest)
-	{
-		for(unsigned int q = 0; q < state.multiSampleCount; q++)
-		{
-			depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
-		}
-	}
-
-	If(depthPass || Bool(!earlyDepthTest))
-	{
-		Float4 yyyy = Float4(Float(y)) + *Pointer<Float4>(primitive + OFFSET(Primitive, yQuad), 16);
-
-		// Centroid locations
-		Float4 XXXX = Float4(0.0f);
-		Float4 YYYY = Float4(0.0f);
-
-		if(state.centroid)
-		{
-			Float4 WWWW(1.0e-9f);
-
-			for(unsigned int q = 0; q < state.multiSampleCount; q++)
-			{
-				XXXX += *Pointer<Float4>(constants + OFFSET(Constants, sampleX[q]) + 16 * cMask[q]);
-				YYYY += *Pointer<Float4>(constants + OFFSET(Constants, sampleY[q]) + 16 * cMask[q]);
-				WWWW += *Pointer<Float4>(constants + OFFSET(Constants, weight) + 16 * cMask[q]);
-			}
-
-			WWWW = Rcp(WWWW, Precision::Relaxed);
-			XXXX *= WWWW;
-			YYYY *= WWWW;
-
-			XXXX += xxxx;
-			YYYY += yyyy;
+			zMask[q] = cMask[q];
+			sMask[q] = cMask[q];
 		}
 
-		if(interpolateW())
+		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 		{
-			w = interpolate(xxxx, Dw, rhw, primitive + OFFSET(Primitive, w), false, false);
-			rhw = reciprocal(w, false, false, true);
+			stencilTest(sBuffer, q, x, sMask[q], cMask[q]);
+		}
+
+		Float4 f;
+		Float4 rhwCentroid;
+
+		Float4 xxxx = Float4(Float(x)) + *Pointer<Float4>(primitive + OFFSET(Primitive, xQuad), 16);
+
+		if(interpolateZ())
+		{
+			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			{
+				Float4 x = xxxx;
+
+				if(state.enableMultiSampling)
+				{
+					x -= *Pointer<Float4>(constants + OFFSET(Constants, X) + q * sizeof(float4));
+				}
+
+				z[q] = interpolate(x, Dz[q], z[q], primitive + OFFSET(Primitive, z), false, false);
+
+				if(state.depthBias)
+				{
+					z[q] += *Pointer<Float4>(primitive + OFFSET(Primitive, zBias), 16);
+				}
+
+				if(state.depthClamp)
+				{
+					z[q] = Min(Max(z[q], Float4(0.0f)), Float4(1.0f));
+				}
+			}
+		}
+
+		Bool depthPass = false;
+
+		if(earlyDepthTest)
+		{
+			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			{
+				depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
+			}
+		}
+
+		If(depthPass || Bool(!earlyDepthTest))
+		{
+			Float4 yyyy = Float4(Float(y)) + *Pointer<Float4>(primitive + OFFSET(Primitive, yQuad), 16);
+
+			// Centroid locations
+			Float4 XXXX = Float4(0.0f);
+			Float4 YYYY = Float4(0.0f);
 
 			if(state.centroid)
 			{
-				rhwCentroid = reciprocal(SpirvRoutine::interpolateAtXY(XXXX, YYYY, rhwCentroid, primitive + OFFSET(Primitive, w), false, false));
-			}
-		}
+				Float4 WWWW(1.0e-9f);
 
-		if(spirvShader)
-		{
-			for(int interpolant = 0; interpolant < MAX_INTERFACE_COMPONENTS; interpolant++)
-			{
-				auto const &input = spirvShader->inputs[interpolant];
-				if(input.Type != SpirvShader::ATTRIBTYPE_UNUSED)
+				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 				{
-					if(input.Centroid && state.enableMultiSampling)
-					{
-						routine.inputs[interpolant] =
-						    SpirvRoutine::interpolateAtXY(XXXX, YYYY, rhwCentroid,
-						                                  primitive + OFFSET(Primitive, V[interpolant]),
-						                                  input.Flat, !input.NoPerspective);
-					}
-					else
-					{
-						routine.inputs[interpolant] =
-						    interpolate(xxxx, Dv[interpolant], rhw,
-						                primitive + OFFSET(Primitive, V[interpolant]),
-						                input.Flat, !input.NoPerspective);
-					}
+					XXXX += *Pointer<Float4>(constants + OFFSET(Constants, sampleX[q]) + 16 * cMask[q]);
+					YYYY += *Pointer<Float4>(constants + OFFSET(Constants, sampleY[q]) + 16 * cMask[q]);
+					WWWW += *Pointer<Float4>(constants + OFFSET(Constants, weight) + 16 * cMask[q]);
+				}
+
+				WWWW = Rcp(WWWW, Precision::Relaxed);
+				XXXX *= WWWW;
+				YYYY *= WWWW;
+
+				XXXX += xxxx;
+				YYYY += yyyy;
+			}
+
+			if(interpolateW())
+			{
+				w = interpolate(xxxx, Dw, rhw, primitive + OFFSET(Primitive, w), false, false);
+				rhw = reciprocal(w, false, false, true);
+
+				if(state.centroid)
+				{
+					rhwCentroid = reciprocal(SpirvRoutine::interpolateAtXY(XXXX, YYYY, rhwCentroid, primitive + OFFSET(Primitive, w), false, false));
 				}
 			}
 
-			setBuiltins(x, y, z, w, cMask);
-
-			for(uint32_t i = 0; i < state.numClipDistances; i++)
+			if(spirvShader)
 			{
-				auto distance = interpolate(xxxx, DclipDistance[i], rhw,
-				                            primitive + OFFSET(Primitive, clipDistance[i]),
-				                            false, true);
-
-				auto clipMask = SignMask(CmpGE(distance, SIMD::Float(0)));
-				for(auto ms = 0u; ms < state.multiSampleCount; ms++)
+				if(perSampleShading && (state.multiSampleCount > 1))
 				{
-					// FIXME(b/148105887): Fragments discarded by clipping do not exist at
-					// all -- they should not be counted in queries or have their Z/S effects
-					// performed when early fragment tests are enabled.
-					cMask[ms] &= clipMask;
+					xxxx += Float4(Constants::SampleLocationsX[sampleId]);
+					yyyy += Float4(Constants::SampleLocationsY[sampleId]);
 				}
 
-				if(spirvShader->getUsedCapabilities().ClipDistance)
+				for(int interpolant = 0; interpolant < MAX_INTERFACE_COMPONENTS; interpolant++)
 				{
-					auto it = spirvShader->inputBuiltins.find(spv::BuiltInClipDistance);
+					auto const &input = spirvShader->inputs[interpolant];
+					if(input.Type != SpirvShader::ATTRIBTYPE_UNUSED)
+					{
+						if(input.Centroid && state.enableMultiSampling)
+						{
+							routine.inputs[interpolant] =
+							    SpirvRoutine::interpolateAtXY(XXXX, YYYY, rhwCentroid,
+							                                  primitive + OFFSET(Primitive, V[interpolant]),
+							                                  input.Flat, !input.NoPerspective);
+						}
+						else if(perSampleShading)
+						{
+							routine.inputs[interpolant] =
+							    SpirvRoutine::interpolateAtXY(xxxx, yyyy, rhw,
+							                                  primitive + OFFSET(Primitive, V[interpolant]),
+							                                  input.Flat, !input.NoPerspective);
+						}
+						else
+						{
+							routine.inputs[interpolant] =
+							    interpolate(xxxx, Dv[interpolant], rhw,
+							                primitive + OFFSET(Primitive, V[interpolant]),
+							                input.Flat, !input.NoPerspective);
+						}
+					}
+				}
+
+				setBuiltins(x, y, z, w, cMask, sampleId);
+
+				for(uint32_t i = 0; i < state.numClipDistances; i++)
+				{
+					auto distance = interpolate(xxxx, DclipDistance[i], rhw,
+					                            primitive + OFFSET(Primitive, clipDistance[i]),
+					                            false, true);
+
+					auto clipMask = SignMask(CmpGE(distance, SIMD::Float(0)));
+					for(auto ms = sampleLoopInit; ms < sampleLoopEnd; ms++)
+					{
+						// FIXME(b/148105887): Fragments discarded by clipping do not exist at
+						// all -- they should not be counted in queries or have their Z/S effects
+						// performed when early fragment tests are enabled.
+						cMask[ms] &= clipMask;
+					}
+
+					if(spirvShader->getUsedCapabilities().ClipDistance)
+					{
+						auto it = spirvShader->inputBuiltins.find(spv::BuiltInClipDistance);
+						if(it != spirvShader->inputBuiltins.end())
+						{
+							if(i < it->second.SizeInComponents)
+							{
+								routine.getVariable(it->second.Id)[it->second.FirstComponent + i] = distance;
+							}
+						}
+					}
+				}
+
+				if(spirvShader->getUsedCapabilities().CullDistance)
+				{
+					auto it = spirvShader->inputBuiltins.find(spv::BuiltInCullDistance);
 					if(it != spirvShader->inputBuiltins.end())
 					{
-						if(i < it->second.SizeInComponents)
+						for(uint32_t i = 0; i < state.numCullDistances; i++)
 						{
-							routine.getVariable(it->second.Id)[it->second.FirstComponent + i] = distance;
+							if(i < it->second.SizeInComponents)
+							{
+								routine.getVariable(it->second.Id)[it->second.FirstComponent + i] =
+								    interpolate(xxxx, DcullDistance[i], rhw,
+								                primitive + OFFSET(Primitive, cullDistance[i]),
+								                false, true);
+							}
 						}
 					}
 				}
 			}
 
-			if(spirvShader->getUsedCapabilities().CullDistance)
+			Bool alphaPass = true;
+
+			if(spirvShader)
 			{
-				auto it = spirvShader->inputBuiltins.find(spv::BuiltInCullDistance);
-				if(it != spirvShader->inputBuiltins.end())
+				bool earlyFragTests = (spirvShader && spirvShader->getModes().EarlyFragmentTests);
+				applyShader(cMask, earlyFragTests ? sMask : cMask, earlyDepthTest ? zMask : cMask, sampleId);
+			}
+
+			alphaPass = alphaTest(cMask, sampleId);
+
+			if((spirvShader && spirvShader->getModes().ContainsKill) || state.alphaToCoverage)
+			{
+				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 				{
-					for(uint32_t i = 0; i < state.numCullDistances; i++)
+					zMask[q] &= cMask[q];
+					sMask[q] &= cMask[q];
+				}
+			}
+
+			If(alphaPass)
+			{
+				if(!earlyDepthTest)
+				{
+					for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 					{
-						if(i < it->second.SizeInComponents)
-						{
-							routine.getVariable(it->second.Id)[it->second.FirstComponent + i] =
-							    interpolate(xxxx, DcullDistance[i], rhw,
-							                primitive + OFFSET(Primitive, cullDistance[i]),
-							                false, true);
-						}
+						depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
 					}
 				}
-			}
-		}
 
-		Bool alphaPass = true;
-
-		if(spirvShader)
-		{
-			bool earlyFragTests = (spirvShader && spirvShader->getModes().EarlyFragmentTests);
-			applyShader(cMask, earlyFragTests ? sMask : cMask, earlyDepthTest ? zMask : cMask);
-		}
-
-		alphaPass = alphaTest(cMask);
-
-		if((spirvShader && spirvShader->getModes().ContainsKill) || state.alphaToCoverage)
-		{
-			for(unsigned int q = 0; q < state.multiSampleCount; q++)
-			{
-				zMask[q] &= cMask[q];
-				sMask[q] &= cMask[q];
-			}
-		}
-
-		If(alphaPass)
-		{
-			if(!earlyDepthTest)
-			{
-				for(unsigned int q = 0; q < state.multiSampleCount; q++)
+				If(depthPass || Bool(earlyDepthTest))
 				{
-					depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
-				}
-			}
-
-			If(depthPass || Bool(earlyDepthTest))
-			{
-				for(unsigned int q = 0; q < state.multiSampleCount; q++)
-				{
-					if(state.multiSampleMask & (1 << q))
+					for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 					{
-						writeDepth(zBuffer, q, x, z[q], zMask[q]);
-
-						if(state.occlusionEnabled)
+						if(state.multiSampleMask & (1 << q))
 						{
-							occlusion += *Pointer<UInt>(constants + OFFSET(Constants, occlusionCount) + 4 * (zMask[q] & sMask[q]));
+							writeDepth(zBuffer, q, x, z[q], zMask[q]);
+
+							if(state.occlusionEnabled)
+							{
+								occlusion += *Pointer<UInt>(constants + OFFSET(Constants, occlusionCount) + 4 * (zMask[q] & sMask[q]));
+							}
 						}
 					}
-				}
 
-				rasterOperation(cBuffer, x, sMask, zMask, cMask);
+					rasterOperation(cBuffer, x, sMask, zMask, cMask, sampleId);
+				}
 			}
 		}
-	}
 
-	for(unsigned int q = 0; q < state.multiSampleCount; q++)
-	{
-		if(state.multiSampleMask & (1 << q))
+		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
 		{
-			writeStencil(sBuffer, q, x, sMask[q], zMask[q], cMask[q]);
+			if(state.multiSampleMask & (1 << q))
+			{
+				writeStencil(sBuffer, q, x, sMask[q], zMask[q], cMask[q]);
+			}
 		}
 	}
 }
@@ -546,22 +570,24 @@
 	}
 }
 
-void PixelRoutine::alphaToCoverage(Int cMask[4], const Float4 &alpha)
+void PixelRoutine::alphaToCoverage(Int cMask[4], const Float4 &alpha, int sampleId)
 {
-	Int4 coverage0 = CmpNLT(alpha, *Pointer<Float4>(data + OFFSET(DrawData, a2c0)));
-	Int4 coverage1 = CmpNLT(alpha, *Pointer<Float4>(data + OFFSET(DrawData, a2c1)));
-	Int4 coverage2 = CmpNLT(alpha, *Pointer<Float4>(data + OFFSET(DrawData, a2c2)));
-	Int4 coverage3 = CmpNLT(alpha, *Pointer<Float4>(data + OFFSET(DrawData, a2c3)));
+	static const int a2c[4] = {
+		OFFSET(DrawData, a2c0),
+		OFFSET(DrawData, a2c1),
+		OFFSET(DrawData, a2c2),
+		OFFSET(DrawData, a2c3),
+	};
 
-	Int aMask0 = SignMask(coverage0);
-	Int aMask1 = SignMask(coverage1);
-	Int aMask2 = SignMask(coverage2);
-	Int aMask3 = SignMask(coverage3);
+	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
+	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
 
-	cMask[0] &= aMask0;
-	cMask[1] &= aMask1;
-	cMask[2] &= aMask2;
-	cMask[3] &= aMask3;
+	for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+	{
+		Int4 coverage = CmpNLT(alpha, *Pointer<Float4>(data + a2c[q]));
+		Int aMask = SignMask(coverage);
+		cMask[q] &= aMask;
+	}
 }
 
 void PixelRoutine::writeDepth32F(Pointer<Byte> &zBuffer, int q, const Int &x, const Float4 &z, const Int &zMask)
diff --git a/src/Pipeline/PixelRoutine.hpp b/src/Pipeline/PixelRoutine.hpp
index e4dc029..46e0362 100644
--- a/src/Pipeline/PixelRoutine.hpp
+++ b/src/Pipeline/PixelRoutine.hpp
@@ -45,15 +45,15 @@
 	// Depth output
 	Float4 oDepth;
 
-	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4]) = 0;
-	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4]) = 0;
-	virtual Bool alphaTest(Int cMask[4]) = 0;
-	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4]) = 0;
+	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId) = 0;
+	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId) = 0;
+	virtual Bool alphaTest(Int cMask[4], int sampleId) = 0;
+	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId) = 0;
 
 	void quad(Pointer<Byte> cBuffer[4], Pointer<Byte> &zBuffer, Pointer<Byte> &sBuffer, Int cMask[4], Int &x, Int &y) override;
 
 	void alphaTest(Int &aMask, const Short4 &alpha);
-	void alphaToCoverage(Int cMask[4], const Float4 &alpha);
+	void alphaToCoverage(Int cMask[4], const Float4 &alpha, int sampleId);
 
 	// Raster operations
 	void alphaBlend(int index, const Pointer<Byte> &cBuffer, Vector4s &current, const Int &x);
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 8f7a6fb..c5c11dd 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -2471,7 +2471,10 @@
 				break;
 		}
 	}
+}
 
+void SpirvShader::clearPhis(SpirvRoutine *routine) const
+{
 	// Clear phis that are no longer used. This serves two purposes:
 	// (1) The phi rr::Variables are destructed, preventing pointless
 	//     materialization.
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 8631351..da40075 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -784,6 +784,7 @@
 	void emitProlog(SpirvRoutine *routine) const;
 	void emit(SpirvRoutine *routine, RValue<SIMD::Int> const &activeLaneMask, RValue<SIMD::Int> const &storesAndAtomicsMask, const vk::DescriptorSet::Bindings &descriptorSets) const;
 	void emitEpilog(SpirvRoutine *routine) const;
+	void clearPhis(SpirvRoutine *routine) const;
 
 	bool containsImageWrite() const { return imageWriteEmitted; }
 
diff --git a/src/Pipeline/VertexProgram.cpp b/src/Pipeline/VertexProgram.cpp
index f0d2979..be9d916 100644
--- a/src/Pipeline/VertexProgram.cpp
+++ b/src/Pipeline/VertexProgram.cpp
@@ -83,6 +83,7 @@
 	spirvShader->emit(&routine, activeLaneMask, storesAndAtomicsMask, descriptorSets);
 
 	spirvShader->emitEpilog(&routine);
+	spirvShader->clearPhis(&routine);
 }
 
 }  // namespace sw