SpirvShader: Handle multiple entry points
Only one will be used.
Tests: dEQP-VK.spirv_assembly.instruction.compute.multiple_shaders.*
Bug: b/132341142
Change-Id: I75588f3281f325dc1753222a1f89476267c56ad3
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31011
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Presubmit-Ready: Ben Clayton <bclayton@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index c84be55..cdc0683 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -357,10 +357,14 @@
volatile int SpirvShader::serialCounter = 1; // Start at 1, 0 is invalid shader.
- SpirvShader::SpirvShader(InsnStore const &insns, vk::RenderPass *renderPass, uint32_t subpassIndex)
- : insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
- outputs{MAX_INTERFACE_COMPONENTS},
- serialID{serialCounter++}, modes{}
+ SpirvShader::SpirvShader(
+ VkPipelineShaderStageCreateInfo const *createInfo,
+ InsnStore const &insns,
+ vk::RenderPass *renderPass,
+ uint32_t subpassIndex)
+ : insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
+ outputs{MAX_INTERFACE_COMPONENTS},
+ serialID{serialCounter++}, modes{}
{
ASSERT(insns.size() > 0);
@@ -378,9 +382,9 @@
}
// Simplifying assumptions (to be satisfied by earlier transformations)
- // - There is exactly one entrypoint in the module, and it's the one we want
// - The only input/output OpVariables present are those used by the entrypoint
+ Object::ID entryPointFunctionId;
Block::ID currentBlock;
InsnIterator blockStart;
@@ -390,6 +394,20 @@
switch (opcode)
{
+ case spv::OpEntryPoint:
+ {
+ auto executionModel = spv::ExecutionModel(insn.word(1));
+ auto id = Object::ID(insn.word(2));
+ auto name = insn.string(3);
+ auto stage = executionModelToStage(executionModel);
+ if (stage == createInfo->stage && strcmp(name, createInfo->pName) == 0)
+ {
+ ASSERT_MSG(entryPointFunctionId == 0, "Duplicate entry point with name '%s' and stage %d", name, int(stage));
+ entryPointFunctionId = id;
+ }
+ break;
+ }
+
case spv::OpExecutionMode:
ProcessExecutionMode(insn);
break;
@@ -650,27 +668,38 @@
case spv::OpMemoryModel:
break; // Memory model does not affect our code generation until we decide to do Vulkan Memory Model support.
- case spv::OpEntryPoint:
- break;
case spv::OpFunction:
- ASSERT(mainBlockId.value() == 0); // Multiple functions found
- // Scan forward to find the function's label.
- for (auto it = insn; it != end() && mainBlockId.value() == 0; it++)
+ {
+ auto functionId = Object::ID(insn.word(2));
+ if (functionId == entryPointFunctionId)
{
- switch (it.opcode())
+ // Scan forward to find the function's label.
+ for (auto it = insn; it != end() && entryPointBlockId == 0; it++)
{
- case spv::OpFunction:
- case spv::OpFunctionParameter:
- break;
- case spv::OpLabel:
- mainBlockId = Block::ID(it.word(1));
- break;
- default:
- WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
+ switch (it.opcode())
+ {
+ case spv::OpFunction:
+ case spv::OpFunctionParameter:
+ break;
+ case spv::OpLabel:
+ entryPointBlockId = Block::ID(it.word(1));
+ break;
+ default:
+ WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
+ }
}
}
- ASSERT(mainBlockId.value() != 0); // Function's OpLabel not found
+ else
+ {
+ // All non-entry point functions should be inlined into an
+ // entry point function.
+ // This isn't the target entry point, so must be another
+ // entry point that we are not interested in. Just skip it.
+ for (; insn != end() && insn.opcode() != spv::OpFunctionEnd; insn++) {}
+ }
+
break;
+ }
case spv::OpFunctionEnd:
// Due to preprocessing, the entrypoint and its function provide no value.
break;
@@ -678,7 +707,7 @@
{
// We will only support the GLSL 450 extended instruction set, so no point in tracking the ID we assign it.
// Valid shaders will not attempt to import any other instruction sets.
- auto ext = reinterpret_cast<char const *>(insn.wordPointer(2));
+ auto ext = insn.string(2);
if (0 != strcmp("GLSL.std.450", ext))
{
UNSUPPORTED("SPIR-V Extension: %s", ext);
@@ -895,7 +924,7 @@
case spv::OpExtension:
{
- auto ext = reinterpret_cast<char const *>(insn.wordPointer(1));
+ auto ext = insn.string(1);
// Part of core SPIR-V 1.3. Vulkan 1.1 implementations must also accept the pre-1.3
// extension per Appendix A, `Vulkan Environment for SPIR-V`.
if (!strcmp(ext, "SPV_KHR_storage_buffer_storage_class")) break;
@@ -909,6 +938,7 @@
}
}
+ ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
AssignBlockIns();
}
@@ -927,7 +957,7 @@
void SpirvShader::AssignBlockIns()
{
Block::Set reachable;
- TraverseReachableBlocks(mainBlockId, reachable);
+ TraverseReachableBlocks(entryPointBlockId, reachable);
for (auto &it : blocks)
{
@@ -1887,8 +1917,8 @@
EmitInstruction(insn, &state);
}
- // Emit all the blocks starting from mainBlockId.
- EmitBlocks(mainBlockId, &state);
+ // Emit all the blocks starting from entryPointBlockId.
+ EmitBlocks(entryPointBlockId, &state);
}
void SpirvShader::EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore /* = 0 */) const
@@ -1979,7 +2009,7 @@
return; // Already generated this block.
}
- if (blockId != mainBlockId)
+ if (blockId != entryPointBlockId)
{
// Set the activeLaneMask.
SIMD::Int activeLaneMask(0);
@@ -5795,6 +5825,31 @@
return it->second;
}
+ VkShaderStageFlagBits SpirvShader::executionModelToStage(spv::ExecutionModel model)
+ {
+ switch (model)
+ {
+ case spv::ExecutionModelVertex: return VK_SHADER_STAGE_VERTEX_BIT;
+ // case spv::ExecutionModelTessellationControl: return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
+ // case spv::ExecutionModelTessellationEvaluation: return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
+ // case spv::ExecutionModelGeometry: return VK_SHADER_STAGE_GEOMETRY_BIT;
+ case spv::ExecutionModelFragment: return VK_SHADER_STAGE_FRAGMENT_BIT;
+ case spv::ExecutionModelGLCompute: return VK_SHADER_STAGE_COMPUTE_BIT;
+ // case spv::ExecutionModelKernel: return VkShaderStageFlagBits(0); // Not supported by vulkan.
+ // case spv::ExecutionModelTaskNV: return VK_SHADER_STAGE_TASK_BIT_NV;
+ // case spv::ExecutionModelMeshNV: return VK_SHADER_STAGE_MESH_BIT_NV;
+ // case spv::ExecutionModelRayGenerationNV: return VK_SHADER_STAGE_RAYGEN_BIT_NV;
+ // case spv::ExecutionModelIntersectionNV: return VK_SHADER_STAGE_INTERSECTION_BIT_NV;
+ // case spv::ExecutionModelAnyHitNV: return VK_SHADER_STAGE_ANY_HIT_BIT_NV;
+ // case spv::ExecutionModelClosestHitNV: return VK_SHADER_STAGE_CLOSEST_HIT_BIT_NV;
+ // case spv::ExecutionModelMissNV: return VK_SHADER_STAGE_MISS_BIT_NV;
+ // case spv::ExecutionModelCallableNV: return VK_SHADER_STAGE_CALLABLE_BIT_NV;
+ default:
+ UNSUPPORTED("ExecutionModel: %d", int(model));
+ return VkShaderStageFlagBits(0);
+ }
+ }
+
SpirvRoutine::SpirvRoutine(vk::PipelineLayout const *pipelineLayout) :
pipelineLayout(pipelineLayout)
{
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 3f71fb5..21a9d5a 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -295,6 +295,11 @@
return &iter[n];
}
+ const char* string(uint32_t n) const
+ {
+ return reinterpret_cast<const char*>(wordPointer(n));
+ }
+
bool operator==(InsnIterator const &other) const
{
return iter == other.iter;
@@ -538,7 +543,10 @@
return serialID;
}
- SpirvShader(InsnStore const &insns, vk::RenderPass *renderPass, uint32_t subpassIndex);
+ SpirvShader(VkPipelineShaderStageCreateInfo const *createInfo,
+ InsnStore const &insns,
+ vk::RenderPass *renderPass,
+ uint32_t subpassIndex);
struct Modes
{
@@ -727,7 +735,7 @@
HandleMap<Type> types;
HandleMap<Object> defs;
HandleMap<Block> blocks;
- Block::ID mainBlockId; // Block of the entry point function.
+ Block::ID entryPointBlockId; // Block of the entry point function.
// Walks all reachable the blocks starting from id adding them to
// reachable.
@@ -980,6 +988,7 @@
static sw::FilterType convertFilterMode(const vk::Sampler *sampler);
static sw::MipmapType convertMipmapMode(const vk::Sampler *sampler);
static sw::AddressingMode convertAddressingMode(int coordinateIndex, VkSamplerAddressMode addressMode, VkImageViewType imageViewType);
+ static VkShaderStageFlagBits executionModelToStage(spv::ExecutionModel model);
};
class SpirvRoutine
diff --git a/src/Vulkan/VkPipeline.cpp b/src/Vulkan/VkPipeline.cpp
index d0cd19e..eaedcac 100644
--- a/src/Vulkan/VkPipeline.cpp
+++ b/src/Vulkan/VkPipeline.cpp
@@ -459,7 +459,7 @@
// FIXME (b/119409619): use an allocator here so we can control all memory allocations
// TODO: also pass in any pipeline state which will affect shader compilation
- auto spirvShader = new sw::SpirvShader{code, Cast(pCreateInfo->renderPass), pCreateInfo->subpass};
+ auto spirvShader = new sw::SpirvShader{pStage, code, Cast(pCreateInfo->renderPass), pCreateInfo->subpass};
switch (pStage->stage)
{
@@ -552,7 +552,7 @@
ASSERT(shader == nullptr);
// FIXME(b/119409619): use allocator.
- shader = new sw::SpirvShader(code, nullptr, 0);
+ shader = new sw::SpirvShader(&pCreateInfo->stage, code, nullptr, 0);
vk::DescriptorSet::Bindings descriptorSets; // FIXME(b/129523279): Delay code generation until invoke time.
program = new sw::ComputeProgram(shader, layout, descriptorSets);
program->generate();