Implement FMA() which always returns a fused multiply-add result

This Reactor intrinsic falls back to calling std::fma() when the backend
or the CPU does not support FMA instructions. The rr::Caps::fmaIsFast()
function can be used to check whether FMA instructions are highly likely
to be emitted instead.

This intrinsic must be used for algorithms that demand the precision of
FMA operations. rr::MulAdd() cannot be relied upon to produce FMA
results, even if fmaIsFast() is true.

Bug: b/214591655
Change-Id: Ide069606cd9fe7cc40fa42e68719b24a0046405d
Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/62109
Kokoro-Result: kokoro <noreply+kokoro@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
Reviewed-by: Alexis Hétu <sugoi@google.com>
diff --git a/src/Reactor/EmulatedIntrinsics.cpp b/src/Reactor/EmulatedIntrinsics.cpp
index e748455..d259cea 100644
--- a/src/Reactor/EmulatedIntrinsics.cpp
+++ b/src/Reactor/EmulatedIntrinsics.cpp
@@ -56,6 +56,18 @@
 	return result;
 }
 
+// Call three arg function on a vector type
+template<typename Func, typename T>
+RValue<T> call4(Func func, const RValue<T> &x, const RValue<T> &y, const RValue<T> &z)
+{
+	T result;
+	result = Insert(result, Call(func, Extract(x, 0), Extract(y, 0), Extract(z, 0)), 0);
+	result = Insert(result, Call(func, Extract(x, 1), Extract(y, 1), Extract(z, 1)), 1);
+	result = Insert(result, Call(func, Extract(x, 2), Extract(y, 2), Extract(z, 2)), 2);
+	result = Insert(result, Call(func, Extract(x, 3), Extract(y, 3), Extract(z, 3)), 3);
+	return result;
+}
+
 template<typename T, typename EL = UnderlyingTypeT<T>>
 void gather(T &out, RValue<Pointer<EL>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes)
 {
@@ -278,5 +290,10 @@
 	return call4(fmodf, lhs, rhs);
 }
 
+RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z)
+{
+	return call4(fmaf, x, y, z);
+}
+
 }  // namespace emulated
 }  // namespace rr
diff --git a/src/Reactor/EmulatedIntrinsics.hpp b/src/Reactor/EmulatedIntrinsics.hpp
index 89aa894..4ca6224 100644
--- a/src/Reactor/EmulatedIntrinsics.hpp
+++ b/src/Reactor/EmulatedIntrinsics.hpp
@@ -53,6 +53,7 @@
 RValue<Int> MaxAtomic(RValue<Pointer<Int>> x, RValue<Int> y, std::memory_order memoryOrder);
 RValue<UInt> MaxAtomic(RValue<Pointer<UInt>> x, RValue<UInt> y, std::memory_order memoryOrder);
 RValue<Float4> FRem(RValue<Float4> lhs, RValue<Float4> rhs);
+RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z);
 
 }  // namespace emulated
 }  // namespace rr
diff --git a/src/Reactor/LLVMJIT.cpp b/src/Reactor/LLVMJIT.cpp
index 11ca7e4..d762518 100644
--- a/src/Reactor/LLVMJIT.cpp
+++ b/src/Reactor/LLVMJIT.cpp
@@ -486,6 +486,7 @@
 			functions.try_emplace("logf", reinterpret_cast<void *>(logf));
 			functions.try_emplace("exp2f", reinterpret_cast<void *>(exp2f));
 			functions.try_emplace("log2f", reinterpret_cast<void *>(log2f));
+			functions.try_emplace("fmaf", reinterpret_cast<void *>(fmaf));
 
 			functions.try_emplace("fmod", reinterpret_cast<void *>(static_cast<double (*)(double, double)>(fmod)));
 			functions.try_emplace("sin", reinterpret_cast<void *>(static_cast<double (*)(double)>(sin)));
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp
index 685fbad..30edadf 100644
--- a/src/Reactor/LLVMReactor.cpp
+++ b/src/Reactor/LLVMReactor.cpp
@@ -3174,6 +3174,12 @@
 	return RValue<Float4>(V(jit->builder->CreateCall(func, { V(x.value()), V(y.value()), V(z.value()) })));
 }
 
+RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z)
+{
+	auto func = llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::fma, { T(Float4::type()) });
+	return RValue<Float4>(V(jit->builder->CreateCall(func, { V(x.value()), V(y.value()), V(z.value()) })));
+}
+
 RValue<Float4> Abs(RValue<Float4> x)
 {
 	auto func = llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::fabs, { V(x.value())->getType() });
diff --git a/src/Reactor/Reactor.hpp b/src/Reactor/Reactor.hpp
index f198a40..d4ca19a 100644
--- a/src/Reactor/Reactor.hpp
+++ b/src/Reactor/Reactor.hpp
@@ -2344,6 +2344,8 @@
 
 // Computes `x * y + z`, which may be fused into one operation to produce a higher-precision result.
 RValue<Float4> MulAdd(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z);
+// Computes a fused `x * y + z` operation. Caps::fmaIsFast indicates whether it emits an FMA instruction.
+RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z);
 
 RValue<Float4> Abs(RValue<Float4> x);
 RValue<Float4> Max(RValue<Float4> x, RValue<Float4> y);
diff --git a/src/Reactor/SubzeroReactor.cpp b/src/Reactor/SubzeroReactor.cpp
index 4fa6bb2..07c9ed8 100644
--- a/src/Reactor/SubzeroReactor.cpp
+++ b/src/Reactor/SubzeroReactor.cpp
@@ -3947,6 +3947,12 @@
 	return x * y + z;
 }
 
+RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z)
+{
+	// TODO(b/214591655): Use FMA instructions when available.
+	return emulated::FMA(x, y, z);
+}
+
 RValue<Float4> Abs(RValue<Float4> x)
 {
 	// TODO: Optimize.
diff --git a/tests/ReactorUnitTests/ReactorUnitTests.cpp b/tests/ReactorUnitTests/ReactorUnitTests.cpp
index 543a9d1..19056db 100644
--- a/tests/ReactorUnitTests/ReactorUnitTests.cpp
+++ b/tests/ReactorUnitTests/ReactorUnitTests.cpp
@@ -1231,6 +1231,39 @@
 	}
 }
 
+TEST(ReactorUnitTests, FMA)
+{
+	Function<Void(Pointer<Float4>, Pointer<Float4>, Pointer<Float4>, Pointer<Float4>)> function;
+	{
+		Pointer<Float4> r = function.Arg<0>();
+		Pointer<Float4> x = function.Arg<1>();
+		Pointer<Float4> y = function.Arg<2>();
+		Pointer<Float4> z = function.Arg<3>();
+
+		*r = FMA(*x, *y, *z);
+	}
+
+	auto routine = function(testName().c_str());
+	auto callable = (void (*)(float4 *, float4 *, float4 *, float4 *))routine->getEntry();
+
+	float x[] = { 0.0f, 2.0f, 4.0f, 1.00000011920929f };
+	float y[] = { 0.0f, 3.0f, 0.0f, 53400708.0f };
+	float z[] = { 0.0f, 0.0f, 7.0f, -53400708.0f };
+
+	for(size_t i = 0; i < std::size(x); i++)
+	{
+		float4 x_in = { x[i], x[i], x[i], x[i] };
+		float4 y_in = { y[i], y[i], y[i], y[i] };
+		float4 z_in = { z[i], z[i], z[i], z[i] };
+		float4 r_out;
+
+		callable(&r_out, &x_in, &y_in, &z_in);
+
+		float expected = fmaf(x[i], y[i], z[i]);
+		EXPECT_FLOAT_EQ(r_out[0], expected);
+	}
+}
+
 TEST(ReactorUnitTests, FAbs)
 {
 	Function<Void(Pointer<Float4>, Pointer<Float4>)> function;