Add support for OpUMulExtended, OpSMulExtended

- Make the existing LLVMReactor lowering support for MulHigh on non-x86
  available on x86 as well, as we don't have good intrinsics-based implementation
  of 4x 32bit mul highs. At some point in the future we can rework this
  to use some shuffles and a pair of pmuludq.
- Plumb through Int4 and UInt4 variants of MulHigh
- Implement SPIRV OpUMulExtended, OpSMulExtended in terms of MulHigh

Bug: b/126873455
Change-Id: I25ba0a69691e7a6f7a5542ec4a90a44ba8f68331
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/25929
Reviewed-by: Nicolas Capens <nicolascapens@google.com>
Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
Tested-by: Chris Forbes <chrisforbes@google.com>
diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp
index 4789108..b828f31 100644
--- a/src/Pipeline/SpirvShader.cpp
+++ b/src/Pipeline/SpirvShader.cpp
@@ -233,7 +233,9 @@
 			case spv::OpBitwiseAnd:
 			case spv::OpLogicalOr:
 			case spv::OpLogicalAnd:
-				// Instructions that yield an ssavalue.
+			case spv::OpUMulExtended:
+			case spv::OpSMulExtended:
+				// Instructions that yield an intermediate value
 			{
 				TypeID typeId = insn.word(1);
 				ObjectID resultId = insn.word(2);
@@ -907,6 +909,8 @@
 			case spv::OpBitwiseAnd:
 			case spv::OpLogicalOr:
 			case spv::OpLogicalAnd:
+			case spv::OpUMulExtended:
+			case spv::OpSMulExtended:
 				EmitBinaryOp(insn, routine);
 				break;
 
@@ -1200,10 +1204,11 @@
 	{
 		auto &type = getType(insn.word(1));
 		auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
+		auto &lhsType = getType(getObject(insn.word(3)).type);
 		auto srcLHS = GenericValue(this, routine, insn.word(3));
 		auto srcRHS = GenericValue(this, routine, insn.word(4));
 
-		for (auto i = 0u; i < type.sizeInComponents; i++)
+		for (auto i = 0u; i < lhsType.sizeInComponents; i++)
 		{
 			auto lhs = srcLHS[i];
 			auto rhs = srcRHS[i];
@@ -1257,6 +1262,17 @@
 			case spv::OpLogicalAnd:
 				dst.emplace(i, As<SIMD::Float>(As<SIMD::UInt>(lhs) & As<SIMD::UInt>(rhs)));
 				break;
+			case spv::OpSMulExtended:
+				// Extended ops: result is a structure containing two members of the same type as lhs & rhs.
+				// In our flat view then, component i is the i'th component of the first member;
+				// component i + N is the i'th component of the second member.
+				dst.emplace(i, As<SIMD::Float>(As<SIMD::Int>(lhs) * As<SIMD::Int>(rhs)));
+				dst.emplace(i + lhsType.sizeInComponents, As<SIMD::Float>(MulHigh(As<SIMD::Int>(lhs), As<SIMD::Int>(rhs))));
+				break;
+			case spv::OpUMulExtended:
+				dst.emplace(i, As<SIMD::Float>(As<SIMD::UInt>(lhs) * As<SIMD::UInt>(rhs)));
+				dst.emplace(i + lhsType.sizeInComponents, As<SIMD::Float>(MulHigh(As<SIMD::UInt>(lhs), As<SIMD::UInt>(rhs))));
+				break;
 			default:
 				UNIMPLEMENTED("Unhandled binary operator %s", OpcodeName(insn.opcode()).c_str());
 			}
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index e589d17..6831082 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -352,30 +352,6 @@
 		return ::builder->CreateAdd(lhs, rhs);
 	}
 
-	llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext)
-	{
-		llvm::VectorType *ty = llvm::cast<llvm::VectorType>(x->getType());
-		llvm::VectorType *extTy = llvm::VectorType::getExtendedElementVectorType(ty);
-
-		llvm::Value *extX, *extY;
-		if (sext)
-		{
-			extX = ::builder->CreateSExt(x, extTy);
-			extY = ::builder->CreateSExt(y, extTy);
-		}
-		else
-		{
-			extX = ::builder->CreateZExt(x, extTy);
-			extY = ::builder->CreateZExt(y, extTy);
-		}
-
-		llvm::Value *mult = ::builder->CreateMul(extX, extY);
-
-		llvm::IntegerType *intTy = llvm::cast<llvm::IntegerType>(ty->getElementType());
-		llvm::Value *mulh = ::builder->CreateAShr(mult, intTy->getIntegerBitWidth());
-		return ::builder->CreateTrunc(mulh, ty);
-	}
-
 	llvm::Value *lowerPack(llvm::Value *x, llvm::Value *y, bool isSigned)
 	{
 		llvm::VectorType *srcTy = llvm::cast<llvm::VectorType>(x->getType());
@@ -447,6 +423,30 @@
 	}
 #endif  // !defined(__i386__) && !defined(__x86_64__)
 #endif  // REACTOR_LLVM_VERSION >= 7
