Only process active fragment samples

This change refactors the handling of which samples we should iterate
over, by putting their indices into a container object. This enables the
use of range-based for loops to have a more elegant syntax and avoid the
confusion around sampleId -1 having special meaning and the begin and
end iteration markers still requiring checking the sample mask on each
loop.

Bug: b/194521425
Change-Id: Ib6fbbb3e89c3a5501311ebd81859608df44d1bd0
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/56008
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Sean Risser <srisser@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Device/QuadRasterizer.cpp b/src/Device/QuadRasterizer.cpp
index c56d4da..6c5397c 100644
--- a/src/Device/QuadRasterizer.cpp
+++ b/src/Device/QuadRasterizer.cpp
@@ -198,10 +198,6 @@
 						Short4 mask = CmpGT(xxxx, xLeft[i]) & CmpGT(xRight[i], xxxx);
 						cMask[q] = SignMask(PackSigned(mask, mask)) & 0x0000000F;
 					}
-					else
-					{
-						cMask[q] = 0;
-					}
 				}
 
 				quad(cBuffer, zBuffer, sBuffer, cMask, x, y);
diff --git a/src/Pipeline/PixelProgram.cpp b/src/Pipeline/PixelProgram.cpp
index cf815a4..809a03c 100644
--- a/src/Pipeline/PixelProgram.cpp
+++ b/src/Pipeline/PixelProgram.cpp
@@ -13,22 +13,31 @@
 // limitations under the License.
 
 #include "PixelProgram.hpp"
-#include "Constants.hpp"
 
+#include "Constants.hpp"
 #include "SamplerCore.hpp"
 #include "Device/Primitive.hpp"
 #include "Device/Renderer.hpp"
 
 namespace sw {
 
+PixelProgram::PixelProgram(
+    const PixelProcessor::State &state,
+    const vk::PipelineLayout *pipelineLayout,
+    const SpirvShader *spirvShader,
+    const vk::DescriptorSet::Bindings &descriptorSets)
+    : PixelRoutine(state, pipelineLayout, spirvShader, descriptorSets)
+{
+}
+
 // Union all cMask and return it as 4 booleans
-Int4 PixelProgram::maskAny(Int cMask[4]) const
+Int4 PixelProgram::maskAny(Int cMask[4], const SampleSet &samples)
 {
 	// See if at least 1 sample is used
-	Int maskUnion = cMask[0];
-	for(auto i = 1u; i < state.multiSampleCount; i++)
+	Int maskUnion = 0;
+	for(unsigned int q : samples)
 	{
-		maskUnion |= cMask[i];
+		maskUnion |= cMask[q];
 	}
 
 	// Convert to 4 booleans
@@ -40,13 +49,13 @@
 }
 
 // Union all cMask/sMask/zMask and return it as 4 booleans
-Int4 PixelProgram::maskAny(Int cMask[4], Int sMask[4], Int zMask[4]) const
+Int4 PixelProgram::maskAny(Int cMask[4], Int sMask[4], Int zMask[4], const SampleSet &samples)
 {
 	// See if at least 1 sample is used
-	Int maskUnion = cMask[0] & sMask[0] & zMask[0];
-	for(auto i = 1u; i < state.multiSampleCount; i++)
+	Int maskUnion = 0;
+	for(unsigned int q : samples)
 	{
-		maskUnion |= (cMask[i] & sMask[i] & zMask[i]);
+		maskUnion |= (cMask[q] & sMask[q] & zMask[q]);
 	}
 
 	// Convert to 4 booleans
@@ -57,19 +66,7 @@
 	return mask;
 }
 
-Int4 PixelProgram::maskAny(Int cMask, Int sMask, Int zMask) const
-{
-	Int maskUnion = cMask & sMask & zMask;
-
-	// Convert to 4 booleans
-	Int4 laneBits = Int4(1, 2, 4, 8);
-	Int4 laneShiftsToMSB = Int4(31, 30, 29, 28);
-	Int4 mask(maskUnion);
-	mask = ((mask & laneBits) << laneShiftsToMSB) >> Int4(31);
-	return mask;
-}
-
-void PixelProgram::setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId)
+void PixelProgram::setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], const SampleSet &samples)
 {
 	routine.setImmutableInputBuiltins(spirvShader);
 
@@ -79,20 +76,25 @@
 	float y0 = 0.5f;
 	float x1 = 1.5f;
 	float y1 = 1.5f;
-	if((state.multiSampleCount > 1) && (sampleId >= 0))
+
+	// "When Sample Shading is enabled, the x and y components of FragCoord reflect the
+	//  location of one of the samples corresponding to the shader invocation. Otherwise,
+	//  the x and y components of FragCoord reflect the location of the center of the fragment."
+	if(state.sampleShadingEnabled && state.multiSampleCount > 1)
 	{
-		x0 = Constants::VkSampleLocations4[sampleId][0];
-		y0 = Constants::VkSampleLocations4[sampleId][1];
+		x0 = Constants::VkSampleLocations4[samples[0]][0];
+		y0 = Constants::VkSampleLocations4[samples[0]][1];
 		x1 = 1.0f + x0;
 		y1 = 1.0f + y0;
 	}
+
 	routine.fragCoord[0] = SIMD::Float(Float(x)) + SIMD::Float(x0, x1, x0, x1);
 	routine.fragCoord[1] = SIMD::Float(Float(y)) + SIMD::Float(y0, y0, y1, y1);
 	routine.fragCoord[2] = z[0];  // sample 0
 	routine.fragCoord[3] = w;
 
 	routine.invocationsPerSubgroup = SIMD::Width;
-	routine.helperInvocation = ~maskAny(cMask);
+	routine.helperInvocation = ~maskAny(cMask, samples);
 	routine.windowSpacePosition[0] = x + SIMD::Int(0, 1, 0, 1);
 	routine.windowSpacePosition[1] = y + SIMD::Int(0, 0, 1, 1);
 	routine.viewID = *Pointer<Int>(data + OFFSET(DrawData, viewID));
@@ -133,11 +135,8 @@
 	});
 }
 
