Add support for OpUMulExtended, OpSMulExtended
- Make the existing LLVMReactor lowering support for MulHigh on non-x86
available on x86 as well, as we don't have good intrinsics-based implementation
of 4x 32bit mul highs. At some point in the future we can rework this
to use some shuffles and a pair of pmuludq.
- Plumb through Int4 and UInt4 variants of MulHigh
- Implement SPIRV OpUMulExtended, OpSMulExtended in terms of MulHigh
Bug: b/126873455
Change-Id: I25ba0a69691e7a6f7a5542ec4a90a44ba8f68331
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/25929
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Chris Forbes <chrisforbes@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 4789108..b828f31 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -233,7 +233,9 @@
case spv::OpBitwiseAnd:
case spv::OpLogicalOr:
case spv::OpLogicalAnd:
- // Instructions that yield an ssavalue.
+ case spv::OpUMulExtended:
+ case spv::OpSMulExtended:
+ // Instructions that yield an intermediate value
{
TypeID typeId = insn.word(1);
ObjectID resultId = insn.word(2);
@@ -907,6 +909,8 @@
case spv::OpBitwiseAnd:
case spv::OpLogicalOr:
case spv::OpLogicalAnd:
+ case spv::OpUMulExtended:
+ case spv::OpSMulExtended:
EmitBinaryOp(insn, routine);
break;
@@ -1200,10 +1204,11 @@
{
auto &type = getType(insn.word(1));
auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
+ auto &lhsType = getType(getObject(insn.word(3)).type);
auto srcLHS = GenericValue(this, routine, insn.word(3));
auto srcRHS = GenericValue(this, routine, insn.word(4));
- for (auto i = 0u; i < type.sizeInComponents; i++)
+ for (auto i = 0u; i < lhsType.sizeInComponents; i++)
{
auto lhs = srcLHS[i];
auto rhs = srcRHS[i];
@@ -1257,6 +1262,17 @@
case spv::OpLogicalAnd:
dst.emplace(i, As<SIMD::Float>(As<SIMD::UInt>(lhs) & As<SIMD::UInt>(rhs)));
break;
+ case spv::OpSMulExtended:
+ // Extended ops: result is a structure containing two members of the same type as lhs & rhs.
+ // In our flat view then, component i is the i'th component of the first member;
+ // component i + N is the i'th component of the second member.
+ dst.emplace(i, As<SIMD::Float>(As<SIMD::Int>(lhs) * As<SIMD::Int>(rhs)));
+ dst.emplace(i + lhsType.sizeInComponents, As<SIMD::Float>(MulHigh(As<SIMD::Int>(lhs), As<SIMD::Int>(rhs))));
+ break;
+ case spv::OpUMulExtended:
+ dst.emplace(i, As<SIMD::Float>(As<SIMD::UInt>(lhs) * As<SIMD::UInt>(rhs)));
+ dst.emplace(i + lhsType.sizeInComponents, As<SIMD::Float>(MulHigh(As<SIMD::UInt>(lhs), As<SIMD::UInt>(rhs))));
+ break;
default:
UNIMPLEMENTED("Unhandled binary operator %s", OpcodeName(insn.opcode()).c_str());
}
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index e589d17..6831082 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -352,30 +352,6 @@
return ::builder->CreateAdd(lhs, rhs);
}
- llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext)
- {
- llvm::VectorType *ty = llvm::cast<llvm::VectorType>(x->getType());
- llvm::VectorType *extTy = llvm::VectorType::getExtendedElementVectorType(ty);
-
- llvm::Value *extX, *extY;
- if (sext)
- {
- extX = ::builder->CreateSExt(x, extTy);
- extY = ::builder->CreateSExt(y, extTy);
- }
- else
- {
- extX = ::builder->CreateZExt(x, extTy);
- extY = ::builder->CreateZExt(y, extTy);
- }
-
- llvm::Value *mult = ::builder->CreateMul(extX, extY);
-
- llvm::IntegerType *intTy = llvm::cast<llvm::IntegerType>(ty->getElementType());
- llvm::Value *mulh = ::builder->CreateAShr(mult, intTy->getIntegerBitWidth());
- return ::builder->CreateTrunc(mulh, ty);
- }
-
llvm::Value *lowerPack(llvm::Value *x, llvm::Value *y, bool isSigned)
{
llvm::VectorType *srcTy = llvm::cast<llvm::VectorType>(x->getType());
@@ -447,6 +423,30 @@
}
#endif // !defined(__i386__) && !defined(__x86_64__)
#endif // REACTOR_LLVM_VERSION >= 7
+
+ llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext)
+ {
+ llvm::VectorType *ty = llvm::cast<llvm::VectorType>(x->getType());
+ llvm::VectorType *extTy = llvm::VectorType::getExtendedElementVectorType(ty);
+
+ llvm::Value *extX, *extY;
+ if (sext)
+ {
+ extX = ::builder->CreateSExt(x, extTy);
+ extY = ::builder->CreateSExt(y, extTy);
+ }
+ else
+ {
+ extX = ::builder->CreateZExt(x, extTy);
+ extY = ::builder->CreateZExt(y, extTy);
+ }
+
+ llvm::Value *mult = ::builder->CreateMul(extX, extY);
+
+ llvm::IntegerType *intTy = llvm::cast<llvm::IntegerType>(ty->getElementType());
+ llvm::Value *mulh = ::builder->CreateAShr(mult, intTy->getBitWidth());
+ return ::builder->CreateTrunc(mulh, ty);
+ }
}
namespace rr
@@ -5715,6 +5715,18 @@
#endif
}
+ RValue<Int4> MulHigh(RValue<Int4> x, RValue<Int4> y)
+ {
+ // TODO: For x86, build an intrinsics version of this which uses shuffles + pmuludq.
+ return As<Int4>(V(lowerMulHigh(V(x.value), V(y.value), true)));
+ }
+
+ RValue<UInt4> MulHigh(RValue<UInt4> x, RValue<UInt4> y)
+ {
+ // TODO: For x86, build an intrinsics version of this which uses shuffles + pmuludq.
+ return As<UInt4>(V(lowerMulHigh(V(x.value), V(y.value), false)));
+ }
+
RValue<Short8> PackSigned(RValue<Int4> x, RValue<Int4> y)
{
#if defined(__i386__) || defined(__x86_64__)
diff --git a/src/Reactor/Reactor.hpp b/src/Reactor/Reactor.hpp
index 357d0ce..dd64731 100644
--- a/src/Reactor/Reactor.hpp
+++ b/src/Reactor/Reactor.hpp
@@ -1856,6 +1856,7 @@
RValue<Int4> Insert(RValue<Int4> val, RValue<Int> element, int i);
RValue<Int> SignMask(RValue<Int4> x);
RValue<Int4> Swizzle(RValue<Int4> x, unsigned char select);
+ RValue<Int4> MulHigh(RValue<Int4> x, RValue<Int4> y);
class UInt4 : public LValue<UInt4>, public XYZW<UInt4>
{
@@ -1930,6 +1931,7 @@
RValue<UInt4> CmpNLE(RValue<UInt4> x, RValue<UInt4> y);
RValue<UInt4> Max(RValue<UInt4> x, RValue<UInt4> y);
RValue<UInt4> Min(RValue<UInt4> x, RValue<UInt4> y);
+ RValue<UInt4> MulHigh(RValue<UInt4> x, RValue<UInt4> y);
// RValue<UInt4> RoundInt(RValue<Float4> cast);
class Half : public LValue<Half>