src/Pipeline: Refactor ComputeProgram

• Split up ComputeProgram::emit() into a few smaller functions.
• Calculate the subgroup count in C++, and pass it in as a parameter.
• Add a firstSubgroup, this is currently always 0, but will be used in a later change.
• Pass the workgroup ID as parameters instead of through Data. Data now holds fields common for all workgroups.

This refactoring prepares the code for migrating to coroutines.

Bug: b/131672705
Change-Id: Id3492adc0a7aedc3f16c0e37f135294862c55700
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/30848
Tested-by: Ben Clayton <bclayton@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/ComputeProgram.cpp b/src/Pipeline/ComputeProgram.cpp
index d02f340..6daaa3a 100644
--- a/src/Pipeline/ComputeProgram.cpp
+++ b/src/Pipeline/ComputeProgram.cpp
@@ -45,30 +45,11 @@
 		shader->emitEpilog(&routine);
 	}
 
-	void ComputeProgram::emit()
+	void ComputeProgram::setWorkgroupBuiltins(Int workgroupID[3])
 	{
-		routine.descriptorSets = data + OFFSET(Data, descriptorSets);
-		routine.descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
-		routine.pushConstants = data + OFFSET(Data, pushConstants);
-		routine.constants = *Pointer<Pointer<Byte>>(data + OFFSET(Data, constants));
-		routine.workgroupMemory = *Pointer<Pointer<Byte>>(data + OFFSET(Data, workgroupMemory));
-
-		auto &modes = shader->getModes();
-
-		int localSize[3] = {modes.WorkgroupSizeX, modes.WorkgroupSizeY, modes.WorkgroupSizeZ};
-
-		const int subgroupSize = SIMD::Width;
-
-		// Total number of invocations required to execute this workgroup.
-		int numInvocations = localSize[X] * localSize[Y] * localSize[Z];
-
-		Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
-		Int4 workgroupID = *Pointer<Int4>(data + OFFSET(Data, workgroupID));
-		Int4 workgroupSize = Int4(localSize[X], localSize[Y], localSize[Z], 0);
-		Int numSubgroups = (numInvocations + subgroupSize - 1) / subgroupSize;
-
 		setInputBuiltin(spv::BuiltInNumWorkgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
 		{
+			auto numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
 			for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
 			{
 				value[builtin.FirstComponent + component] =
@@ -81,12 +62,13 @@
 			for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
 			{
 				value[builtin.FirstComponent + component] =
-						As<SIMD::Float>(SIMD::Int(Extract(workgroupID, component)));
+					As<SIMD::Float>(SIMD::Int(workgroupID[component]));
 			}
 		});
 
 		setInputBuiltin(spv::BuiltInWorkgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
 		{
+			auto workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
 			for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
 			{
 				value[builtin.FirstComponent + component] =
@@ -97,13 +79,15 @@
 		setInputBuiltin(spv::BuiltInNumSubgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
 		{
 			ASSERT(builtin.SizeInComponents == 1);
-			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(numSubgroups));
+			auto subgroupsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, subgroupsPerWorkgroup));
+			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupsPerWorkgroup));
 		});
 
 		setInputBuiltin(spv::BuiltInSubgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
 		{
 			ASSERT(builtin.SizeInComponents == 1);
-			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupSize));
+			auto invocationsPerSubgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerSubgroup));
+			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(invocationsPerSubgroup));
 		});
 
 		setInputBuiltin(spv::BuiltInSubgroupLocalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
@@ -111,56 +95,94 @@
 			ASSERT(builtin.SizeInComponents == 1);
 			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(0, 1, 2, 3));
 		});
+	}
 
