Pipeline/SpirvShader: Support more OpExtInsts
Previously `OpExtInst` and `OpExtInstImport` was hard-coded to only support `GLSL.std.450`.
We will want to support `OpenCL.Debug.100`, so put in the plumbing to properly support other extensions.
Bug: b/145351270
Change-Id: I60fbb8c45bb57b747067437641676946129b1251
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/39885
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index a4ef814..55ea1c3 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -416,13 +416,25 @@
case spv::OpExtInstImport:
{
- // We will only support the GLSL 450 extended instruction set, so no point in tracking the ID we assign it.
- // Valid shaders will not attempt to import any other instruction sets.
- auto ext = insn.string(2);
- if(0 != strcmp("GLSL.std.450", ext))
+ auto id = Extension::ID(insn.word(1));
+ auto name = insn.string(2);
+ auto ext = Extension{ Extension::Unknown };
+ for(auto it : std::initializer_list<std::pair<const char *, Extension::Name>>{
+ { "GLSL.std.450", Extension::GLSLstd450 },
+ })
{
- UNSUPPORTED("SPIR-V Extension: %s", ext);
+ if(0 == strcmp(name, it.first))
+ {
+ ext = Extension{ it.second };
+ break;
+ }
}
+ if(ext.name == Extension::Unknown)
+ {
+ UNSUPPORTED("SPIR-V Extension: %s", name);
+ break;
+ }
+ extensions.emplace(id, ext);
break;
}
case spv::OpName:
@@ -573,7 +585,6 @@
case spv::OpConvertUToF:
case spv::OpBitcast:
case spv::OpSelect:
- case spv::OpExtInst:
case spv::OpIsInf:
case spv::OpIsNan:
case spv::OpAny:
@@ -658,6 +669,18 @@
DefineResult(insn);
break;
+ case spv::OpExtInst:
+ switch(getExtension(insn.word(3)).name)
+ {
+ case Extension::GLSLstd450:
+ DefineResult(insn);
+ break;
+ default:
+ UNREACHABLE("Unexpected Extension name %d", int(getExtension(insn.word(3)).name));
+ break;
+ }
+ break;
+
case spv::OpStore:
case spv::OpAtomicStore:
case spv::OpImageWrite:
@@ -2293,6 +2316,19 @@
return EmitResult::Continue;
}
+SpirvShader::EmitResult SpirvShader::EmitExtendedInstruction(InsnIterator insn, EmitState *state) const
+{
+ auto ext = getExtension(insn.word(3));
+ switch(ext.name)
+ {
+ case Extension::GLSLstd450:
+ return EmitExtGLSLstd450(insn, state);
+ default:
+ UNREACHABLE("Unknown Extension::Name<%d>", int(ext.name));
+ }
+ return EmitResult::Continue;
+}
+
uint32_t SpirvShader::GetConstScalarInt(Object::ID id) const
{
auto &scopeObj = getObject(id);
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index d74fb14..696c47f 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -380,6 +380,20 @@
using String = std::string;
using StringID = SpirvID<std::string>;
+ class Extension
+ {
+ public:
+ using ID = SpirvID<Extension>;
+
+ enum Name
+ {
+ Unknown,
+ GLSLstd450,
+ };
+
+ Name name;
+ };
+
struct TypeOrObject
{}; // Dummy struct to represent a Type or Object.
@@ -732,6 +746,7 @@
HandleMap<Object> defs;
HandleMap<Function> functions;
std::unordered_map<StringID, String> strings;
+ HandleMap<Extension> extensions;
Function::ID entryPoint;
const bool robustBufferAccess = true;
@@ -1017,6 +1032,13 @@
return it->second;
}
+ Extension const &getExtension(Extension::ID id) const
+ {
+ auto it = extensions.find(id);
+ ASSERT_MSG(it != extensions.end(), "Unknown extension %d", id.value());
+ return it->second;
+ }
+
// Returns a SIMD::Pointer to the underlying data for the given pointer
// object.
// Handles objects of the following kinds:
@@ -1069,6 +1091,7 @@
EmitResult EmitDot(InsnIterator insn, EmitState *state) const;
EmitResult EmitSelect(InsnIterator insn, EmitState *state) const;
EmitResult EmitExtendedInstruction(InsnIterator insn, EmitState *state) const;
+ EmitResult EmitExtGLSLstd450(InsnIterator insn, EmitState *state) const;
EmitResult EmitAny(InsnIterator insn, EmitState *state) const;
EmitResult EmitAll(InsnIterator insn, EmitState *state) const;
EmitResult EmitBranch(InsnIterator insn, EmitState *state) const;
diff --git a/src/Pipeline/SpirvShaderGLSLstd450.cpp b/src/Pipeline/SpirvShaderGLSLstd450.cpp
index f6aaeca..0b0de23 100644
--- a/src/Pipeline/SpirvShaderGLSLstd450.cpp
+++ b/src/Pipeline/SpirvShaderGLSLstd450.cpp
@@ -25,7 +25,7 @@
namespace sw {
-SpirvShader::EmitResult SpirvShader::EmitExtendedInstruction(InsnIterator insn, EmitState *state) const
+SpirvShader::EmitResult SpirvShader::EmitExtGLSLstd450(InsnIterator insn, EmitState *state) const
{
auto &type = getType(insn.word(1));
auto &dst = state->createIntermediate(insn.word(2), type.sizeInComponents);