SpirvShader: Handle multiple entry points

Only one will be used.

Tests: dEQP-VK.spirv_assembly.instruction.compute.multiple_shaders.*
Bug: b/132341142
Change-Id: I75588f3281f325dc1753222a1f89476267c56ad3
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31011
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index c84be55..cdc0683 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -357,10 +357,14 @@
 
 	volatile int SpirvShader::serialCounter = 1;    // Start at 1, 0 is invalid shader.
 
-	SpirvShader::SpirvShader(InsnStore const &insns, vk::RenderPass *renderPass, uint32_t subpassIndex)
-			: insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
-			  outputs{MAX_INTERFACE_COMPONENTS},
-			  serialID{serialCounter++}, modes{}
+	SpirvShader::SpirvShader(
+			VkPipelineShaderStageCreateInfo const *createInfo,
+			InsnStore const &insns,
+			vk::RenderPass *renderPass,
+			uint32_t subpassIndex)
+				: insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
+				outputs{MAX_INTERFACE_COMPONENTS},
+				serialID{serialCounter++}, modes{}
 	{
 		ASSERT(insns.size() > 0);
 
@@ -378,9 +382,9 @@
 		}
 
 		// Simplifying assumptions (to be satisfied by earlier transformations)
-		// - There is exactly one entrypoint in the module, and it's the one we want
 		// - The only input/output OpVariables present are those used by the entrypoint
 
+		Object::ID entryPointFunctionId;
 		Block::ID currentBlock;
 		InsnIterator blockStart;
 
@@ -390,6 +394,20 @@
 
 			switch (opcode)
 			{
+			case spv::OpEntryPoint:
+			{
+				auto executionModel = spv::ExecutionModel(insn.word(1));
+				auto id = Object::ID(insn.word(2));
+				auto name = insn.string(3);
+				auto stage = executionModelToStage(executionModel);
+				if (stage == createInfo->stage && strcmp(name, createInfo->pName) == 0)
+				{
+					ASSERT_MSG(entryPointFunctionId == 0, "Duplicate entry point with name '%s' and stage %d", name, int(stage));
+					entryPointFunctionId = id;
+				}
+				break;
+			}
+
 			case spv::OpExecutionMode:
 				ProcessExecutionMode(insn);
 				break;
@@ -650,27 +668,38 @@
 			case spv::OpMemoryModel:
 				break; // Memory model does not affect our code generation until we decide to do Vulkan Memory Model support.
 
-			case spv::OpEntryPoint:
-				break;
 			case spv::OpFunction:
-				ASSERT(mainBlockId.value() == 0); // Multiple functions found
-				// Scan forward to find the function's label.
-				for (auto it = insn; it != end() && mainBlockId.value() == 0; it++)
+			{
+				auto functionId = Object::ID(insn.word(2));
+				if (functionId == entryPointFunctionId)
 				{
-					switch (it.opcode())
+					// Scan forward to find the function's label.
+					for (auto it = insn; it != end() && entryPointBlockId == 0; it++)
 					{
-					case spv::OpFunction:
-					case spv::OpFunctionParameter:
-						break;
-					case spv::OpLabel:
-						mainBlockId = Block::ID(it.word(1));
-						break;
-					default:
-						WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
+						switch (it.opcode())
+						{
+						case spv::OpFunction:
+						case spv::OpFunctionParameter:
+							break;
+						case spv::OpLabel:
+							entryPointBlockId = Block::ID(it.word(1));
+							break;
+						default:
+							WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
+						}
 					}
 				}
-				ASSERT(mainBlockId.value() != 0); // Function's OpLabel not found
+				else
+				{
+					// All non-entry point functions should be inlined into an
+					// entry point function.
+					// This isn't the target entry point, so must be another
+					// entry point that we are not interested in. Just skip it.
+					for (; insn != end() && insn.opcode() != spv::OpFunctionEnd; insn++) {}
+				}
+
 				break;
+			}
 			case spv::OpFunctionEnd:
 				// Due to preprocessing, the entrypoint and its function provide no value.
 				break;
@@ -678,7 +707,7 @@
 			{
 				// We will only support the GLSL 450 extended instruction set, so no point in tracking the ID we assign it.
 				// Valid shaders will not attempt to import any other instruction sets.
-				auto ext = reinterpret_cast<char const *>(insn.wordPointer(2));
+				auto ext = insn.string(2);
 				if (0 != strcmp("GLSL.std.450", ext))
 				{
 					UNSUPPORTED("SPIR-V Extension: %s", ext);
@@ -895,7 +924,7 @@
 
 			case spv::OpExtension:
 			{
-				auto ext = reinterpret_cast<char const *>(insn.wordPointer(1));
+				auto ext = insn.string(1);
 				// Part of core SPIR-V 1.3. Vulkan 1.1 implementations must also accept the pre-1.3
 				// extension per Appendix A, `Vulkan Environment for SPIR-V`.
 				if (!strcmp(ext, "SPV_KHR_storage_buffer_storage_class")) break;
@@ -909,6 +938,7 @@
 			}
 		}
 
+		ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
 		AssignBlockIns();
 	}
 
@@ -927,7 +957,7 @@
 	void SpirvShader::AssignBlockIns()
 	{
 		Block::Set reachable;
-		TraverseReachableBlocks(mainBlockId, reachable);
+		TraverseReachableBlocks(entryPointBlockId, reachable);
 
 		for (auto &it : blocks)
 		{
@@ -1887,8 +1917,8 @@
 			EmitInstruction(insn, &state);
 		}
 
-		// Emit all the blocks starting from mainBlockId.
-		EmitBlocks(mainBlockId, &state);
+		// Emit all the blocks starting from entryPointBlockId.
+		EmitBlocks(entryPointBlockId, &state);
 	}
 
 	void SpirvShader::EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore /* = 0 */) const
