Implement FMA() which always returns a fused multiply-add result This Reactor intrinsic falls back to calling std::fma() when the backend or the CPU does not support FMA instructions. The rr::Caps::fmaIsFast() function can be used to check whether FMA instructions are highly likely to be emitted instead. This intrinsic must be used for algorithms that demand the precision of FMA operations. rr::MulAdd() cannot be relied upon to produce FMA results, even if fmaIsFast() is true. Bug: b/214591655 Change-Id: Ide069606cd9fe7cc40fa42e68719b24a0046405d Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/62109 Kokoro-Result: kokoro <noreply+kokoro@google.com> Tested-by: Nicolas Capens <nicolascapens@google.com> Reviewed-by: Alexis Hétu <sugoi@google.com>
diff --git a/src/Reactor/EmulatedIntrinsics.cpp b/src/Reactor/EmulatedIntrinsics.cpp index e748455..d259cea 100644 --- a/src/Reactor/EmulatedIntrinsics.cpp +++ b/src/Reactor/EmulatedIntrinsics.cpp
@@ -56,6 +56,18 @@ return result; } +// Call three arg function on a vector type +template<typename Func, typename T> +RValue<T> call4(Func func, const RValue<T> &x, const RValue<T> &y, const RValue<T> &z) +{ + T result; + result = Insert(result, Call(func, Extract(x, 0), Extract(y, 0), Extract(z, 0)), 0); + result = Insert(result, Call(func, Extract(x, 1), Extract(y, 1), Extract(z, 1)), 1); + result = Insert(result, Call(func, Extract(x, 2), Extract(y, 2), Extract(z, 2)), 2); + result = Insert(result, Call(func, Extract(x, 3), Extract(y, 3), Extract(z, 3)), 3); + return result; +} + template<typename T, typename EL = UnderlyingTypeT<T>> void gather(T &out, RValue<Pointer<EL>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes) { @@ -278,5 +290,10 @@ return call4(fmodf, lhs, rhs); } +RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z) +{ + return call4(fmaf, x, y, z); +} + } // namespace emulated } // namespace rr
diff --git a/src/Reactor/EmulatedIntrinsics.hpp b/src/Reactor/EmulatedIntrinsics.hpp index 89aa894..4ca6224 100644 --- a/src/Reactor/EmulatedIntrinsics.hpp +++ b/src/Reactor/EmulatedIntrinsics.hpp
@@ -53,6 +53,7 @@ RValue<Int> MaxAtomic(RValue<Pointer<Int>> x, RValue<Int> y, std::memory_order memoryOrder); RValue<UInt> MaxAtomic(RValue<Pointer<UInt>> x, RValue<UInt> y, std::memory_order memoryOrder); RValue<Float4> FRem(RValue<Float4> lhs, RValue<Float4> rhs); +RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z); } // namespace emulated } // namespace rr
diff --git a/src/Reactor/LLVMJIT.cpp b/src/Reactor/LLVMJIT.cpp index 11ca7e4..d762518 100644 --- a/src/Reactor/LLVMJIT.cpp +++ b/src/Reactor/LLVMJIT.cpp
@@ -486,6 +486,7 @@ functions.try_emplace("logf", reinterpret_cast<void *>(logf)); functions.try_emplace("exp2f", reinterpret_cast<void *>(exp2f)); functions.try_emplace("log2f", reinterpret_cast<void *>(log2f)); + functions.try_emplace("fmaf", reinterpret_cast<void *>(fmaf)); functions.try_emplace("fmod", reinterpret_cast<void *>(static_cast<double (*)(double, double)>(fmod))); functions.try_emplace("sin", reinterpret_cast<void *>(static_cast<double (*)(double)>(sin)));
diff --git a/src/Reactor/LLVMReactor.cpp b/src/Reactor/LLVMReactor.cpp index 685fbad..30edadf 100644 --- a/src/Reactor/LLVMReactor.cpp +++ b/src/Reactor/LLVMReactor.cpp
@@ -3174,6 +3174,12 @@ return RValue<Float4>(V(jit->builder->CreateCall(func, { V(x.value()), V(y.value()), V(z.value()) }))); } +RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z) +{ + auto func = llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::fma, { T(Float4::type()) }); + return RValue<Float4>(V(jit->builder->CreateCall(func, { V(x.value()), V(y.value()), V(z.value()) }))); +} + RValue<Float4> Abs(RValue<Float4> x) { auto func = llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::fabs, { V(x.value())->getType() });
diff --git a/src/Reactor/Reactor.hpp b/src/Reactor/Reactor.hpp index f198a40..d4ca19a 100644 --- a/src/Reactor/Reactor.hpp +++ b/src/Reactor/Reactor.hpp
@@ -2344,6 +2344,8 @@ // Computes `x * y + z`, which may be fused into one operation to produce a higher-precision result. RValue<Float4> MulAdd(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z); +// Computes a fused `x * y + z` operation. Caps::fmaIsFast indicates whether it emits an FMA instruction. +RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z); RValue<Float4> Abs(RValue<Float4> x); RValue<Float4> Max(RValue<Float4> x, RValue<Float4> y);
diff --git a/src/Reactor/SubzeroReactor.cpp b/src/Reactor/SubzeroReactor.cpp index 4fa6bb2..07c9ed8 100644 --- a/src/Reactor/SubzeroReactor.cpp +++ b/src/Reactor/SubzeroReactor.cpp
@@ -3947,6 +3947,12 @@ return x * y + z; } +RValue<Float4> FMA(RValue<Float4> x, RValue<Float4> y, RValue<Float4> z) +{ + // TODO(b/214591655): Use FMA instructions when available. + return emulated::FMA(x, y, z); +} + RValue<Float4> Abs(RValue<Float4> x) { // TODO: Optimize.
diff --git a/tests/ReactorUnitTests/ReactorUnitTests.cpp b/tests/ReactorUnitTests/ReactorUnitTests.cpp index 543a9d1..19056db 100644 --- a/tests/ReactorUnitTests/ReactorUnitTests.cpp +++ b/tests/ReactorUnitTests/ReactorUnitTests.cpp
@@ -1231,6 +1231,39 @@ } } +TEST(ReactorUnitTests, FMA) +{ + Function<Void(Pointer<Float4>, Pointer<Float4>, Pointer<Float4>, Pointer<Float4>)> function; + { + Pointer<Float4> r = function.Arg<0>(); + Pointer<Float4> x = function.Arg<1>(); + Pointer<Float4> y = function.Arg<2>(); + Pointer<Float4> z = function.Arg<3>(); + + *r = FMA(*x, *y, *z); + } + + auto routine = function(testName().c_str()); + auto callable = (void (*)(float4 *, float4 *, float4 *, float4 *))routine->getEntry(); + + float x[] = { 0.0f, 2.0f, 4.0f, 1.00000011920929f }; + float y[] = { 0.0f, 3.0f, 0.0f, 53400708.0f }; + float z[] = { 0.0f, 0.0f, 7.0f, -53400708.0f }; + + for(size_t i = 0; i < std::size(x); i++) + { + float4 x_in = { x[i], x[i], x[i], x[i] }; + float4 y_in = { y[i], y[i], y[i], y[i] }; + float4 z_in = { z[i], z[i], z[i], z[i] }; + float4 r_out; + + callable(&r_out, &x_in, &y_in, &z_in); + + float expected = fmaf(x[i], y[i], z[i]); + EXPECT_FLOAT_EQ(r_out[0], expected); + } +} + TEST(ReactorUnitTests, FAbs) { Function<Void(Pointer<Float4>, Pointer<Float4>)> function;