-		For(Int subgroupIndex = 0, subgroupIndex < numSubgroups, subgroupIndex++)
+	void ComputeProgram::setSubgroupBuiltins(Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex)
+	{
+		Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
+		Int4 workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
+
+		// TODO: Fix Int4 swizzles so we can just use workgroupSize.x, workgroupSize.y.
+		Int workgroupSizeX = Extract(workgroupSize, X);
+		Int workgroupSizeY = Extract(workgroupSize, Y);
+
+		SIMD::Int localInvocationID[3];
 		{
+			SIMD::Int idx = localInvocationIndex;
+			localInvocationID[Z] = idx / SIMD::Int(workgroupSizeX * workgroupSizeY);
+			idx -= localInvocationID[Z] * SIMD::Int(workgroupSizeX * workgroupSizeY); // modulo
+			localInvocationID[Y] = idx / SIMD::Int(workgroupSizeX);
+			idx -= localInvocationID[Y] * SIMD::Int(workgroupSizeX); // modulo
+			localInvocationID[X] = idx;
+		}
+
+		setInputBuiltin(spv::BuiltInLocalInvocationIndex, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
+		{
+			ASSERT(builtin.SizeInComponents == 1);
+			value[builtin.FirstComponent] = As<SIMD::Float>(localInvocationIndex);
+		});
+
+		setInputBuiltin(spv::BuiltInSubgroupId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
+		{
+			ASSERT(builtin.SizeInComponents == 1);
+			value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupIndex));
+		});
+
+		setInputBuiltin(spv::BuiltInLocalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
+		{
+			for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
+			{
+				value[builtin.FirstComponent + component] =
+					As<SIMD::Float>(localInvocationID[component]);
+			}
+		});
+
+		setInputBuiltin(spv::BuiltInGlobalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
+		{
+			SIMD::Int wgID = 0;
+			wgID = Insert(wgID, workgroupID[X], X);
+			wgID = Insert(wgID, workgroupID[Y], Y);
+			wgID = Insert(wgID, workgroupID[Z], Z);
+			auto localBase = workgroupSize * wgID;
+			for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
+			{
+				auto globalInvocationID = SIMD::Int(Extract(localBase, component)) + localInvocationID[component];
+				value[builtin.FirstComponent + component] = As<SIMD::Float>(globalInvocationID);
+			}
+		});
+	}
+
+	void ComputeProgram::emit()
+	{
+		routine.descriptorSets = data + OFFSET(Data, descriptorSets);
+		routine.descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
+		routine.pushConstants = data + OFFSET(Data, pushConstants);
+		routine.constants = *Pointer<Pointer<Byte>>(data + OFFSET(Data, constants));
+		routine.workgroupMemory = *Pointer<Pointer<Byte>>(data + OFFSET(Data, workgroupMemory));
+
+		Int workgroupX = Arg<1>();
+		Int workgroupY = Arg<2>();
+		Int workgroupZ = Arg<3>();
+		Int firstSubgroup = Arg<4>();
+		Int subgroupCount = Arg<5>();
+
+		Int invocationsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerWorkgroup));
+
+		Int workgroupID[3] = {workgroupX, workgroupY, workgroupZ};
+		setWorkgroupBuiltins(workgroupID);
+
+		For(Int i = 0, i < subgroupCount, i++)
+		{
+			auto subgroupIndex = firstSubgroup + i;
+
 			// TODO: Replace SIMD::Int(0, 1, 2, 3) with SIMD-width equivalent
 			auto localInvocationIndex = SIMD::Int(subgroupIndex * SIMD::Width) + SIMD::Int(0, 1, 2, 3);
 
-			// Disable lanes where (invocationIDs >= numInvocations)
-			auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(numInvocations));
+			// Disable lanes where (invocationIDs >= invocationsPerWorkgroup)
+			auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(invocationsPerWorkgroup));
 
-			SIMD::Int localInvocationID[3];
-			{
-				SIMD::Int idx = localInvocationIndex;
-				localInvocationID[Z] = idx / SIMD::Int(localSize[X] * localSize[Y]);
-				idx -= localInvocationID[Z] * SIMD::Int(localSize[X] * localSize[Y]); // modulo
-				localInvocationID[Y] = idx / SIMD::Int(localSize[X]);
-				idx -= localInvocationID[Y] * SIMD::Int(localSize[X]); // modulo
-				localInvocationID[X] = idx;
-			}
+			setSubgroupBuiltins(workgroupID, localInvocationIndex, subgroupIndex);
 
-			setInputBuiltin(spv::BuiltInLocalInvocationIndex, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
-			{
-				ASSERT(builtin.SizeInComponents == 1);
-				value[builtin.FirstComponent] = As<SIMD::Float>(localInvocationIndex);
-			});
-
-			setInputBuiltin(spv::BuiltInSubgroupId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
-			{
-				ASSERT(builtin.SizeInComponents == 1);
-				value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupIndex));
-			});
-
-			setInputBuiltin(spv::BuiltInLocalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
-			{
-				for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
-				{
-					value[builtin.FirstComponent + component] = As<SIMD::Float>(localInvocationID[component]);
-				}
-			});
-
-			setInputBuiltin(spv::BuiltInGlobalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
-			{
-				auto localBase = workgroupID * workgroupSize;
-				for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
-				{
-					auto globalInvocationID = SIMD::Int(Extract(localBase, component)) + localInvocationID[component];
-					value[builtin.FirstComponent + component] = As<SIMD::Float>(globalInvocationID);
-				}
-			});
-
-			// Process numLanes of the workgroup.
 			shader->emit(&routine, activeLaneMask, descriptorSets);
 		}
 	}