-void PixelProgram::applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId)
+void PixelProgram::executeShader(Int cMask[4], Int sMask[4], Int zMask[4], const SampleSet &samples)
 {
-	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
-	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
-
 	routine.descriptorSets = data + OFFSET(DrawData, descriptorSets);
 	routine.descriptorDynamicOffsets = data + OFFSET(DrawData, descriptorDynamicOffsets);
 	routine.pushConstants = data + OFFSET(DrawData, pushConstants);
@@ -158,43 +157,51 @@
 		Int4 laneBits = Int4(1, 2, 4, 8);
 
 		Int4 inputSampleMask = 0;
-		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
+		for(unsigned int q : samples)
 		{
-			inputSampleMask |= Int4(1 << i) & CmpNEQ(Int4(cMask[i]) & laneBits, Int4(0));
+			inputSampleMask |= Int4(1 << q) & CmpNEQ(Int4(cMask[q]) & laneBits, Int4(0));
 		}
 
 		routine.getVariable(it->second.Id)[it->second.FirstComponent] = As<Float4>(inputSampleMask);
 		// Sample mask input is an array, as the spec contemplates MSAA levels higher than 32.
 		// Fill any non-zero indices with 0.
 		for(auto i = 1u; i < it->second.SizeInComponents; i++)
+		{
 			routine.getVariable(it->second.Id)[it->second.FirstComponent + i] = Float4(0);
+		}
 	}
 
 	it = spirvShader->inputBuiltins.find(spv::BuiltInSampleId);
 	if(it != spirvShader->inputBuiltins.end())
 	{
+		ASSERT(samples.size() == 1);
+		int sampleId = samples[0];
 		routine.getVariable(it->second.Id)[it->second.FirstComponent] =
-		    As<SIMD::Float>(SIMD::Int((sampleId >= 0) ? sampleId : 0));
+		    As<SIMD::Float>(SIMD::Int(sampleId));
 	}
 
 	it = spirvShader->inputBuiltins.find(spv::BuiltInSamplePosition);
 	if(it != spirvShader->inputBuiltins.end())
 	{
+		ASSERT(samples.size() == 1);
+		int sampleId = samples[0];
 		routine.getVariable(it->second.Id)[it->second.FirstComponent + 0] =
-		    SIMD::Float(((sampleId >= 0) && (state.multiSampleCount > 1)) ? Constants::VkSampleLocations4[sampleId][0] : 0.5f);
+		    SIMD::Float((state.multiSampleCount > 1) ? Constants::VkSampleLocations4[sampleId][0] : 0.5f);
 		routine.getVariable(it->second.Id)[it->second.FirstComponent + 1] =
-		    SIMD::Float(((sampleId >= 0) && (state.multiSampleCount > 1)) ? Constants::VkSampleLocations4[sampleId][1] : 0.5f);
+		    SIMD::Float((state.multiSampleCount > 1) ? Constants::VkSampleLocations4[sampleId][1] : 0.5f);
 	}
 
 	// Note: all lanes initially active to facilitate derivatives etc. Actual coverage is
 	// handled separately, through the cMask.
 	auto activeLaneMask = SIMD::Int(0xFFFFFFFF);
