Cache optimized SPIR-V binaries instead of compiled shaders

Previously the optimization of the SPIR-V binary and the compilation
into a SpirvShader object took place in the same function, and the
latter would be cached for potential reuse. This change splits the
binary optimization step from the compilation step, and caches the
result of the former instead.

Since spirvtools::Optimizer processes an entire SPIR-V module, while
SpirvShader objects represent a single entry point, this will enable
subsequent changes to reuse optimized SPIR-V binaries containing
multiple entry points. This change still uses the same cache key,
which includes the entry point.

Bug: b/197982536
Change-Id: I64b185512b03beb9307b1606f505c9a854b876a1
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/57129
Tested-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index aae2104..0fb6f9d 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -105,10 +105,8 @@
 	return optimized;
 }
 
-std::shared_ptr<sw::SpirvShader> createShader(
+sw::SpirvBinary optimizeSpirv(
     const vk::PipelineCache::SpirvShaderKey &key,
-    const vk::ShaderModule *module,
-    bool robustBufferAccess,
     const std::shared_ptr<vk::dbg::Context> &dbgctx)
 {
 	// Do not optimize the shader if we have a debugger context.
@@ -119,13 +117,23 @@
 	auto code = preprocessSpirv(key.getInsns(), key.getSpecializationInfo(), optimize);
 	ASSERT(code.size() > 0);
 
+	return code;
+}
+
+std::shared_ptr<sw::SpirvShader> createShader(
+    const vk::PipelineCache::SpirvShaderKey &key,
+    const vk::ShaderModule *module,
+    const sw::SpirvBinary &spirv,
+    bool robustBufferAccess,
+    const std::shared_ptr<vk::dbg::Context> &dbgctx)
+{
 	// If the pipeline has specialization constants, assume they're unique and
 	// use a new serial ID so the shader gets recompiled.
 	uint32_t codeSerialID = (key.getSpecializationInfo() ? vk::ShaderModule::nextSerialID() : module->getSerialID());
 
 	// TODO(b/119409619): use allocator.
 	return std::make_shared<sw::SpirvShader>(codeSerialID, key.getPipelineStage(), key.getEntryPointName().c_str(),
-	                                         code, key.getRenderPass(), key.getSubpassIndex(), robustBufferAccess, dbgctx);
+	                                         spirv, key.getRenderPass(), key.getSubpassIndex(), robustBufferAccess, dbgctx);
 }
 
 std::shared_ptr<sw::ComputeProgram> createProgram(vk::Device *device, std::shared_ptr<sw::SpirvShader> shader, const vk::PipelineLayout *layout)
@@ -238,19 +246,24 @@
 		                                        vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass,
 		                                        pStage->pSpecializationInfo);
 		auto pipelineStage = key.getPipelineStage();
+		auto dbgctx = device->getDebuggerContext();
+
+		sw::SpirvBinary spirv;
 
 		if(pPipelineCache)
 		{
-			auto shader = pPipelineCache->getOrCreateShader(key, [&] {
-				return createShader(key, module, robustBufferAccess, device->getDebuggerContext());
+			spirv = pPipelineCache->getOrOptimizeSpirv(key, [&] {
+				return optimizeSpirv(key, dbgctx);
 			});
-			setShader(pipelineStage, shader);
 		}
 		else
 		{
-			auto shader = createShader(key, module, robustBufferAccess, device->getDebuggerContext());
-			setShader(pipelineStage, shader);
+			spirv = optimizeSpirv(key, dbgctx);
 		}
+
+		auto shader = createShader(key, module, spirv, robustBufferAccess, dbgctx);
+
+		setShader(pipelineStage, shader);
 	}
 }
 
@@ -280,21 +293,33 @@
 
 	const PipelineCache::SpirvShaderKey shaderKey(
 	    stage.stage, stage.pName, module->getCode(), nullptr, 0, stage.pSpecializationInfo);