@@ -182,7 +204,13 @@
 		PushConstantStorage const &pushConstants,
 		uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ)
 	{
-		auto runWorkgroup = (void(*)(void*))(routine->getEntry());
+		auto runWorkgroup = (void(*)(void*, int, int, int, int, int))(routine->getEntry());
+
+		auto &modes = shader->getModes();
+
+		auto invocationsPerSubgroup = SIMD::Width;
+		auto invocationsPerWorkgroup = modes.WorkgroupSizeX * modes.WorkgroupSizeY * modes.WorkgroupSizeZ;
+		auto subgroupsPerWorkgroup = (invocationsPerWorkgroup + invocationsPerSubgroup - 1) / invocationsPerSubgroup;
 
 		// We're sharing a buffer here across all workgroups.
 		// We can only do this because we know workgroups are executed
@@ -196,6 +224,13 @@
 		data.numWorkgroups[Y] = groupCountY;
 		data.numWorkgroups[Z] = groupCountZ;
 		data.numWorkgroups[3] = 0;
+		data.workgroupSize[X] = modes.WorkgroupSizeX;
+		data.workgroupSize[Y] = modes.WorkgroupSizeY;
+		data.workgroupSize[Z] = modes.WorkgroupSizeZ;
+		data.workgroupSize[3] = 0;
+		data.invocationsPerSubgroup = invocationsPerSubgroup;
+		data.invocationsPerWorkgroup = invocationsPerWorkgroup;
+		data.subgroupsPerWorkgroup = subgroupsPerWorkgroup;
 		data.pushConstants = pushConstants;
 		data.constants = &sw::constants;
 		data.workgroupMemory = workgroupMemory.data();
@@ -203,16 +238,14 @@
 		// TODO(bclayton): Split work across threads.
 		for (uint32_t groupZ = 0; groupZ < groupCountZ; groupZ++)
 		{
-			data.workgroupID[Z] = groupZ;
 			for (uint32_t groupY = 0; groupY < groupCountY; groupY++)
 			{
-				data.workgroupID[Y] = groupY;
 				for (uint32_t groupX = 0; groupX < groupCountX; groupX++)
 				{
-					data.workgroupID[X] = groupX;
-					runWorkgroup(&data);
+					runWorkgroup(&data, groupX, groupY, groupZ, 0, subgroupsPerWorkgroup);
 				}
 			}
 		}
 	}
-}
+
+} // namespace sw
diff --git a/src/Pipeline/ComputeProgram.hpp b/src/Pipeline/ComputeProgram.hpp
index 59a2315..493d89b 100644
--- a/src/Pipeline/ComputeProgram.hpp
+++ b/src/Pipeline/ComputeProgram.hpp
@@ -37,7 +37,13 @@
 	struct Constants;
 
 	// ComputeProgram builds a SPIR-V compute shader.
-	class ComputeProgram : public Function<Void(Pointer<Byte>)>
+	class ComputeProgram : public Function<Void(
+			Pointer<Byte> data,
+			Int workgroupX,
+			Int workgroupY,
+			Int workgroupZ,
+			Int firstSubgroup,
+			Int subgroupCount)>
 	{
 	public:
 		ComputeProgram(SpirvShader const *spirvShader, vk::PipelineLayout const *pipelineLayout, const vk::DescriptorSet::Bindings &descriptorSets);
@@ -59,6 +65,8 @@
 	protected:
 		void emit();
 
+		void setWorkgroupBuiltins(Int workgroupID[3]);
+		void setSubgroupBuiltins(Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex);
 		void setInputBuiltin(spv::BuiltIn id, std::function<void(const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)> cb);
 
 		Pointer<Byte> data; // argument 0
@@ -67,8 +75,11 @@
 		{
 			vk::DescriptorSet::Bindings descriptorSets;
 			vk::DescriptorSet::DynamicOffsets descriptorDynamicOffsets;
-			uint4 numWorkgroups;
-			uint4 workgroupID;
+			uint4 numWorkgroups; // [x, y, z, 0]
+			uint4 workgroupSize; // [x, y, z, 0]
+			uint32_t invocationsPerSubgroup; // SPIR-V: "SubgroupSize"
+			uint32_t subgroupsPerWorkgroup; // SPIR-V: "NumSubgroups"
+			uint32_t invocationsPerWorkgroup; // Total number of invocations per workgroup.
 			PushConstantStorage pushConstants;
 			const Constants *constants;
 			uint8_t* workgroupMemory;