-	auto storesAndAtomicsMask = (sampleId >= 0) ? maskAny(cMask[sampleId], sMask[sampleId], zMask[sampleId]) : maskAny(cMask, sMask, zMask);
+	auto storesAndAtomicsMask = maskAny(cMask, sMask, zMask, samples);
 	routine.killMask = 0;
 
 	spirvShader->emit(&routine, activeLaneMask, storesAndAtomicsMask, descriptorSets, state.multiSampleCount);
 	spirvShader->emitEpilog(&routine);
-	if((sampleId < 0) || (sampleId == static_cast<int>(state.multiSampleCount - 1)))
+	// At the last invocation of the fragment shader, clear phi data.
+	// TODO(b/178662288): Automatically clear phis through SpirvRoutine lifetime reduction.
+	if(samples[0] == static_cast<int>(state.multiSampleCount - 1))
 	{
 		spirvShader->clearPhis(&routine);
 	}
@@ -215,9 +222,9 @@
 
 	if(spirvShader->getModes().ContainsKill)
 	{
-		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
+		for(unsigned int q : samples)
 		{
-			cMask[i] &= ~routine.killMask;
+			cMask[q] &= ~routine.killMask;
 		}
 	}
 
@@ -226,9 +233,9 @@
 	{
 		auto outputSampleMask = As<SIMD::Int>(routine.getVariable(it->second.Id)[it->second.FirstComponent]);
 
-		for(auto i = sampleLoopInit; i < sampleLoopEnd; i++)
+		for(unsigned int q : samples)
 		{
-			cMask[i] &= SignMask(CmpNEQ(outputSampleMask & SIMD::Int(1 << i), SIMD::Int(0)));
+			cMask[q] &= SignMask(CmpNEQ(outputSampleMask & SIMD::Int(1 << q), SIMD::Int(0)));
 		}
 	}
 
@@ -243,23 +250,17 @@
 	}
 }
 
-Bool PixelProgram::alphaTest(Int cMask[4], int sampleId)
+Bool PixelProgram::alphaTest(Int cMask[4], const SampleSet &samples)
 {
 	if(!state.alphaToCoverage)
 	{
 		return true;
 	}
 
-	alphaToCoverage(cMask, c[0].w, sampleId);
+	alphaToCoverage(cMask, c[0].w, samples);
 
-	if(sampleId >= 0)
-	{
-		return cMask[sampleId] != 0x0;
-	}
-
-	Int pass = cMask[0];
-
-	for(unsigned int q = 1; q < state.multiSampleCount; q++)
+	Int pass = 0;
+	for(unsigned int q : samples)
 	{
 		pass = pass | cMask[q];
 	}
@@ -267,11 +268,8 @@
 	return pass != 0x0;
 }
 
