SpirvShader: Implement OpOuterProduct
Bug: b/126873455
Tests: dEQP-VK.glsl.matrix.outerproduct.*
Change-Id: I0189bbec3049d01f02191a6c8a7aee2c5d675af3
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/28469
Reviewed-by: Chris Forbes <chrisforbes@google.com>
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Tested-by: Ben Clayton <bclayton@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 8f14316..80ac8ed 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -361,6 +361,7 @@
case spv::OpMatrixTimesVector:
case spv::OpVectorTimesMatrix:
case spv::OpMatrixTimesMatrix:
+ case spv::OpOuterProduct:
case spv::OpVectorExtractDynamic:
case spv::OpVectorInsertDynamic:
case spv::OpNot: // Unary ops
@@ -1623,6 +1624,9 @@
case spv::OpMatrixTimesMatrix:
return EmitMatrixTimesMatrix(insn, state);
+ case spv::OpOuterProduct:
+ return EmitOuterProduct(insn, state);
+
case spv::OpNot:
case spv::OpSNegate:
case spv::OpFNegate:
@@ -2257,6 +2261,36 @@
return EmitResult::Continue;
}
+ SpirvShader::EmitResult SpirvShader::EmitOuterProduct(InsnIterator insn, EmitState *state) const
+ {
+ auto routine = state->routine;
+ auto &type = getType(insn.word(1));
+ auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
+ auto lhs = GenericValue(this, routine, insn.word(3));
+ auto rhs = GenericValue(this, routine, insn.word(4));
+ auto &lhsType = getType(lhs.type);
+ auto &rhsType = getType(rhs.type);
+
+ ASSERT(type.definition.opcode() == spv::OpTypeMatrix);
+ ASSERT(lhsType.definition.opcode() == spv::OpTypeVector);
+ ASSERT(rhsType.definition.opcode() == spv::OpTypeVector);
+ ASSERT(getType(lhsType.element).opcode() == spv::OpTypeFloat);
+ ASSERT(getType(rhsType.element).opcode() == spv::OpTypeFloat);
+
+ auto numRows = lhsType.definition.word(3);
+ auto numCols = rhsType.definition.word(3);
+
+ for (auto col = 0u; col < numCols; col++)
+ {
+ for (auto row = 0u; row < numRows; row++)
+ {
+ dst.move(col * numRows + row, lhs.Float(row) * rhs.Float(col));
+ }
+ }
+
+ return EmitResult::Continue;
+ }
+
SpirvShader::EmitResult SpirvShader::EmitUnaryOp(InsnIterator insn, EmitState *state) const
{
auto routine = state->routine;
diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp
index c3104c0..34ac47a 100644
--- a/src/Pipeline/SpirvShader.hpp
+++ b/src/Pipeline/SpirvShader.hpp
@@ -648,6 +648,7 @@
EmitResult EmitMatrixTimesVector(InsnIterator insn, EmitState *state) const;
EmitResult EmitVectorTimesMatrix(InsnIterator insn, EmitState *state) const;
EmitResult EmitMatrixTimesMatrix(InsnIterator insn, EmitState *state) const;
+ EmitResult EmitOuterProduct(InsnIterator insn, EmitState *state) const;
EmitResult EmitVectorExtractDynamic(InsnIterator insn, EmitState *state) const;
EmitResult EmitVectorInsertDynamic(InsnIterator insn, EmitState *state) const;
EmitResult EmitUnaryOp(InsnIterator insn, EmitState *state) const;