Avoid recompiling identical SPIR-V code

We were creating SpirvShader objects for every shader stage of the
pipeline, each with their own unique serial ID. This caused us to
compile the same SPIR-V code over an over again when multiple pipelines
are created from the same shader module(s).

This change essentially moves the serial ID to the shader module. Things
that still require us to recompile code from the same shader module are
the entry point specification, and specialization constants. The former
is taken into account by using a 64-bit ID consisting of the module ID
and entry point ID. For the latter we assume any use of specialization
constants will result in a unique SPIR-V binary. This is conservative
and may still lead to unnecessary recompiles.

This change also minimizes the state passed to SpirvShader, to prevent
specialization on state not taken into account by the routine caches.

Bug: b/135609394
Tests: dEQP-VK.pipeline.render_to_image.core.*.huge.*
Change-Id: I204e812265067462f8019af9f6b7b3067ef5dc7f
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/33109
Presubmit-Ready: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
diff --git a/src/Device/PixelProcessor.hpp b/src/Device/PixelProcessor.hpp
index 0ab9d88..2dd9a01 100644
--- a/src/Device/PixelProcessor.hpp
+++ b/src/Device/PixelProcessor.hpp
@@ -34,7 +34,7 @@
 		{
 			unsigned int computeHash();
 
-			int shaderID;
+			uint64_t shaderID;
 
 			VkCompareOp depthCompareMode;
 			bool depthWriteEnable;
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 8a28a0a..b6866a2 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -366,16 +366,16 @@
 
 	} // namespace SIMD
 
-	std::atomic<int> SpirvShader::serialCounter(1);    // Start at 1, 0 is invalid shader.
-
 	SpirvShader::SpirvShader(
-			VkPipelineShaderStageCreateInfo const *createInfo,
+			uint32_t codeSerialID,
+			VkShaderStageFlagBits pipelineStage,
+			const char *entryPointName,
 			InsnStore const &insns,
 			vk::RenderPass *renderPass,
 			uint32_t subpassIndex)
 				: insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
 				outputs{MAX_INTERFACE_COMPONENTS},
-				serialID{serialCounter++}, modes{}
+				codeSerialID(codeSerialID), modes{}
 	{
 		ASSERT(insns.size() > 0);
 
@@ -411,7 +411,7 @@
 				auto id = Object::ID(insn.word(2));
 				auto name = insn.string(3);
 				auto stage = executionModelToStage(executionModel);
-				if (stage == createInfo->stage && strcmp(name, createInfo->pName) == 0)
+				if (stage == pipelineStage && strcmp(name, entryPointName) == 0)
 				{
 					ASSERT_MSG(entryPointFunctionId == 0, "Duplicate entry point with name '%s' and stage %d", name, int(stage));
 					entryPointFunctionId = id;
@@ -982,7 +982,7 @@
 			}
 		}
 
-		ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
+		ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", entryPointName);
 		AssignBlockFields();
 	}
 
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index b692254..23f7723 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -540,15 +540,19 @@
 
 		static_assert(sizeof(ImageInstruction) == sizeof(uint32_t), "ImageInstruction must be 32-bit");
 
-		int getSerialID() const
+		// This method is for retrieving an ID that uniquely identifies the
+		// shader entry point represented by this object.
+		uint64_t getSerialID() const
 		{
-			return serialID;
+			return  ((uint64_t)entryPointBlockId.value() << 32) | codeSerialID;
 		}
 
-		SpirvShader(VkPipelineShaderStageCreateInfo const *createInfo,
-					InsnStore const &insns,
-					vk::RenderPass *renderPass,
-					uint32_t subpassIndex);
+		SpirvShader(uint32_t codeSerialID,
+		            VkShaderStageFlagBits stage,
+		            const char *entryPointName,
+		            InsnStore const &insns,
+		            vk::RenderPass *renderPass,
+		            uint32_t subpassIndex);
 
 		struct Modes
 		{
@@ -740,8 +744,7 @@
 		}
 
 	private:
-		const int serialID;
-		static std::atomic<int> serialCounter;
+		const uint32_t codeSerialID;
 		Modes modes;
 		HandleMap<Type> types;
 		HandleMap<Object> defs;
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index 971c7ec..a7f2a41 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -152,8 +152,7 @@
 	return 0;
 }
 
-// preprocessSpirv applies and freezes specializations into constants, inlines
-// all functions and performs constant folding.
+// preprocessSpirv applies and freezes specializations into constants, and inlines all functions.
 std::vector<uint32_t> preprocessSpirv(
 		std::vector<uint32_t> const &code,
 		VkSpecializationInfo const *specializationInfo)