-void PixelProgram::rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId)
+void PixelProgram::rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], const SampleSet &samples)
 {
-	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
-	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
-
 	for(int index = 0; index < RENDERTARGETS; index++)
 	{
 		if(!state.colorWriteActive(index))
@@ -296,21 +294,18 @@
 		case VK_FORMAT_A8B8G8R8_SRGB_PACK32:
 		case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
 		case VK_FORMAT_A2R10G10B10_UNORM_PACK32:
-			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			for(unsigned int q : samples)
 			{
-				if(state.multiSampleMask & (1 << q))
-				{
-					Pointer<Byte> buffer = cBuffer[index] + q * *Pointer<Int>(data + OFFSET(DrawData, colorSliceB[index]));
-					Vector4s color;
+				Pointer<Byte> buffer = cBuffer[index] + q * *Pointer<Int>(data + OFFSET(DrawData, colorSliceB[index]));
+				Vector4s color;
 
-					color.x = convertFixed16(c[index].x, false);
-					color.y = convertFixed16(c[index].y, false);
-					color.z = convertFixed16(c[index].z, false);
-					color.w = convertFixed16(c[index].w, false);
+				color.x = convertFixed16(c[index].x, false);
+				color.y = convertFixed16(c[index].y, false);
+				color.z = convertFixed16(c[index].z, false);
+				color.w = convertFixed16(c[index].w, false);
 
-					alphaBlend(index, buffer, color, x);
-					writeColor(index, buffer, x, color, sMask[q], zMask[q], cMask[q]);
-				}
+				alphaBlend(index, buffer, color, x);
+				writeColor(index, buffer, x, color, sMask[q], zMask[q], cMask[q]);
 			}
 			break;
 		case VK_FORMAT_R16_SFLOAT:
@@ -342,16 +337,13 @@
 		case VK_FORMAT_A8B8G8R8_SINT_PACK32:
 		case VK_FORMAT_A2B10G10R10_UINT_PACK32:
 		case VK_FORMAT_A2R10G10B10_UINT_PACK32:
-			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			for(unsigned int q : samples)
 			{
-				if(state.multiSampleMask & (1 << q))
-				{
-					Pointer<Byte> buffer = cBuffer[index] + q * *Pointer<Int>(data + OFFSET(DrawData, colorSliceB[index]));
-					Vector4f color = c[index];
+				Pointer<Byte> buffer = cBuffer[index] + q * *Pointer<Int>(data + OFFSET(DrawData, colorSliceB[index]));
+				Vector4f color = c[index];
 
-					alphaBlend(index, buffer, color, x);
-					writeColor(index, buffer, x, color, sMask[q], zMask[q], cMask[q]);
-				}
+				alphaBlend(index, buffer, color, x);
+				writeColor(index, buffer, x, color, sMask[q], zMask[q], cMask[q]);
 			}
 			break;
 		default:
diff --git a/src/Pipeline/PixelProgram.hpp b/src/Pipeline/PixelProgram.hpp
index d2d2301..306c3da 100644
--- a/src/Pipeline/PixelProgram.hpp
+++ b/src/Pipeline/PixelProgram.hpp
@@ -24,20 +24,17 @@
 public:
 	PixelProgram(
 	    const PixelProcessor::State &state,
-	    vk::PipelineLayout const *pipelineLayout,
-	    SpirvShader const *spirvShader,
-	    const vk::DescriptorSet::Bindings &descriptorSets)
-	    : PixelRoutine(state, pipelineLayout, spirvShader, descriptorSets)
-	{
-	}
+	    const vk::PipelineLayout *pipelineLayout,
+	    const SpirvShader *spirvShader,
+	    const vk::DescriptorSet::Bindings &descriptorSets);
 
 	virtual ~PixelProgram() {}
 
 protected:
-	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId);
-	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId);
-	virtual Bool alphaTest(Int cMask[4], int sampleId);
-	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId);
+	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], const SampleSet &samples);
+	virtual void executeShader(Int cMask[4], Int sMask[4], Int zMask[4], const SampleSet &samples);
+	virtual Bool alphaTest(Int cMask[4], const SampleSet &samples);
+	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], const SampleSet &samples);
 
 private:
 	// Color outputs
@@ -46,9 +43,8 @@
 	// Raster operations
 	void clampColor(Vector4f oC[RENDERTARGETS]);
 