+	auto dbgctx = device->getDebuggerContext();
+
+	sw::SpirvBinary spirv;
+
 	if(pPipelineCache)
 	{
-		shader = pPipelineCache->getOrCreateShader(shaderKey, [&] {
-			return createShader(shaderKey, module, robustBufferAccess, device->getDebuggerContext());
+		spirv = pPipelineCache->getOrOptimizeSpirv(shaderKey, [&] {
+			return optimizeSpirv(shaderKey, dbgctx);
 		});
+	}
+	else
+	{
+		spirv = optimizeSpirv(shaderKey, dbgctx);
+	}
 
-		const PipelineCache::ComputeProgramKey programKey(shader->getSerialID(), layout->identifier);
+	shader = createShader(shaderKey, module, spirv, robustBufferAccess, dbgctx);
+
+	const PipelineCache::ComputeProgramKey programKey(shader->getSerialID(), layout->identifier);
+
+	if(pPipelineCache)
+	{
 		program = pPipelineCache->getOrCreateComputeProgram(programKey, [&] {
 			return createProgram(device, shader, layout);
 		});
 	}
 	else
 	{
-		shader = createShader(shaderKey, module, robustBufferAccess, device->getDebuggerContext());
-		const PipelineCache::ComputeProgramKey programKey(shader->getSerialID(), layout->identifier);
 		program = createProgram(device, shader, layout);
 	}
 }
diff --git a/src/Vulkan/VkPipelineCache.hpp b/src/Vulkan/VkPipelineCache.hpp
index 9c479e4..c959e0d 100644
--- a/src/Vulkan/VkPipelineCache.hpp
+++ b/src/Vulkan/VkPipelineCache.hpp
@@ -80,13 +80,13 @@
 		const vk::SpecializationInfo specializationInfo;
 	};
 
-	// getOrCreateShader() queries the cache for a shader with the given key.
+	// getOrOptimizeSpirv() queries the cache for a shader with the given key.
 	// If one is found, it is returned, otherwise create() is called, the
-	// returned shader is added to the cache, and it is returned.
+	// returned SPIR-V binary is added to the cache, and it is returned.
 	// Function must be a function of the signature:
-	//     std::shared_ptr<sw::SpirvShader>()
+	//     sw::ShaderBinary()
 	template<typename Function>
-	inline std::shared_ptr<sw::SpirvShader> getOrCreateShader(const PipelineCache::SpirvShaderKey &key, Function &&create);
+	inline sw::SpirvBinary getOrOptimizeSpirv(const PipelineCache::SpirvShaderKey &key, Function &&create);
 
 	struct ComputeProgramKey
 	{
@@ -122,7 +122,7 @@
 	uint8_t *data = nullptr;
 
 	marl::mutex spirvShadersMutex;
-	std::map<SpirvShaderKey, std::shared_ptr<sw::SpirvShader>> spirvShaders GUARDED_BY(spirvShadersMutex);
+	std::map<SpirvShaderKey, sw::SpirvBinary> spirvShaders GUARDED_BY(spirvShadersMutex);
 
 	marl::mutex computeProgramsMutex;
 	std::map<ComputeProgramKey, std::shared_ptr<sw::ComputeProgram>> computePrograms GUARDED_BY(computeProgramsMutex);
@@ -139,23 +139,31 @@
 	marl::lock lock(computeProgramsMutex);
 
 	auto it = computePrograms.find(key);
-	if(it != computePrograms.end()) { return it->second; }
+	if(it != computePrograms.end())
+	{
+		return it->second;
+	}
 
 	auto created = create();
 	computePrograms.emplace(key, created);
+
 	return created;
 }
 
 template<typename Function>
-std::shared_ptr<sw::SpirvShader> PipelineCache::getOrCreateShader(const PipelineCache::SpirvShaderKey &key, Function &&create)
+sw::SpirvBinary PipelineCache::getOrOptimizeSpirv(const PipelineCache::SpirvShaderKey &key, Function &&create)
 {
 	marl::lock lock(spirvShadersMutex);
 
 	auto it = spirvShaders.find(key);
-	if(it != spirvShaders.end()) { return it->second; }
+	if(it != spirvShaders.end())
+	{
+		return it->second;
+	}
 
 	auto created = create();
 	spirvShaders.emplace(key, created);
+
 	return created;
 }