Refactor GroupNonUniformArithmetic handling
The implementation of GroupNonUniformArithmetic instructions used to
rely on a pimpl struct containing a static template function. It's
not very elegant and causes code bloat. It's also too closely tied to
the SpirvShader class, which we want to refactor into a parsing-only
class.
This change reduces the template function to performing the operation
on scalarized components. It doesn't need access to SpirvShader nor
EmitState and takes one fewer arguments.
Bug: b/247020580
Change-Id: I24955b42d84a3a31f139bc0a3aacc09e82f58fee
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/68868
Tested-by: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
Kokoro-Result: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index 2e57c5d..1e775f7 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -1589,7 +1589,6 @@
struct Impl
{
struct Debugger;
- struct Group;
Debugger *debugger = nullptr;
};
Impl impl;
diff --git a/src/Pipeline/SpirvShaderGroup.cpp b/src/Pipeline/SpirvShaderGroup.cpp
index 19c2c05..a081437 100644
--- a/src/Pipeline/SpirvShaderGroup.cpp
+++ b/src/Pipeline/SpirvShaderGroup.cpp
@@ -18,63 +18,51 @@
namespace sw {
-struct SpirvShader::Impl::Group
+// Template function to perform a binary group operation.
+// |TYPE| should be the type of the binary operation (as a SIMD::<ScalarType>).
+// |I| should be a type suitable to initialize the identity value.
+// |APPLY| should be a callable object that takes two RValue<TYPE> parameters
+// and returns a new RValue<TYPE> corresponding to the operation's result.
+template<typename TYPE, typename I, typename APPLY>
+static RValue<TYPE> BinaryOperation(
+ spv::GroupOperation operation,
+ RValue<SIMD::UInt> value,
+ RValue<SIMD::UInt> mask,
+ const I identityValue,
+ APPLY &&apply)
{
- // Template function to perform a binary operation.
- // |TYPE| should be the type of the binary operation (as a SIMD::<ScalarType>).
- // |I| should be a type suitable to initialize the identity value.
- // |APPLY| should be a callable object that takes two RValue<TYPE> parameters
- // and returns a new RValue<TYPE> corresponding to the operation's result.
- template<typename TYPE, typename I, typename APPLY>
- static void BinaryOperation(
- const SpirvShader *shader,
- const SpirvShader::InsnIterator &insn,
- const SpirvShader::EmitState *state,
- Intermediate &dst,
- const I identityValue,
- APPLY &&apply)
+ auto identity = TYPE(identityValue);
+ SIMD::UInt v_uint = (value & mask) | (As<SIMD::UInt>(identity) & ~mask);
+ TYPE v = As<TYPE>(v_uint);
+
+ switch(operation)
{
- SpirvShader::Operand value(shader, state, insn.word(5));
- auto &type = shader->getType(SpirvShader::Type::ID(insn.word(1)));
- for(auto i = 0u; i < type.componentCount; i++)
+ case spv::GroupOperationReduce:
{
- auto mask = As<SIMD::UInt>(state->activeLaneMask()); // Considers helper invocations active. See b/151137030
- auto identity = TYPE(identityValue);
- SIMD::UInt v_uint = (value.UInt(i) & mask) | (As<SIMD::UInt>(identity) & ~mask);
- TYPE v = As<TYPE>(v_uint);
- switch(spv::GroupOperation(insn.word(4)))
- {
- case spv::GroupOperationReduce:
- {
- // NOTE: floating-point add and multiply are not really commutative so
- // ensure that all values in the final lanes are identical
- TYPE v2 = apply(v.xxzz, v.yyww); // [xy] [xy] [zw] [zw]
- TYPE v3 = apply(v2.xxxx, v2.zzzz); // [xyzw] [xyzw] [xyzw] [xyzw]
- dst.move(i, v3);
- }
- break;
- case spv::GroupOperationInclusiveScan:
- {
- TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
- TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
- dst.move(i, v3);
- }
- break;
- case spv::GroupOperationExclusiveScan:
- {
- TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
- TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
- auto v4 = Shuffle(v3, identity, 0x4012 /* [id, v3.x, v3.y, v3.z] */); // [i] [x] [xy] [xyz]
- dst.move(i, v4);
- }
- break;
- default:
- UNSUPPORTED("EmitGroupNonUniform op: %s Group operation: %d",
- SpirvShader::OpcodeName(type.opcode()), insn.word(4));
- }
+ // NOTE: floating-point add and multiply are not really commutative so
+ // ensure that all values in the final lanes are identical
+ TYPE v2 = apply(v.xxzz, v.yyww); // [xy] [xy] [zw] [zw]
+ return apply(v2.xxxx, v2.zzzz); // [xyzw] [xyzw] [xyzw] [xyzw]
}
+ break;
+ case spv::GroupOperationInclusiveScan:
+ {
+ TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
+ return apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
+ }
+ break;
+ case spv::GroupOperationExclusiveScan:
+ {
+ TYPE v2 = apply(v, Shuffle(v, identity, 0x4012) /* [id, v.y, v.z, v.w] */); // [x] [xy] [yz] [zw]
+ TYPE v3 = apply(v2, Shuffle(v2, identity, 0x4401) /* [id, id, v2.x, v2.y] */); // [x] [xy] [xyz] [xyzw]
+ return Shuffle(v3, identity, 0x4012 /* [id, v3.x, v3.y, v3.z] */); // [i] [x] [xy] [xyz]
+ }
+ break;
+ default:
+ UNSUPPORTED("Group operation: %d", operation);
+ return identity;
}
-};
+}
SpirvShader::EmitResult SpirvShader::EmitGroupNonUniform(InsnIterator insn, EmitState *state) const
{
@@ -396,113 +384,126 @@
}
break;
- case spv::OpGroupNonUniformIAdd:
- Impl::Group::BinaryOperation<SIMD::Int>(
- this, insn, state, dst, 0,
- [](auto a, auto b) { return a + b; });
- break;
-
- case spv::OpGroupNonUniformFAdd:
- Impl::Group::BinaryOperation<SIMD::Float>(
- this, insn, state, dst, 0.0f,
- [](auto a, auto b) { return a + b; });
- break;
-
- case spv::OpGroupNonUniformIMul:
- Impl::Group::BinaryOperation<SIMD::Int>(
- this, insn, state, dst, 1,
- [](auto a, auto b) { return a * b; });
- break;
-
- case spv::OpGroupNonUniformFMul:
- Impl::Group::BinaryOperation<SIMD::Float>(
- this, insn, state, dst, 1.0f,
- [](auto a, auto b) { return a * b; });
- break;
-
- case spv::OpGroupNonUniformBitwiseAnd:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, ~0u,
- [](auto a, auto b) { return a & b; });
- break;
-
- case spv::OpGroupNonUniformBitwiseOr:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, 0,
- [](auto a, auto b) { return a | b; });
- break;
-
- case spv::OpGroupNonUniformBitwiseXor:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, 0,
- [](auto a, auto b) { return a ^ b; });
- break;
-
- case spv::OpGroupNonUniformSMin:
- Impl::Group::BinaryOperation<SIMD::Int>(
- this, insn, state, dst, INT32_MAX,
- [](auto a, auto b) { return Min(a, b); });
- break;
-
- case spv::OpGroupNonUniformUMin:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, ~0u,
- [](auto a, auto b) { return Min(a, b); });
- break;
-
- case spv::OpGroupNonUniformFMin:
- Impl::Group::BinaryOperation<SIMD::Float>(
- this, insn, state, dst, SIMD::Float::infinity(),
- [](auto a, auto b) { return NMin(a, b); });
- break;
-
- case spv::OpGroupNonUniformSMax:
- Impl::Group::BinaryOperation<SIMD::Int>(
- this, insn, state, dst, INT32_MIN,
- [](auto a, auto b) { return Max(a, b); });
- break;
-
- case spv::OpGroupNonUniformUMax:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, 0,
- [](auto a, auto b) { return Max(a, b); });
- break;
-
- case spv::OpGroupNonUniformFMax:
- Impl::Group::BinaryOperation<SIMD::Float>(
- this, insn, state, dst, -SIMD::Float::infinity(),
- [](auto a, auto b) { return NMax(a, b); });
- break;
-
- case spv::OpGroupNonUniformLogicalAnd:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, ~0u,
- [](auto a, auto b) {
- SIMD::UInt zero = SIMD::UInt(0);
- return CmpNEQ(a, zero) & CmpNEQ(b, zero);
- });
- break;
-
- case spv::OpGroupNonUniformLogicalOr:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, 0,
- [](auto a, auto b) {
- SIMD::UInt zero = SIMD::UInt(0);
- return CmpNEQ(a, zero) | CmpNEQ(b, zero);
- });
- break;
-
- case spv::OpGroupNonUniformLogicalXor:
- Impl::Group::BinaryOperation<SIMD::UInt>(
- this, insn, state, dst, 0,
- [](auto a, auto b) {
- SIMD::UInt zero = SIMD::UInt(0);
- return CmpNEQ(a, zero) ^ CmpNEQ(b, zero);
- });
- break;
-
+ // The remaining instructions are GroupNonUniformArithmetic operations
default:
- UNSUPPORTED("EmitGroupNonUniform op: %s", OpcodeName(type.opcode()));
+ auto &type = getType(SpirvShader::Type::ID(insn.word(1)));
+ auto operation = static_cast<spv::GroupOperation>(insn.word(4));
+ SpirvShader::Operand value(this, state, insn.word(5));
+ auto mask = As<SIMD::UInt>(state->activeLaneMask()); // Considers helper invocations active. See b/151137030
+
+ for(uint32_t i = 0; i < type.componentCount; i++)
+ {
+ switch(insn.opcode())
+ {
+ case spv::OpGroupNonUniformIAdd:
+ dst.move(i, BinaryOperation<SIMD::Int>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) { return a + b; }));
+ break;
+ case spv::OpGroupNonUniformFAdd:
+ dst.move(i, BinaryOperation<SIMD::Float>(
+ operation, value.UInt(i), mask, 0.0f,
+ [](auto a, auto b) { return a + b; }));
+ break;
+
+ case spv::OpGroupNonUniformIMul:
+ dst.move(i, BinaryOperation<SIMD::Int>(
+ operation, value.UInt(i), mask, 1,
+ [](auto a, auto b) { return a * b; }));
+ break;
+
+ case spv::OpGroupNonUniformFMul:
+ dst.move(i, BinaryOperation<SIMD::Float>(
+ operation, value.UInt(i), mask, 1.0f,
+ [](auto a, auto b) { return a * b; }));
+ break;
+
+ case spv::OpGroupNonUniformBitwiseAnd:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, ~0u,
+ [](auto a, auto b) { return a & b; }));
+ break;
+
+ case spv::OpGroupNonUniformBitwiseOr:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) { return a | b; }));
+ break;
+
+ case spv::OpGroupNonUniformBitwiseXor:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) { return a ^ b; }));
+ break;
+
+ case spv::OpGroupNonUniformSMin:
+ dst.move(i, BinaryOperation<SIMD::Int>(
+ operation, value.UInt(i), mask, INT32_MAX,
+ [](auto a, auto b) { return Min(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformUMin:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, ~0u,
+ [](auto a, auto b) { return Min(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformFMin:
+ dst.move(i, BinaryOperation<SIMD::Float>(
+ operation, value.UInt(i), mask, SIMD::Float::infinity(),
+ [](auto a, auto b) { return NMin(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformSMax:
+ dst.move(i, BinaryOperation<SIMD::Int>(
+ operation, value.UInt(i), mask, INT32_MIN,
+ [](auto a, auto b) { return Max(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformUMax:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) { return Max(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformFMax:
+ dst.move(i, BinaryOperation<SIMD::Float>(
+ operation, value.UInt(i), mask, -SIMD::Float::infinity(),
+ [](auto a, auto b) { return NMax(a, b); }));
+ break;
+
+ case spv::OpGroupNonUniformLogicalAnd:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, ~0u,
+ [](auto a, auto b) {
+ SIMD::UInt zero = SIMD::UInt(0);
+ return CmpNEQ(a, zero) & CmpNEQ(b, zero);
+ }));
+ break;
+
+ case spv::OpGroupNonUniformLogicalOr:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) {
+ SIMD::UInt zero = SIMD::UInt(0);
+ return CmpNEQ(a, zero) | CmpNEQ(b, zero);
+ }));
+ break;
+
+ case spv::OpGroupNonUniformLogicalXor:
+ dst.move(i, BinaryOperation<SIMD::UInt>(
+ operation, value.UInt(i), mask, 0,
+ [](auto a, auto b) {
+ SIMD::UInt zero = SIMD::UInt(0);
+ return CmpNEQ(a, zero) ^ CmpNEQ(b, zero);
+ }));
+ break;
+
+ default:
+ UNSUPPORTED("EmitGroupNonUniform op: %s", OpcodeName(type.opcode()));
+ }
+ }
+ break;
}
return EmitResult::Continue;
}