LLVM DSL: reimplement avg

This commit is contained in:
Nekotekina 2021-08-31 19:35:45 +03:00
parent 95c36221fa
commit 38dfc88e8d

View File

@ -2512,6 +2512,85 @@ struct llvm_ctpop
}
};
template <typename A1, typename A2, typename T = llvm_common_t<A1, A2>>
struct llvm_avg
{
using type = T;
llvm_expr_t<A1> a1;
llvm_expr_t<A2> a2;
static_assert(llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint, "llvm_avg<>: invalid type");
static constexpr bool is_ok = llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint;
static constexpr auto cast_op = llvm_value_t<T>::is_sint ? llvm::Instruction::SExt : llvm::Instruction::ZExt;
static llvm::Type* cast_dst_type(llvm::LLVMContext& context)
{
llvm::Type* cast_to = llvm::IntegerType::get(context, llvm_value_t<T>::esize * 2);
if constexpr (llvm_value_t<T>::is_vector != 0)
cast_to = llvm::VectorType::get(cast_to, llvm_value_t<T>::is_vector, false);
return cast_to;
}
llvm::Value* eval(llvm::IRBuilder<>* ir) const
{
const auto v1 = a1.eval(ir);
const auto v2 = a2.eval(ir);
const auto dty = cast_dst_type(ir->getContext());
const auto axt = ir->CreateCast(cast_op, v1, dty);
const auto bxt = ir->CreateCast(cast_op, v2, dty);
const auto cxt = llvm::ConstantInt::get(dty, 1, false);
const auto abc = ir->CreateAdd(ir->CreateAdd(axt, bxt), cxt);
return ir->CreateTrunc(ir->CreateLShr(abc, 1), llvm_value_t<T>::get_type(ir->getContext()));
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
const auto dty = cast_dst_type(value->getContext());
if (auto i = llvm::dyn_cast_or_null<llvm::CastInst>(value); i && i->getOpcode() == llvm::Instruction::Trunc && i->getSrcTy() == dty)
{
const auto cxt = llvm::ConstantInt::get(dty, 1, false);
if (auto j = llvm::dyn_cast_or_null<llvm::BinaryOperator>(i->getOperand(0)); j && j->getOpcode() == llvm::Instruction::LShr && j->getOperand(1) == cxt)
{
if (j = llvm::dyn_cast_or_null<llvm::BinaryOperator>(j->getOperand(0)); j && j->getOpcode() == llvm::Instruction::Add && j->getOperand(1) == cxt)
{
if (j = llvm::dyn_cast_or_null<llvm::BinaryOperator>(j->getOperand(0)); j && j->getOpcode() == llvm::Instruction::Add)
{
auto a = llvm::dyn_cast_or_null<llvm::CastInst>(j->getOperand(0));
auto b = llvm::dyn_cast_or_null<llvm::CastInst>(j->getOperand(1));
if (a && b && a->getOpcode() == cast_op && b->getOpcode() == cast_op)
{
v1 = a->getOperand(0);
v2 = b->getOperand(0);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
}
}
}
}
value = nullptr;
return {};
}
};
class cpu_translator
{
protected:
@ -2805,24 +2884,10 @@ public:
}
// Average: (a + b + 1) >> 1
template <typename T>
inline auto avg(T a, T b)
template <typename T, typename U, typename = std::enable_if_t<llvm_avg<T, U>::is_ok>>
static auto avg(T&& a, U&& b)
{
//return (a >> 1) + (b >> 1) + ((a | b) & 1);
value_t<typename T::type> result;
static_assert(result.is_sint || result.is_uint);
const auto cast_op = result.is_sint ? llvm::Instruction::SExt : llvm::Instruction::ZExt;
llvm::Type* cast_to = m_ir->getIntNTy(result.esize * 2);
if constexpr (result.is_vector != 0)
cast_to = llvm::VectorType::get(cast_to, result.is_vector, false);
const auto axt = m_ir->CreateCast(cast_op, a.eval(m_ir), cast_to);
const auto bxt = m_ir->CreateCast(cast_op, b.eval(m_ir), cast_to);
const auto cxt = llvm::ConstantInt::get(cast_to, 1, false);
const auto abc = m_ir->CreateAdd(m_ir->CreateAdd(axt, bxt), cxt);
result.value = m_ir->CreateTrunc(m_ir->CreateLShr(abc, 1), result.get_type(m_context));
return result;
return llvm_avg<T, U>{std::forward<T>(a), std::forward<U>(b)};
}
template <typename... Types>