+
+	llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext)
+	{
+		llvm::VectorType *ty = llvm::cast<llvm::VectorType>(x->getType());
+		llvm::VectorType *extTy = llvm::VectorType::getExtendedElementVectorType(ty);
+
+		llvm::Value *extX, *extY;
+		if (sext)
+		{
+			extX = ::builder->CreateSExt(x, extTy);
+			extY = ::builder->CreateSExt(y, extTy);
+		}
+		else
+		{
+			extX = ::builder->CreateZExt(x, extTy);
+			extY = ::builder->CreateZExt(y, extTy);
+		}
+
+		llvm::Value *mult = ::builder->CreateMul(extX, extY);
+
+		llvm::IntegerType *intTy = llvm::cast<llvm::IntegerType>(ty->getElementType());
+		llvm::Value *mulh = ::builder->CreateAShr(mult, intTy->getBitWidth());
+		return ::builder->CreateTrunc(mulh, ty);
+	}
 }
 
 namespace rr
@@ -5715,6 +5715,18 @@
 #endif
 	}
 
+	RValue<Int4> MulHigh(RValue<Int4> x, RValue<Int4> y)
+	{
+		// TODO: For x86, build an intrinsics version of this which uses shuffles + pmuludq.
+		return As<Int4>(V(lowerMulHigh(V(x.value), V(y.value), true)));
+	}
+
+	RValue<UInt4> MulHigh(RValue<UInt4> x, RValue<UInt4> y)
+	{
+		// TODO: For x86, build an intrinsics version of this which uses shuffles + pmuludq.
+		return As<UInt4>(V(lowerMulHigh(V(x.value), V(y.value), false)));
+	}
+
 	RValue<Short8> PackSigned(RValue<Int4> x, RValue<Int4> y)
 	{
 #if defined(__i386__) || defined(__x86_64__)
diff --git a/src/Reactor/Reactor.hpp b/src/Reactor/Reactor.hpp
index 357d0ce..dd64731 100644
--- a/src/Reactor/Reactor.hpp
+++ b/src/Reactor/Reactor.hpp
@@ -1856,6 +1856,7 @@
 	RValue<Int4> Insert(RValue<Int4> val, RValue<Int> element, int i);
 	RValue<Int> SignMask(RValue<Int4> x);
 	RValue<Int4> Swizzle(RValue<Int4> x, unsigned char select);
+	RValue<Int4> MulHigh(RValue<Int4> x, RValue<Int4> y);
 
 	class UInt4 : public LValue<UInt4>, public XYZW<UInt4>
 	{
@@ -1930,6 +1931,7 @@
 	RValue<UInt4> CmpNLE(RValue<UInt4> x, RValue<UInt4> y);
 	RValue<UInt4> Max(RValue<UInt4> x, RValue<UInt4> y);
 	RValue<UInt4> Min(RValue<UInt4> x, RValue<UInt4> y);
+	RValue<UInt4> MulHigh(RValue<UInt4> x, RValue<UInt4> y);
 //	RValue<UInt4> RoundInt(RValue<Float4> cast);
 
 	class Half : public LValue<Half>