LLVM DSL: implement absd and match helpers

Matcheable expression absd(a, b) (absolute difference).
This commit is contained in:
Nekotekina 2021-09-05 20:33:19 +03:00
parent 2268aa9093
commit 67b3fc70f8
2 changed files with 61 additions and 1 deletions

View File

@ -2974,6 +2974,13 @@ public:
return {};
}
template <typename T, typename = llvm_common_t<T>>
static auto match_expr(llvm::Value* v, llvm::Module* _m, T&& expr)
{
auto r = expr.match(v, _m);
return std::tuple_cat(std::make_tuple(v != nullptr), r);
}
template <typename T, typename U, typename = llvm_common_t<T, U>>
auto match_expr(T&& arg, U&& expr) -> decltype(std::tuple_cat(std::make_tuple(false), expr.match(std::declval<llvm::Value*&>(), nullptr)))
{
@ -2989,6 +2996,26 @@ public:
return (pred(llvm_placeholder_t<Types>{}) || ...);
}
template <typename T, typename F>
struct expr_t
{
using type = llvm_common_t<T>;
T a;
F match;
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
return a.eval(ir);
}
};
template <typename T, typename F>
static auto expr(T&& expr, F matcher)
{
return expr_t<T, F>{std::forward<T>(expr), std::move(matcher)};
}
template <typename T, typename = std::enable_if_t<is_llvm_cmp<std::decay_t<T>>::value>>
static auto fcmp_ord(T&& cmp_expr)
{
@ -3190,6 +3217,39 @@ public:
return llvm_fmuladd<T, U, V>{std::forward<T>(a), std::forward<U>(b), std::forward<V>(c), m_use_fma};
}
// Absolute difference
template <typename T, typename U, typename CT = llvm_common_t<T, U>>
static auto absd(T&& a, U&& b)
{
return expr(max(a, b) - min(a, b), [](llvm::Value*& value, llvm::Module* _m) -> llvm_match_tuple<T, U>
{
static const auto M = match<CT>();
if (auto [ok, _max, _min] = match_expr(value, _m, M - M); ok)
{
if (auto [ok1, a, b] = match_expr(_max.value, _m, max(M, M)); ok1 && !a.eq(b))
{
if (auto [ok2, c, d] = match_expr(_min.value, _m, min(M, M)); ok2 && !c.eq(d))
{
if ((a.eq(c) && b.eq(d)) || (a.eq(d) && b.eq(c)))
{
if (auto r1 = llvm_expr_t<T>{}.match(a.value, _m); a.eq())
{
if (auto r2 = llvm_expr_t<U>{}.match(b.value, _m); b.eq())
{
return std::tuple_cat(r1, r2);
}
}
}
}
}
}
value = nullptr;
return {};
});
}
template <typename... Types>
llvm::Function* get_intrinsic(llvm::Intrinsic::ID id)
{

View File

@ -6417,7 +6417,7 @@ public:
void ABSDB(spu_opcode_t op)
{
const auto [a, b] = get_vrs<u8[16]>(op.ra, op.rb);
set_vr(op.rt, max(a, b) - min(a, b));
set_vr(op.rt, absd(a, b));
}
template <typename T>