@@ -1979,7 +2009,7 @@
 			return; // Already generated this block.
 		}
 
-		if (blockId != mainBlockId)
+		if (blockId != entryPointBlockId)
 		{
 			// Set the activeLaneMask.
 			SIMD::Int activeLaneMask(0);
@@ -5795,6 +5825,31 @@
 		return it->second;
 	}
 
+	VkShaderStageFlagBits SpirvShader::executionModelToStage(spv::ExecutionModel model)
+	{
+		switch (model)
+		{
+		case spv::ExecutionModelVertex:                 return VK_SHADER_STAGE_VERTEX_BIT;
+		// case spv::ExecutionModelTessellationControl:    return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
+		// case spv::ExecutionModelTessellationEvaluation: return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
+		// case spv::ExecutionModelGeometry:               return VK_SHADER_STAGE_GEOMETRY_BIT;
+		case spv::ExecutionModelFragment:               return VK_SHADER_STAGE_FRAGMENT_BIT;
+		case spv::ExecutionModelGLCompute:              return VK_SHADER_STAGE_COMPUTE_BIT;
+		// case spv::ExecutionModelKernel:                 return VkShaderStageFlagBits(0); // Not supported by vulkan.
+		// case spv::ExecutionModelTaskNV:                 return VK_SHADER_STAGE_TASK_BIT_NV;
+		// case spv::ExecutionModelMeshNV:                 return VK_SHADER_STAGE_MESH_BIT_NV;
+		// case spv::ExecutionModelRayGenerationNV:        return VK_SHADER_STAGE_RAYGEN_BIT_NV;
+		// case spv::ExecutionModelIntersectionNV:         return VK_SHADER_STAGE_INTERSECTION_BIT_NV;
+		// case spv::ExecutionModelAnyHitNV:               return VK_SHADER_STAGE_ANY_HIT_BIT_NV;
+		// case spv::ExecutionModelClosestHitNV:           return VK_SHADER_STAGE_CLOSEST_HIT_BIT_NV;
+		// case spv::ExecutionModelMissNV:                 return VK_SHADER_STAGE_MISS_BIT_NV;
+		// case spv::ExecutionModelCallableNV:             return VK_SHADER_STAGE_CALLABLE_BIT_NV;
+		default:
+			UNSUPPORTED("ExecutionModel: %d", int(model));
+			return VkShaderStageFlagBits(0);
+		}
+	}
+
 	SpirvRoutine::SpirvRoutine(vk::PipelineLayout const *pipelineLayout) :
 		pipelineLayout(pipelineLayout)
 	{
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 3f71fb5..21a9d5a 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -295,6 +295,11 @@
 				return &iter[n];
 			}
 
+			const char* string(uint32_t n) const
+			{
+				return reinterpret_cast<const char*>(wordPointer(n));
+			}
+
 			bool operator==(InsnIterator const &other) const
 			{
 				return iter == other.iter;
@@ -538,7 +543,10 @@
 			return serialID;
 		}
 
-		SpirvShader(InsnStore const &insns, vk::RenderPass *renderPass, uint32_t subpassIndex);
+		SpirvShader(VkPipelineShaderStageCreateInfo const *createInfo,
+					InsnStore const &insns,
+					vk::RenderPass *renderPass,
+					uint32_t subpassIndex);
 
 		struct Modes
 		{
@@ -727,7 +735,7 @@
 		HandleMap<Type> types;
 		HandleMap<Object> defs;
 		HandleMap<Block> blocks;
-		Block::ID mainBlockId; // Block of the entry point function.
+		Block::ID entryPointBlockId; // Block of the entry point function.
 
 		// Walks all reachable the blocks starting from id adding them to
 		// reachable.
@@ -980,6 +988,7 @@
 		static sw::FilterType convertFilterMode(const vk::Sampler *sampler);
 		static sw::MipmapType convertMipmapMode(const vk::Sampler *sampler);
 		static sw::AddressingMode convertAddressingMode(int coordinateIndex, VkSamplerAddressMode addressMode, VkImageViewType imageViewType);
+		static VkShaderStageFlagBits executionModelToStage(spv::ExecutionModel model);
 	};
 
 	class SpirvRoutine
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index d0cd19e..eaedcac 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -459,7 +459,7 @@
 
 		// FIXME (b/119409619): use an allocator here so we can control all memory allocations
 		// TODO: also pass in any pipeline state which will affect shader compilation
-		auto spirvShader = new sw::SpirvShader{code, Cast(pCreateInfo->renderPass), pCreateInfo->subpass};
+		auto spirvShader = new sw::SpirvShader{pStage, code, Cast(pCreateInfo->renderPass), pCreateInfo->subpass};
 
 		switch (pStage->stage)
 		{
@@ -552,7 +552,7 @@
 	ASSERT(shader == nullptr);
 
 	// FIXME(b/119409619): use allocator.
-	shader = new sw::SpirvShader(code, nullptr, 0);
+	shader = new sw::SpirvShader(&pCreateInfo->stage, code, nullptr, 0);
 	vk::DescriptorSet::Bindings descriptorSets;  // FIXME(b/129523279): Delay code generation until invoke time.
 	program = new sw::ComputeProgram(shader, layout, descriptorSets);
 	program->generate();