@@ -451,12 +450,16 @@
 			UNIMPLEMENTED("pStage->flags");
 		}
 
-		auto module = vk::Cast(pStage->module);
+		const ShaderModule *module = vk::Cast(pStage->module);
 		auto code = preprocessSpirv(module->getCode(), pStage->pSpecializationInfo);
 
+		// If the pipeline has specialization constants, assume they're unique and
+		// use a new serial ID so the shader gets recompiled.
+		uint32_t codeSerialID = (pStage->pSpecializationInfo ? ShaderModule::nextSerialID() : module->getSerialID());
+
 		// FIXME (b/119409619): use an allocator here so we can control all memory allocations
 		// TODO: also pass in any pipeline state which will affect shader compilation
-		auto spirvShader = new sw::SpirvShader{pStage, code, vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass};
+		auto spirvShader = new sw::SpirvShader(codeSerialID, pStage->stage, pStage->pName, code, vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass);
 
 		switch (pStage->stage)
 		{
@@ -542,16 +545,21 @@
 
 void ComputePipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkComputePipelineCreateInfo* pCreateInfo)
 {
-	auto module = vk::Cast(pCreateInfo->stage.module);
+	auto &stage = pCreateInfo->stage;
+	const ShaderModule *module = vk::Cast(stage.module);
 
-	auto code = preprocessSpirv(module->getCode(), pCreateInfo->stage.pSpecializationInfo);
+	auto code = preprocessSpirv(module->getCode(), stage.pSpecializationInfo);
 
 	ASSERT_OR_RETURN(code.size() > 0);
 
 	ASSERT(shader == nullptr);
 
-	// FIXME(b/119409619): use allocator.
-	shader = new sw::SpirvShader(&pCreateInfo->stage, code, nullptr, 0);
+	// If the pipeline has specialization constants, assume they're unique and
+	// use a new serial ID so the shader gets recompiled.
+	uint32_t codeSerialID = (stage.pSpecializationInfo ? ShaderModule::nextSerialID() : module->getSerialID());
+
+	// TODO(b/119409619): use allocator.
+	shader = new sw::SpirvShader(codeSerialID, stage.stage, stage.pName, code, nullptr, 0);
 	vk::DescriptorSet::Bindings descriptorSets;  // FIXME(b/129523279): Delay code generation until invoke time.
 	program = new sw::ComputeProgram(shader, layout, descriptorSets);
 	program->generate();
diff --git a/src/Vulkan/VkShaderModule.cpp b/src/Vulkan/VkShaderModule.cpp
index b48c909..ce5e831 100644
--- a/src/Vulkan/VkShaderModule.cpp
+++ b/src/Vulkan/VkShaderModule.cpp
@@ -19,7 +19,10 @@
 namespace vk
 {
 
-ShaderModule::ShaderModule(const VkShaderModuleCreateInfo* pCreateInfo, void* mem) : code(reinterpret_cast<uint32_t*>(mem))
+std::atomic<uint32_t> ShaderModule::serialCounter(1);    // Start at 1, 0 is invalid shader.
+
+ShaderModule::ShaderModule(const VkShaderModuleCreateInfo* pCreateInfo, void* mem)
+	: serialID(nextSerialID()), code(reinterpret_cast<uint32_t*>(mem))
 {
 	memcpy(code, pCreateInfo->pCode, pCreateInfo->codeSize);
 	wordCount = static_cast<uint32_t>(pCreateInfo->codeSize / sizeof(uint32_t));
diff --git a/src/Vulkan/VkShaderModule.hpp b/src/Vulkan/VkShaderModule.hpp
index 0bc1309..ba30b59 100644
--- a/src/Vulkan/VkShaderModule.hpp
+++ b/src/Vulkan/VkShaderModule.hpp
@@ -16,6 +16,8 @@
 #define VK_SHADER_MODULE_HPP_
 
 #include "VkObject.hpp"
+
+#include <atomic>
 #include <vector>
 
 namespace rr
@@ -37,7 +39,13 @@
 	// guts' operations, and this copy.
 	std::vector<uint32_t> getCode() const { return std::vector<uint32_t>{ code, code + wordCount };}
 
+	uint32_t getSerialID() const { return serialID; }
+	static uint32_t nextSerialID() { return serialCounter++; }
+
 private:
+	const uint32_t serialID;
+	static std::atomic<uint32_t> serialCounter;
+
 	uint32_t* code = nullptr;
 	uint32_t wordCount = 0;
 };