-	Int4 maskAny(Int cMask[4]) const;
-	Int4 maskAny(Int cMask[4], Int sMask[4], Int zMask[4]) const;
-	Int4 maskAny(Int cMask, Int sMask, Int zMask) const;
+	static Int4 maskAny(Int cMask[4], const SampleSet &samples);
+	static Int4 maskAny(Int cMask[4], Int sMask[4], Int zMask[4], const SampleSet &samples);
 };
 
 }  // namespace sw
diff --git a/src/Pipeline/PixelRoutine.cpp b/src/Pipeline/PixelRoutine.cpp
index 543b1c9..031a028 100644
--- a/src/Pipeline/PixelRoutine.cpp
+++ b/src/Pipeline/PixelRoutine.cpp
@@ -44,6 +44,7 @@
 
 		// Clearing inputs to 0 is not demanded by the spec,
 		// but it makes the undefined behavior deterministic.
+		// TODO(b/155148722): Remove to detect UB.
 		for(int i = 0; i < MAX_INTERFACE_COMPONENTS; i++)
 		{
 			routine.inputs[i] = Float4(0.0f);
@@ -60,9 +61,27 @@
 {
 }
 
+PixelRoutine::SampleSet PixelRoutine::getSampleSet(int invocation) const
+{
+	unsigned int sampleBegin = perSampleShading ? invocation : 0;
+	unsigned int sampleEnd = perSampleShading ? (invocation + 1) : state.multiSampleCount;
+
+	SampleSet samples;
+
+	for(unsigned int q = sampleBegin; q < sampleEnd; q++)
+	{
+		if(state.multiSampleMask & (1 << q))
+		{
+			samples.push_back(q);
+		}
+	}
+
+	return samples;
+}
+
 void PixelRoutine::quad(Pointer<Byte> cBuffer[RENDERTARGETS], Pointer<Byte> &zBuffer, Pointer<Byte> &sBuffer, Int cMask[4], Int &x, Int &y)
 {
-	const bool earlyDepthTest = !spirvShader || spirvShader->getModes().EarlyFragmentTests;
+	const bool earlyFragmentTests = !spirvShader || spirvShader->getModes().EarlyFragmentTests;
 
 	Int zMask[4];  // Depth mask
 	Int sMask[4];  // Stencil mask
@@ -70,17 +89,20 @@
 
 	for(int invocation = 0; invocation < invocationCount; invocation++)
 	{
-		int sampleId = perSampleShading ? invocation : -1;
-		unsigned int sampleLoopInit = perSampleShading ? sampleId : 0;
-		unsigned int sampleLoopEnd = perSampleShading ? sampleId + 1 : state.multiSampleCount;
+		SampleSet samples = getSampleSet(invocation);
 
-		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+		if(samples.empty())
+		{
+			continue;
+		}
+
+		for(unsigned int q : samples)
 		{
 			zMask[q] = cMask[q];
 			sMask[q] = cMask[q];
 		}
 
-		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+		for(unsigned int q : samples)
 		{
 			stencilTest(sBuffer, q, x, sMask[q], cMask[q]);
 		}
@@ -92,7 +114,7 @@
 
 		if(interpolateZ())
 		{
-			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			for(unsigned int q : samples)
 			{
 				Float4 x = xxxx;
 
@@ -118,16 +140,16 @@
 
 		Bool depthPass = false;
 
-		if(earlyDepthTest)
+		if(earlyFragmentTests)
 		{
-			for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+			for(unsigned int q : samples)
 			{
 				depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
 				depthBoundsTest(zBuffer, q, x, zMask[q], cMask[q]);
 			}
 		}
 
-		If(depthPass || Bool(!earlyDepthTest))
+		If(depthPass || Bool(!earlyFragmentTests))
 		{
 			Float4 yyyy = Float4(Float(y)) + *Pointer<Float4>(primitive + OFFSET(Primitive, yQuad), 16);
 
@@ -139,7 +161,7 @@
 			{
 				Float4 WWWW(1.0e-9f);
 
-				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+				for(unsigned int q : samples)
 				{
 					XXXX += *Pointer<Float4>(constants + OFFSET(Constants, sampleX[q]) + 16 * cMask[q]);
 					YYYY += *Pointer<Float4>(constants + OFFSET(Constants, sampleY[q]) + 16 * cMask[q]);
@@ -182,8 +204,8 @@
 
 				if(perSampleShading && (state.multiSampleCount > 1))
 				{
-					xxxx += Float4(Constants::SampleLocationsX[sampleId]);
-					yyyy += Float4(Constants::SampleLocationsY[sampleId]);
+					xxxx += Float4(Constants::SampleLocationsX[samples[0]]);
+					yyyy += Float4(Constants::SampleLocationsY[samples[0]]);
 				}
 
 				for(int interpolant = 0; interpolant < MAX_INTERFACE_COMPONENTS; interpolant++)
@@ -215,7 +237,7 @@
 					}
 				}
 
-				setBuiltins(x, y, unclampedZ, w, cMask, sampleId);
+				setBuiltins(x, y, unclampedZ, w, cMask, samples);
 
 				for(uint32_t i = 0; i < state.numClipDistances; i++)
 				{
@@ -224,12 +246,12 @@
 					                            false, true);
 
 					auto clipMask = SignMask(CmpGE(distance, SIMD::Float(0)));
-					for(auto ms = sampleLoopInit; ms < sampleLoopEnd; ms++)
+					for(unsigned int q : samples)
 					{
 						// FIXME(b/148105887): Fragments discarded by clipping do not exist at
 						// all -- they should not be counted in queries or have their Z/S effects
 						// performed when early fragment tests are enabled.
-						cMask[ms] &= clipMask;
+						cMask[q] &= clipMask;
 					}
 
 					if(spirvShader->getUsedCapabilities().ClipDistance)
@@ -264,19 +286,16 @@
 				}
 			}
 
-			Bool alphaPass = true;
-
 			if(spirvShader)
 			{
-				bool earlyFragTests = (spirvShader && spirvShader->getModes().EarlyFragmentTests);
-				applyShader(cMask, earlyFragTests ? sMask : cMask, earlyDepthTest ? zMask : cMask, sampleId);
+				executeShader(cMask, earlyFragmentTests ? sMask : cMask, earlyFragmentTests ? zMask : cMask, samples);
 			}
 
-			alphaPass = alphaTest(cMask, sampleId);
+			Bool alphaPass = alphaTest(cMask, samples);
 
 			if((spirvShader && spirvShader->getModes().ContainsKill) || state.alphaToCoverage)
 			{
-				for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+				for(unsigned int q : samples)
 				{
 					zMask[q] &= cMask[q];
 					sMask[q] &= cMask[q];
@@ -285,41 +304,35 @@
 
 			If(alphaPass)
 			{
-				if(!earlyDepthTest)
+				if(!earlyFragmentTests)
 				{
-					for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+					for(unsigned int q : samples)
 					{
 						depthPass = depthPass || depthTest(zBuffer, q, x, z[q], sMask[q], zMask[q], cMask[q]);
 						depthBoundsTest(zBuffer, q, x, zMask[q], cMask[q]);
 					}
 				}
 
-				If(depthPass || Bool(earlyDepthTest))
+				If(depthPass || Bool(earlyFragmentTests))
 				{
-					for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+					for(unsigned int q : samples)
 					{
-						if(state.multiSampleMask & (1 << q))
-						{
-							writeDepth(zBuffer, q, x, z[q], zMask[q]);
+						writeDepth(zBuffer, q, x, z[q], zMask[q]);
 
-							if(state.occlusionEnabled)
-							{
-								occlusion += *Pointer<UInt>(constants + OFFSET(Constants, occlusionCount) + 4 * (zMask[q] & sMask[q]));
-							}
+						if(state.occlusionEnabled)
+						{
+							occlusion += *Pointer<UInt>(constants + OFFSET(Constants, occlusionCount) + 4 * (zMask[q] & sMask[q]));
 						}
 					}
 
-					rasterOperation(cBuffer, x, sMask, zMask, cMask, sampleId);
+					rasterOperation(cBuffer, x, sMask, zMask, cMask, samples);
 				}
 			}
 		}
 
-		for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+		for(unsigned int q : samples)
 		{
-			if(state.multiSampleMask & (1 << q))
-			{
-				writeStencil(sBuffer, q, x, sMask[q], zMask[q], cMask[q]);
-			}
+			writeStencil(sBuffer, q, x, sMask[q], zMask[q], cMask[q]);
 		}
 	}
 }
@@ -640,7 +653,7 @@
 	}
 }
 
-void PixelRoutine::alphaToCoverage(Int cMask[4], const Float4 &alpha, int sampleId)
+void PixelRoutine::alphaToCoverage(Int cMask[4], const Float4 &alpha, const SampleSet &samples)
 {
 	static const int a2c[4] = {
 		OFFSET(DrawData, a2c0),
@@ -649,10 +662,7 @@
 		OFFSET(DrawData, a2c3),
 	};
 
-	unsigned int sampleLoopInit = (sampleId >= 0) ? sampleId : 0;
-	unsigned int sampleLoopEnd = (sampleId >= 0) ? sampleId + 1 : state.multiSampleCount;
-
-	for(unsigned int q = sampleLoopInit; q < sampleLoopEnd; q++)
+	for(unsigned int q : samples)
 	{
 		Int4 coverage = CmpNLT(alpha, *Pointer<Float4>(data + a2c[q]));
 		Int aMask = SignMask(coverage);
diff --git a/src/Pipeline/PixelRoutine.hpp b/src/Pipeline/PixelRoutine.hpp
index fd2f9a1..edb7eb4 100644
--- a/src/Pipeline/PixelRoutine.hpp
+++ b/src/Pipeline/PixelRoutine.hpp
@@ -17,6 +17,8 @@
 
 #include "Device/QuadRasterizer.hpp"
 
+#include <vector>
+
 namespace sw {
 
 class PixelShader;
@@ -33,6 +35,8 @@
 	virtual ~PixelRoutine();
 
 protected:
+	using SampleSet = std::vector<int>;
+
 	Float4 z[4];  // Multisampled z
 	Float4 w;     // Used as is
 	Float4 rhw;   // Reciprocal w
@@ -45,15 +49,15 @@
 	// Depth output
 	Float4 oDepth;
 
-	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], int sampleId) = 0;
-	virtual void applyShader(Int cMask[4], Int sMask[4], Int zMask[4], int sampleId) = 0;
-	virtual Bool alphaTest(Int cMask[4], int sampleId) = 0;
-	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], int sampleId) = 0;
+	virtual void setBuiltins(Int &x, Int &y, Float4 (&z)[4], Float4 &w, Int cMask[4], const SampleSet &samples) = 0;
+	virtual void executeShader(Int cMask[4], Int sMask[4], Int zMask[4], const SampleSet &samples) = 0;
+	virtual Bool alphaTest(Int cMask[4], const SampleSet &samples) = 0;
+	virtual void rasterOperation(Pointer<Byte> cBuffer[4], Int &x, Int sMask[4], Int zMask[4], Int cMask[4], const SampleSet &samples) = 0;
 
 	void quad(Pointer<Byte> cBuffer[4], Pointer<Byte> &zBuffer, Pointer<Byte> &sBuffer, Int cMask[4], Int &x, Int &y) override;
 
 	void alphaTest(Int &aMask, const Short4 &alpha);
-	void alphaToCoverage(Int cMask[4], const Float4 &alpha, int sampleId);
+	void alphaToCoverage(Int cMask[4], const Float4 &alpha, const SampleSet &samples);
 
 	// Raster operations
 	void alphaBlend(int index, const Pointer<Byte> &cBuffer, Vector4s &current, const Int &x);
@@ -102,6 +106,8 @@
 	const bool shaderContainsSampleQualifier;
 	const bool perSampleShading;
 	const int invocationCount;
+
+	SampleSet getSampleSet(int invocation) const;
 };
 
 }  // namespace sw