LLVM DSL: reimplement fmuladd, force hw fma if present

This commit is contained in:
Nekotekina 2021-09-01 20:43:57 +03:00
parent 2acb6ed60d
commit 1685769bd9
2 changed files with 75 additions and 32 deletions

View File

@ -2679,6 +2679,63 @@ struct llvm_fabs
}
};
template <typename A1, typename A2, typename A3, typename T = llvm_common_t<A1, A2, A3>>
struct llvm_fmuladd
{
using type = T;
llvm_expr_t<A1> a1;
llvm_expr_t<A2> a2;
llvm_expr_t<A3> a3;
bool strict_fma;
static_assert(llvm_value_t<T>::is_float, "llvm_fmuladd<>: invalid type");
static constexpr bool is_ok = llvm_value_t<T>::is_float;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
llvm::Value* v1 = a1.eval(ir);
llvm::Value* v2 = a2.eval(ir);
llvm::Value* v3 = a3.eval(ir);
if (llvm::isa<llvm::Constant>(v1) && llvm::isa<llvm::Constant>(v2) && llvm::isa<llvm::Constant>(v3))
{
return llvm::ConstantFoldInstruction(ir->CreateIntrinsic(llvm::Intrinsic::fma, {v1->getType()}, {v1, v2, v3}), llvm::DataLayout(""));
}
return ir->CreateIntrinsic(strict_fma ? llvm::Intrinsic::fma : llvm::Intrinsic::fmuladd, {v1->getType()}, {v1, v2, v3});
}
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == (strict_fma ? llvm::Intrinsic::fma : llvm::Intrinsic::fmuladd))
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
if (auto r3 = a3.match(v3); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
return {};
}
};
class cpu_translator
{
protected:
@ -2990,6 +3047,20 @@ public:
return llvm_fabs<T>{std::forward<T>(a)};
}
// Optionally opportunistic hardware FMA, can be used if results are identical for all possible input values
template <typename T, typename U, typename V, typename = std::enable_if_t<llvm_fmuladd<T, U, V>::is_ok>>
static auto fmuladd(T&& a, U&& b, V&& c, bool strict_fma)
{
return llvm_fmuladd<T, U, V>{std::forward<T>(a), std::forward<U>(b), std::forward<V>(c), strict_fma};
}
// Opportunistic hardware FMA, can be used if results are identical for all possible input values
template <typename T, typename U, typename V, typename = std::enable_if_t<llvm_fmuladd<T, U, V>::is_ok>>
auto fmuladd(T&& a, U&& b, V&& c)
{
return llvm_fmuladd<T, U, V>{std::forward<T>(a), std::forward<U>(b), std::forward<V>(c), m_use_fma};
}
template <typename... Types>
llvm::Function* get_intrinsic(llvm::Intrinsic::ID id)
{
@ -2997,18 +3068,6 @@ public:
return llvm::Intrinsic::getDeclaration(_module, id, {get_type<Types>()...});
}
// Opportunistic hardware FMA, can be used if results are identical for all possible input values
template <typename T>
auto fmuladd(T a, T b, T c)
{
value_t<typename T::type> result;
const auto av = a.eval(m_ir);
const auto bv = b.eval(m_ir);
const auto cv = c.eval(m_ir);
result.value = m_ir->CreateCall(get_intrinsic<typename T::type>(llvm::Intrinsic::fmuladd), {av, bv, cv});
return result;
}
// TODO: Support doubles
template <typename T, typename = std::enable_if_t<llvm_value_t<typename T::type>::esize == 32u && llvm_value_t<typename T::type>::is_float>>
auto fre(T a)

View File

@ -7751,11 +7751,7 @@ public:
const auto [a, b, c] = get_vrs<f64[2]>(op.ra, op.rb, op.rt);
if (g_cfg.core.llvm_accurate_dfma)
{
value_t<f64[2]> r;
r.value = m_ir->CreateCall(get_intrinsic<f64[2]>(llvm::Intrinsic::fma), {a.value, b.value, c.value});
set_vr(op.rt, r);
}
set_vr(op.rt, fmuladd(a, b, c, true));
else
set_vr(op.rt, a * b + c);
}
@ -7765,11 +7761,7 @@ public:
const auto [a, b, c] = get_vrs<f64[2]>(op.ra, op.rb, op.rt);
if (g_cfg.core.llvm_accurate_dfma)
{
value_t<f64[2]> r;
r.value = m_ir->CreateCall(get_intrinsic<f64[2]>(llvm::Intrinsic::fma), {a.value, b.value, eval(-c).value});
set_vr(op.rt, r);
}
set_vr(op.rt, fmuladd(a, b, -c, true));
else
set_vr(op.rt, a * b - c);
}
@ -7779,11 +7771,7 @@ public:
const auto [a, b, c] = get_vrs<f64[2]>(op.ra, op.rb, op.rt);
if (g_cfg.core.llvm_accurate_dfma)
{
value_t<f64[2]> r;
r.value = m_ir->CreateCall(get_intrinsic<f64[2]>(llvm::Intrinsic::fma), {eval(-a).value, b.value, c.value});
set_vr(op.rt, r);
}
set_vr(op.rt, fmuladd(-a, b, c, true));
else
set_vr(op.rt, c - (a * b));
}
@ -7793,11 +7781,7 @@ public:
const auto [a, b, c] = get_vrs<f64[2]>(op.ra, op.rb, op.rt);
if (g_cfg.core.llvm_accurate_dfma)
{
value_t<f64[2]> r;
r.value = m_ir->CreateCall(get_intrinsic<f64[2]>(llvm::Intrinsic::fma), {a.value, b.value, c.value});
set_vr(op.rt, -r);
}
set_vr(op.rt, -fmuladd(a, b, c, true));
else
set_vr(op.rt, -(a * b + c));
}