From e28707055b3694eedc6eb8f17d621defd6d492b5 Mon Sep 17 00:00:00 2001 From: Nekotekina Date: Wed, 24 Aug 2022 19:36:37 +0300 Subject: [PATCH] Implement simd_builder for x86 ASMJIT-based tool for building vectorized loops (such as ones in BufferUtils.cpp) --- Utilities/JIT.cpp | 419 +++++++++++++++++++++++ Utilities/JIT.h | 179 ++++++++-- rpcs3/Emu/Cell/PPUThread.cpp | 9 +- rpcs3/Emu/Cell/SPUThread.cpp | 27 +- rpcs3/Emu/RSX/Common/BufferUtils.cpp | 493 +++++++++------------------ 5 files changed, 740 insertions(+), 387 deletions(-) diff --git a/Utilities/JIT.cpp b/Utilities/JIT.cpp index 90449bde54..6794d111bd 100644 --- a/Utilities/JIT.cpp +++ b/Utilities/JIT.cpp @@ -7,6 +7,7 @@ #include "mutex.h" #include "util/vm.hpp" #include "util/asm.hpp" +#include "util/v128.hpp" #include #include @@ -351,6 +352,424 @@ asmjit::inline_runtime::~inline_runtime() utils::memory_protect(m_data, m_size, utils::protection::rx); } +#if defined(ARCH_X64) +asmjit::simd_builder::simd_builder(CodeHolder* ch) noexcept + : native_asm(ch) +{ + _init(true); +} + +void asmjit::simd_builder::_init(bool full) +{ + if (full && utils::has_avx512_icl()) + { + v0 = x86::zmm0; + v1 = x86::zmm1; + v2 = x86::zmm2; + v3 = x86::zmm3; + v4 = x86::zmm4; + v5 = x86::zmm5; + vsize = 64; + } + else if (full && utils::has_avx2()) + { + v0 = x86::ymm0; + v1 = x86::ymm1; + v2 = x86::ymm2; + v3 = x86::ymm3; + v4 = x86::ymm4; + v5 = x86::ymm5; + vsize = 32; + } + else + { + v0 = x86::xmm0; + v1 = x86::xmm1; + v2 = x86::xmm2; + v3 = x86::xmm3; + v4 = x86::xmm4; + v5 = x86::xmm5; + vsize = 16; + } + + if (full && utils::has_avx512()) + { + vmask = -1; + } + else + { + vmask = 0; + } +} + +void asmjit::simd_builder::vec_cleanup_ret() +{ + if (utils::has_avx() && vsize > 16) + this->vzeroupper(); + this->ret(); +} + +void asmjit::simd_builder::vec_set_all_zeros(const Operand& v) +{ + x86::Xmm reg(v.id()); + if (utils::has_avx()) + this->vpxor(reg, reg, reg); + else + this->xorps(reg, reg); +} + +void asmjit::simd_builder::vec_set_all_ones(const Operand& v) +{ + x86::Xmm reg(v.id()); + if (x86::Zmm zr(v.id()); zr == v) + this->vpternlogd(zr, zr, zr, 0xff); + else if (x86::Ymm yr(v.id()); yr == v) + this->vpcmpeqd(yr, yr, yr); + else if (utils::has_avx()) + this->vpcmpeqd(reg, reg, reg); + else + this->pcmpeqd(reg, reg); +} + +void asmjit::simd_builder::vec_set_const(const Operand& v, const v128& val) +{ + if (!val._u) + return vec_set_all_zeros(v); + if (!~val._u) + return vec_set_all_ones(v); + + if (uptr(&val) < 0x8000'0000) + { + // Assume the constant comes from a code or data segment (unsafe) + if (x86::Zmm zr(v.id()); zr == v) + this->vbroadcasti32x4(zr, x86::oword_ptr(uptr(&val))); + else if (x86::Ymm yr(v.id()); yr == v) + this->vbroadcasti128(yr, x86::oword_ptr(uptr(&val))); + else if (utils::has_avx()) + this->vmovaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val))); + else + this->movaps(x86::Xmm(v.id()), x86::oword_ptr(uptr(&val))); + } + else + { + // TODO + fmt::throw_exception("Unexpected constant location"); + } +} + +void asmjit::simd_builder::vec_clobbering_test(u32 esize, const Operand& v, const Operand& rhs) +{ + if (esize == 64) + { + this->emit(x86::Inst::kIdVptestmd, x86::k0, v, rhs); + this->ktestw(x86::k0, x86::k0); + } + else if (esize == 32) + { + this->emit(x86::Inst::kIdVptest, v, rhs); + } + else if (esize == 16 && utils::has_sse41()) + { + this->emit(x86::Inst::kIdPtest, v, rhs); + } + else + { + if (v != rhs) + this->emit(x86::Inst::kIdPand, v, rhs); + if (esize == 16) + this->emit(x86::Inst::kIdPacksswb, v, v); + this->emit(x86::Inst::kIdMovq, x86::rax, v); + if (esize == 16 || esize == 8) + this->test(x86::rax, x86::rax); + else if (esize == 4) + this->test(x86::eax, x86::eax); + else if (esize == 2) + this->test(x86::ax, x86::ax); + else if (esize == 1) + this->test(x86::al, x86::al); + else + fmt::throw_exception("Unimplemented"); + } +} + +asmjit::x86::Mem asmjit::simd_builder::ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index) +{ + switch (ensure(esize)) + { + case 1: return x86::ptr(base, index, 0, 0); + case 2: return x86::ptr(base, index, 1, 0); + case 4: return x86::ptr(base, index, 2, 0); + case 8: return x86::ptr(base, index, 3, 0); + default: fmt::throw_exception("Bad esize"); + } +} + +void asmjit::simd_builder::vec_load_unaligned(u32 esize, const Operand& v, const x86::Mem& src) +{ + ensure(std::has_single_bit(esize)); + ensure(std::has_single_bit(vsize)); + + if (esize == 2) + { + ensure(vsize >= 2); + if (vsize == 2) + vec_set_all_zeros(v); + if (vsize == 2 && utils::has_avx()) + this->emit(x86::Inst::kIdVpinsrw, x86::Xmm(v.id()), x86::Xmm(v.id()), src, Imm(0)); + else if (vsize == 2) + this->emit(x86::Inst::kIdPinsrw, v, src, Imm(0)); + else if (vmask && vmask < 8) + this->emit(x86::Inst::kIdVmovdqu16, v, src); + else + return vec_load_unaligned(vsize, v, src); + } + else if (esize == 4) + { + ensure(vsize >= 4); + if (vsize == 4 && utils::has_avx()) + this->emit(x86::Inst::kIdVmovd, x86::Xmm(v.id()), src); + else if (vsize == 4) + this->emit(x86::Inst::kIdMovd, v, src); + else if (vmask && vmask < 8) + this->emit(x86::Inst::kIdVmovdqu32, v, src); + else + return vec_load_unaligned(vsize, v, src); + } + else if (esize == 8) + { + ensure(vsize >= 8); + if (vsize == 8 && utils::has_avx()) + this->emit(x86::Inst::kIdVmovq, x86::Xmm(v.id()), src); + else if (vsize == 8) + this->emit(x86::Inst::kIdMovq, v, src); + else if (vmask && vmask < 8) + this->emit(x86::Inst::kIdVmovdqu64, v, src); + else + return vec_load_unaligned(vsize, v, src); + } + else if (esize >= 16) + { + ensure(vsize >= 16); + if (utils::has_avx()) + this->emit(x86::Inst::kIdVmovdqu, v, src); + else + this->emit(x86::Inst::kIdMovups, v, src); + } + else + { + fmt::throw_exception("Unimplemented"); + } +} + +void asmjit::simd_builder::vec_store_unaligned(u32 esize, const Operand& v, const x86::Mem& dst) +{ + ensure(std::has_single_bit(esize)); + ensure(std::has_single_bit(vsize)); + + if (esize == 2) + { + ensure(vsize >= 2); + if (vsize == 2 && utils::has_avx()) + this->emit(x86::Inst::kIdVpextrw, dst, x86::Xmm(v.id()), Imm(0)); + else if (vsize == 2 && utils::has_sse41()) + this->emit(x86::Inst::kIdPextrw, dst, v, Imm(0)); + else if (vsize == 2) + this->push(x86::rax), this->pextrw(x86::eax, x86::Xmm(v.id()), 0), this->mov(dst, x86::ax), this->pop(x86::rax); + else if ((vmask && vmask < 8) || vsize >= 64) + this->emit(x86::Inst::kIdVmovdqu16, dst, v); + else + return vec_store_unaligned(vsize, v, dst); + } + else if (esize == 4) + { + ensure(vsize >= 4); + if (vsize == 4 && utils::has_avx()) + this->emit(x86::Inst::kIdVmovd, dst, x86::Xmm(v.id())); + else if (vsize == 4) + this->emit(x86::Inst::kIdMovd, dst, v); + else if ((vmask && vmask < 8) || vsize >= 64) + this->emit(x86::Inst::kIdVmovdqu32, dst, v); + else + return vec_store_unaligned(vsize, v, dst); + } + else if (esize == 8) + { + ensure(vsize >= 8); + if (vsize == 8 && utils::has_avx()) + this->emit(x86::Inst::kIdVmovq, dst, x86::Xmm(v.id())); + else if (vsize == 8) + this->emit(x86::Inst::kIdMovq, dst, v); + else if ((vmask && vmask < 8) || vsize >= 64) + this->emit(x86::Inst::kIdVmovdqu64, dst, v); + else + return vec_store_unaligned(vsize, v, dst); + } + else if (esize >= 16) + { + ensure(vsize >= 16); + if ((vmask && vmask < 8) || vsize >= 64) + this->emit(x86::Inst::kIdVmovdqu64, dst, v); // Not really needed + else if (utils::has_avx()) + this->emit(x86::Inst::kIdVmovdqu, dst, v); + else + this->emit(x86::Inst::kIdMovups, dst, v); + } + else + { + fmt::throw_exception("Unimplemented"); + } +} + +void asmjit::simd_builder::_vec_binary_op(x86::Inst::Id sse_op, x86::Inst::Id vex_op, x86::Inst::Id evex_op, const Operand& dst, const Operand& lhs, const Operand& rhs) +{ + if (utils::has_avx()) + { + if (vex_op == x86::Inst::kIdNone || this->_extraReg.isReg()) + { + this->evex().emit(evex_op, dst, lhs, rhs); + } + else + { + this->emit(vex_op, dst, lhs, rhs); + } + } + else if (dst == lhs) + { + this->emit(sse_op, dst, rhs); + } + else if (dst == rhs) + { + fmt::throw_exception("Unimplemented"); + } + else + { + this->emit(x86::Inst::kIdMovaps, dst, lhs); + this->emit(sse_op, dst, rhs); + } +} + +void asmjit::simd_builder::vec_umin(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs) +{ + using enum x86::Inst::Id; + if (esize == 2) + { + if (utils::has_sse41()) + return _vec_binary_op(kIdPminuw, kIdVpminuw, kIdVpminuw, dst, lhs, rhs); + } + else if (esize == 4) + { + if (utils::has_sse41()) + return _vec_binary_op(kIdPminud, kIdVpminud, kIdVpminud, dst, lhs, rhs); + } + + fmt::throw_exception("Unimplemented"); +} + +void asmjit::simd_builder::vec_umax(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs) +{ + using enum x86::Inst::Id; + if (esize == 2) + { + if (utils::has_sse41()) + return _vec_binary_op(kIdPmaxuw, kIdVpmaxuw, kIdVpmaxuw, dst, lhs, rhs); + } + else if (esize == 4) + { + if (utils::has_sse41()) + return _vec_binary_op(kIdPmaxud, kIdVpmaxud, kIdVpmaxud, dst, lhs, rhs); + } + + fmt::throw_exception("Unimplemented"); +} + +void asmjit::simd_builder::vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp) +{ + using enum x86::Inst::Id; + if (!utils::has_sse41()) + { + fmt::throw_exception("Unimplemented"); + } + + ensure(src != tmp); + + if (esize == 2) + { + this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0)); + } + else if (esize == 4) + { + if (utils::has_avx()) + { + this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8); + this->vpminud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4); + this->vpminud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->vmovd(dst.r32(), x86::Xmm(src.id())); + } + else + { + this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->psrldq(x86::Xmm(tmp.id()), 8); + this->pminud(x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->psrldq(x86::Xmm(src.id()), 4); + this->pminud(x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->movd(dst.r32(), x86::Xmm(src.id())); + } + } + else + { + fmt::throw_exception("Unimplemented"); + } +} + +void asmjit::simd_builder::vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp) +{ + using enum x86::Inst::Id; + if (!utils::has_sse41()) + { + fmt::throw_exception("Unimplemented"); + } + + ensure(src != tmp); + + if (esize == 2) + { + vec_set_all_ones(x86::Xmm(tmp.id())); + vec_xor(esize, x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->emit(utils::has_avx() ? kIdVphminposuw : kIdPhminposuw, x86::Xmm(tmp.id()), x86::Xmm(tmp.id())); + this->emit(utils::has_avx() ? kIdVpextrw : kIdPextrw, dst, x86::Xmm(tmp.id()), Imm(0)); + this->not_(dst.r16()); + } + else if (esize == 4) + { + if (utils::has_avx()) + { + this->vpsrldq(x86::Xmm(tmp.id()), x86::Xmm(src.id()), 8); + this->vpmaxud(x86::Xmm(tmp.id()), x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->vpsrldq(x86::Xmm(src.id()), x86::Xmm(tmp.id()), 4); + this->vpmaxud(x86::Xmm(src.id()), x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->vmovd(dst.r32(), x86::Xmm(src.id())); + } + else + { + this->movdqa(x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->psrldq(x86::Xmm(tmp.id()), 8); + this->pmaxud(x86::Xmm(tmp.id()), x86::Xmm(src.id())); + this->movdqa(x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->psrldq(x86::Xmm(src.id()), 4); + this->pmaxud(x86::Xmm(src.id()), x86::Xmm(tmp.id())); + this->movd(dst.r32(), x86::Xmm(src.id())); + } + } + else + { + fmt::throw_exception("Unimplemented"); + } +} +#endif /* X86 */ + #ifdef LLVM_AVAILABLE #include diff --git a/Utilities/JIT.h b/Utilities/JIT.h index 3fb5fb2bba..dd925a353b 100644 --- a/Utilities/JIT.h +++ b/Utilities/JIT.h @@ -51,6 +51,8 @@ using native_asm = asmjit::a64::Assembler; using native_args = std::array; #endif +union v128; + void jit_announce(uptr func, usz size, std::string_view name); void jit_announce(auto* func, usz size, std::string_view name) @@ -211,40 +213,132 @@ namespace asmjit } #if defined(ARCH_X64) - template - struct native_vec; - - template <> - struct native_vec<16> { using type = x86::Xmm; }; - - template <> - struct native_vec<32> { using type = x86::Ymm; }; - - template <> - struct native_vec<64> { using type = x86::Zmm; }; - - template - using native_vec_t = typename native_vec::type; - - // if (count > step) { for (; ctr < (count - step); ctr += step) {...} count -= ctr; } - inline void build_incomplete_loop(native_asm& c, auto ctr, auto count, u32 step, auto&& build) + struct simd_builder : native_asm { - asmjit::Label body = c.newLabel(); - asmjit::Label exit = c.newLabel(); + Operand v0, v1, v2, v3, v4, v5; - ensure((step & (step - 1)) == 0); - c.cmp(count, step); - c.jbe(exit); - c.sub(count, step); - c.align(asmjit::AlignMode::kCode, 16); - c.bind(body); - build(); - c.add(ctr, step); - c.sub(count, step); - c.ja(body); - c.add(count, step); - c.bind(exit); - } + uint vsize = 16; + uint vmask = 0; + + simd_builder(CodeHolder* ch) noexcept; + + void _init(bool full); + void vec_cleanup_ret(); + void vec_set_all_zeros(const Operand& v); + void vec_set_all_ones(const Operand& v); + void vec_set_const(const Operand& v, const v128& value); + void vec_clobbering_test(u32 esize, const Operand& v, const Operand& rhs); + + // return x86::ptr(base, ctr, X, 0) where X is set for esize accordingly + x86::Mem ptr_scale_for_vec(u32 esize, const x86::Gp& base, const x86::Gp& index); + + void vec_load_unaligned(u32 esize, const Operand& v, const x86::Mem& src); + void vec_store_unaligned(u32 esize, const Operand& v, const x86::Mem& dst); + void vec_partial_move(u32 esize, const Operand& dst, const Operand& src); + + void _vec_binary_op(x86::Inst::Id sse_op, x86::Inst::Id vex_op, x86::Inst::Id evex_op, const Operand& dst, const Operand& lhs, const Operand& rhs); + + void vec_shuffle_xi8(const Operand& dst, const Operand& lhs, const Operand& rhs) + { + using enum x86::Inst::Id; + _vec_binary_op(kIdPshufb, kIdVpshufb, kIdVpshufb, dst, lhs, rhs); + } + + void vec_xor(u32, const Operand& dst, const Operand& lhs, const Operand& rhs) + { + using enum x86::Inst::Id; + _vec_binary_op(kIdPxor, kIdVpxor, kIdVpxord, dst, lhs, rhs); + } + + void vec_or(u32, const Operand& dst, const Operand& lhs, const Operand& rhs) + { + using enum x86::Inst::Id; + _vec_binary_op(kIdPor, kIdVpor, kIdVpord, dst, lhs, rhs); + } + + void vec_umin(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs); + void vec_umax(u32 esize, const Operand& dst, const Operand& lhs, const Operand& rhs); + + void vec_umin_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp); + void vec_umax_horizontal_i128(u32 esize, const x86::Gp& dst, const Operand& src, const Operand& tmp); + + simd_builder& keep_if_not_masked() + { + if (vmask && vmask < 8) + { + this->k(x86::KReg(vmask)); + } + + return *this; + } + + simd_builder& zero_if_not_masked() + { + if (vmask && vmask < 8) + { + this->k(x86::KReg(vmask)); + this->z(); + } + + return *this; + } + + void build_loop(u32 esize, auto reg_ctr, auto reg_cnt, auto&& build, auto&& reduce) + { + ensure((esize & (esize - 1)) == 0); + ensure(esize <= vsize); + + Label body = this->newLabel(); + Label next = this->newLabel(); + Label exit = this->newLabel(); + + const u32 step = vsize / esize; + + this->xor_(reg_ctr.r32(), reg_ctr.r32()); // Reset counter reg + this->sub(reg_cnt, step); + this->jb(next); // If count < step, skip main loop body + this->align(AlignMode::kCode, 16); + this->bind(body); + build(); + this->add(reg_ctr, step); + this->sub(reg_cnt, step); + this->ja(body); + this->bind(next); + if (!vmask) + reduce(); + this->add(reg_cnt, step); + this->jz(exit); + + if (vmask) + { + // Build single last iteration (masked) + static constexpr u64 all_ones = -1; + this->bzhi(reg_cnt, x86::Mem(uptr(&all_ones)), reg_cnt); + this->kmovq(x86::k7, reg_cnt); + vmask = 7; + build(); + vmask = -1; + reduce(); + } + else + { + // Build tail loop (reduced vector width) + Label body = this->newLabel(); + this->align(AlignMode::kCode, 16); + this->bind(body); + const uint vsz = vsize / step; + this->_init(false); + vsize = vsz; + build(); + this->_init(true); + this->inc(reg_ctr); + this->sub(reg_cnt, 1); + this->ja(body); + } + + this->bind(exit); + } + }; // for (; count > 0; ctr++, count--) inline void build_loop(native_asm& c, auto ctr, auto count, auto&& build) @@ -262,6 +356,27 @@ namespace asmjit c.ja(body); c.bind(exit); } + + inline void maybe_flush_lbr(native_asm& c, uint count = 2) + { + // Workaround for bad LBR callstacks which happen in some situations (mainly TSX) - execute additional RETs + Label next = c.newLabel(); + c.lea(x86::rcx, x86::qword_ptr(next)); + + for (u32 i = 0; i < count; i++) + { + c.push(x86::rcx); + c.sub(x86::rcx, 16); + } + + for (u32 i = 0; i < count; i++) + { + c.ret(); + c.align(asmjit::AlignMode::kCode, 16); + } + + c.bind(next); + } #endif } diff --git a/rpcs3/Emu/Cell/PPUThread.cpp b/rpcs3/Emu/Cell/PPUThread.cpp index b15b06f8a5..3df51b6936 100644 --- a/rpcs3/Emu/Cell/PPUThread.cpp +++ b/rpcs3/Emu/Cell/PPUThread.cpp @@ -2394,14 +2394,7 @@ const auto ppu_stcx_accurate_tx = build_function_asm - void build_copy_data_swap_u32_avx3(native_asm& c, native_args& args) + template + void build_copy_data_swap_u32(asmjit::simd_builder& c, native_args& args) { using namespace asmjit; - native_vec_t vdata{0}; - native_vec_t vmask{1}; - native_vec_t vcmp{2}; - // Load and broadcast shuffle mask - if constexpr (!Avx) - c.movaps(vmask, x86::oword_ptr(uptr(&s_bswap_u32_mask))); - if constexpr (Size == 16 && Avx) - c.vmovaps(vmask, x86::oword_ptr(uptr(&s_bswap_u32_mask))); - if constexpr (Size >= 32) - c.vbroadcasti32x4(vmask, x86::oword_ptr(uptr(&s_bswap_u32_mask))); - - // Clear vcmp (bitwise inequality accumulator) - if constexpr (Compare && Avx) - c.vxorps(x86::xmm2, x86::xmm2, x86::xmm2); - if constexpr (Compare && !Avx) - c.xorps(x86::xmm2, x86::xmm2); - c.mov(args[3].r32(), -1); - c.xor_(x86::eax, x86::eax); - - build_incomplete_loop(c, x86::eax, args[2].r32(), Size / 4, [&] + if (utils::has_ssse3()) { - if constexpr (Avx) + c.vec_set_const(c.v1, s_bswap_u32_mask); + } + + // Clear v2 (bitwise inequality accumulator) + if constexpr (Compare) + { + c.vec_set_all_zeros(c.v2); + } + + c.build_loop(sizeof(u32), x86::eax, args[2].r32(), [&] + { + c.zero_if_not_masked().vec_load_unaligned(sizeof(u32), c.v0, c.ptr_scale_for_vec(sizeof(u32), args[1], x86::rax)); + + if (utils::has_ssse3()) { - c.vmovdqu32(vdata, x86::ptr(args[1], x86::rax, 2, 0, Size)); - c.vpshufb(vdata, vdata, vmask); - if constexpr (Compare) - c.vpternlogd(vcmp, vdata, x86::ptr(args[0], x86::rax, 2, 0, Size), 0xf6); // orAxorBC - c.vmovdqu32(x86::ptr(args[0], x86::rax, 2, 0, Size), vdata); + c.vec_shuffle_xi8(c.v0, c.v0, c.v1); } else { - c.movdqu(vdata, x86::oword_ptr(args[1], x86::rax, 2, 0)); - c.pshufb(vdata, vmask); - if constexpr (Compare) + c.emit(x86::Inst::kIdMovdqa, c.v1, c.v0); + c.emit(x86::Inst::kIdPsrlw, c.v0, 8); + c.emit(x86::Inst::kIdPsllw, c.v1, 8); + c.emit(x86::Inst::kIdPor, c.v0, c.v1); + c.emit(x86::Inst::kIdPshuflw, c.v0, c.v0, 0b01001110); + c.emit(x86::Inst::kIdPshufhw, c.v0, c.v0, 0b01001110); + } + + if constexpr (Compare) + { + if (utils::has_avx512()) { - c.movups(x86::xmm3, x86::oword_ptr(args[0], x86::rax, 2, 0)); - c.xorps(x86::xmm3, vdata); - c.orps(vcmp, x86::xmm3); + c.keep_if_not_masked().emit(x86::Inst::kIdVpternlogd, c.v2, c.v0, c.ptr_scale_for_vec(sizeof(u32), args[0], x86::rax), 0xf6); // orAxorBC + } + else + { + c.zero_if_not_masked().vec_load_unaligned(sizeof(u32), c.v3, c.ptr_scale_for_vec(sizeof(u32), args[0], x86::rax)); + c.vec_xor(sizeof(u32), c.v3, c.v3, c.v0); + c.vec_or(sizeof(u32), c.v2, c.v2, c.v3); + } + } + + c.keep_if_not_masked().vec_store_unaligned(sizeof(u32), c.v0, c.ptr_scale_for_vec(sizeof(u32), args[0], x86::rax)); + }, [&] + { + if constexpr (Compare) + { + if (c.vsize == 32 && c.vmask == 0) + { + // Fix for AVX2 path + c.vextracti128(x86::xmm0, x86::ymm2, 1); + c.vpor(x86::xmm2, x86::xmm2, x86::xmm0); } - c.movups(x86::oword_ptr(args[0], x86::rax, 2, 0), vdata); } }); - if constexpr (Avx) + if constexpr (Compare) { - c.bzhi(args[3].r32(), args[3].r32(), args[2].r32()); - c.kmovw(x86::k1, args[3].r32()); - c.k(x86::k1).z().vmovdqu32(vdata, x86::ptr(args[1], x86::rax, 2, 0, Size)); - c.vpshufb(vdata, vdata, vmask); - if constexpr (Compare) - c.k(x86::k1).vpternlogd(vcmp, vdata, x86::ptr(args[0], x86::rax, 2, 0, Size), 0xf6); - c.k(x86::k1).vmovdqu32(x86::ptr(args[0], x86::rax, 2, 0, Size), vdata); - } - else - { - build_loop(c, x86::eax, args[2].r32(), [&] - { - c.movd(vdata, x86::dword_ptr(args[1], x86::rax, 2, 0)); - c.pshufb(vdata, vmask); - if constexpr (Compare) - { - c.movd(x86::xmm3, x86::dword_ptr(args[0], x86::rax, 2, 0)); - c.pxor(x86::xmm3, vdata); - c.por(vcmp, x86::xmm3); - } - c.movd(x86::dword_ptr(args[0], x86::rax, 2, 0), vdata); - }); - } - - if (Compare) - { - if constexpr (!Avx) - { - c.ptest(vcmp, vcmp); - } - else if constexpr (Size != 64) - { - c.vptest(vcmp, vcmp); - } + if (c.vsize == 32 && c.vmask == 0) + c.vec_clobbering_test(16, x86::xmm2, x86::xmm2); else - { - c.vptestmd(x86::k1, vcmp, vcmp); - c.ktestw(x86::k1, x86::k1); - } - + c.vec_clobbering_test(c.vsize, c.v2, c.v2); c.setnz(x86::al); } - if constexpr (Avx) - c.vzeroupper(); - c.ret(); - } - - template - void build_copy_data_swap_u32(native_asm& c, native_args& args) - { - using namespace asmjit; - - if (utils::has_avx512()) - { - if (utils::has_avx512_icl()) - { - build_copy_data_swap_u32_avx3(c, args); - return; - } - - build_copy_data_swap_u32_avx3(c, args); - return; - } - - if (utils::has_sse41()) - { - build_copy_data_swap_u32_avx3(c, args); - return; - } - - c.jmp(©_data_swap_u32_naive); + c.vec_cleanup_ret(); } #elif defined(ARCH_ARM64) template @@ -271,8 +223,8 @@ namespace } #if !defined(__APPLE__) || defined(ARCH_X64) -DECLARE(copy_data_swap_u32) = build_function_asm("copy_data_swap_u32", &build_copy_data_swap_u32); -DECLARE(copy_data_swap_u32_cmp) = build_function_asm("copy_data_swap_u32_cmp", &build_copy_data_swap_u32); +DECLARE(copy_data_swap_u32) = build_function_asm("copy_data_swap_u32", &build_copy_data_swap_u32); +DECLARE(copy_data_swap_u32_cmp) = build_function_asm("copy_data_swap_u32_cmp", &build_copy_data_swap_u32); #else DECLARE(copy_data_swap_u32) = copy_data_swap_u32_naive; DECLARE(copy_data_swap_u32_cmp) = copy_data_swap_u32_naive; @@ -300,228 +252,123 @@ namespace struct untouched_impl { -#if defined(ARCH_X64) - AVX3_FUNC - static - std::tuple upload_u16_swapped_avx3(const void *src, void *dst, u32 count) + template + static u64 upload_untouched_naive(const be_t* src, T* dst, u32 count) { - const __m512i s_bswap_u16_mask512 = _mm512_broadcast_i64x2(s_bswap_u16_mask); + u32 written = 0; + T max_index = 0; + T min_index = -1; - const __m512i s_remainder_mask = _mm512_set_epi16( - 0x20, 0x1F, 0x1E, 0x1D, - 0x1C, 0x1B, 0x1A, 0x19, - 0x18, 0x17, 0x16, 0x15, - 0x14, 0x13, 0x12, 0x11, - 0x10, 0xF, 0xE, 0xD, - 0xC, 0xB, 0xA, 0x9, - 0x8, 0x7, 0x6, 0x5, - 0x4, 0x3, 0x2, 0x1); - - auto src_stream = static_cast(src); - auto dst_stream = static_cast<__m512*>(dst); - - __m512i min = _mm512_set1_epi16(-1); - __m512i max = _mm512_set1_epi16(0); - - const auto iterations = count / 32; - for (unsigned n = 0; n < iterations; ++n) - { - const __m512i raw = _mm512_loadu_si512(src_stream++); - const __m512i value = _mm512_shuffle_epi8(raw, s_bswap_u16_mask512); - max = _mm512_max_epu16(max, value); - min = _mm512_min_epu16(min, value); - _mm512_store_si512(dst_stream++, value); - } - - if ((iterations * 32) < count ) - { - const u16 remainder = (count - (iterations * 32)); - const __m512i remBroadcast = _mm512_set1_epi16(remainder); - const __mmask32 mask = _mm512_cmpge_epi16_mask(remBroadcast, s_remainder_mask); - const __m512i raw = _mm512_maskz_loadu_epi16(mask, src_stream++); - const __m512i value = _mm512_shuffle_epi8(raw, s_bswap_u16_mask512); - max = _mm512_mask_max_epu16(max, mask, max, value); - min = _mm512_mask_min_epu16(min, mask, min, value); - _mm512_mask_storeu_epi16(dst_stream++, mask, value); - } - - __m256i tmp256 = _mm512_extracti64x4_epi64(min, 1); - __m256i min2 = _mm512_castsi512_si256(min); - min2 = _mm256_min_epu16(min2, tmp256); - __m128i tmp = _mm256_extracti128_si256(min2, 1); - __m128i min3 = _mm256_castsi256_si128(min2); - min3 = _mm_min_epu16(min3, tmp); - - tmp256 = _mm512_extracti64x4_epi64(max, 1); - __m256i max2 = _mm512_castsi512_si256(max); - max2 = _mm256_max_epu16(max2, tmp256); - tmp = _mm256_extracti128_si256(max2, 1); - __m128i max3 = _mm256_castsi256_si128(max2); - max3 = _mm_max_epu16(max3, tmp); - - const u16 min_index = sse41_hmin_epu16(min3); - const u16 max_index = sse41_hmax_epu16(max3); - - return std::make_tuple(min_index, max_index, count); - } - - AVX2_FUNC - static - std::tuple upload_u16_swapped_avx2(const void *src, void *dst, u32 count) - { - const __m256i shuffle_mask = _mm256_set_m128i(s_bswap_u16_mask, s_bswap_u16_mask); - - auto src_stream = static_cast(src); - auto dst_stream = static_cast<__m256i*>(dst); - - __m256i min = _mm256_set1_epi16(-1); - __m256i max = _mm256_set1_epi16(0); - - const auto iterations = count / 16; - for (unsigned n = 0; n < iterations; ++n) - { - const __m256i raw = _mm256_loadu_si256(src_stream++); - const __m256i value = _mm256_shuffle_epi8(raw, shuffle_mask); - max = _mm256_max_epu16(max, value); - min = _mm256_min_epu16(min, value); - _mm256_store_si256(dst_stream++, value); - } - - __m128i tmp = _mm256_extracti128_si256(min, 1); - __m128i min2 = _mm256_castsi256_si128(min); - min2 = _mm_min_epu16(min2, tmp); - - tmp = _mm256_extracti128_si256(max, 1); - __m128i max2 = _mm256_castsi256_si128(max); - max2 = _mm_max_epu16(max2, tmp); - - const u16 min_index = sse41_hmin_epu16(min2); - const u16 max_index = sse41_hmax_epu16(max2); - - return std::make_tuple(min_index, max_index, count); - } -#endif - - SSE4_1_FUNC - static - std::tuple upload_u16_swapped_sse4_1(const void *src, void *dst, u32 count) - { - auto src_stream = static_cast(src); - auto dst_stream = static_cast<__m128i*>(dst); - - __m128i min = _mm_set1_epi16(-1); - __m128i max = _mm_set1_epi16(0); - - const auto iterations = count / 8; - for (unsigned n = 0; n < iterations; ++n) - { - const __m128i raw = _mm_loadu_si128(src_stream++); - const __m128i value = _mm_shuffle_epi8(raw, s_bswap_u16_mask); - max = _mm_max_epu16(max, value); - min = _mm_min_epu16(min, value); - _mm_store_si128(dst_stream++, value); - } - - const u16 min_index = sse41_hmin_epu16(min); - const u16 max_index = sse41_hmax_epu16(max); - - return std::make_tuple(min_index, max_index, count); - } - - SSE4_1_FUNC - static - std::tuple upload_u32_swapped_sse4_1(const void *src, void *dst, u32 count) - { - auto src_stream = static_cast(src); - auto dst_stream = static_cast<__m128i*>(dst); - - __m128i min = _mm_set1_epi32(~0u); - __m128i max = _mm_set1_epi32(0); - - const auto iterations = count / 4; - for (unsigned n = 0; n < iterations; ++n) - { - const __m128i raw = _mm_loadu_si128(src_stream++); - const __m128i value = _mm_shuffle_epi8(raw, s_bswap_u32_mask); - max = _mm_max_epu32(max, value); - min = _mm_min_epu32(min, value); - _mm_store_si128(dst_stream++, value); - } - - __m128i tmp = _mm_srli_si128(min, 8); - min = _mm_min_epu32(min, tmp); - tmp = _mm_srli_si128(min, 4); - min = _mm_min_epu32(min, tmp); - - tmp = _mm_srli_si128(max, 8); - max = _mm_max_epu32(max, tmp); - tmp = _mm_srli_si128(max, 4); - max = _mm_max_epu32(max, tmp); - - const u32 min_index = _mm_cvtsi128_si32(min); - const u32 max_index = _mm_cvtsi128_si32(max); - - return std::make_tuple(min_index, max_index, count); - } - - template - static - std::tuple upload_untouched(std::span> src, std::span dst) - { - T min_index, max_index; - u32 written; - u32 remaining = ::size32(src); - - if (s_use_sse4_1 && remaining >= 32) - { - if constexpr (std::is_same::value) - { - const auto count = (remaining & ~0x3); - std::tie(min_index, max_index, written) = upload_u32_swapped_sse4_1(src.data(), dst.data(), count); - } - else if constexpr (std::is_same::value) - { - if (s_use_avx3) - { -#if defined(ARCH_X64) - - // Handle remainder in function - std::tie(min_index, max_index, written) = upload_u16_swapped_avx3(src.data(), dst.data(), remaining); - return std::make_tuple(min_index, max_index, written); - } - else if (s_use_avx2) - { - const auto count = (remaining & ~0xf); - std::tie(min_index, max_index, written) = upload_u16_swapped_avx2(src.data(), dst.data(), count); -#endif - } - else - { - const auto count = (remaining & ~0x7); - std::tie(min_index, max_index, written) = upload_u16_swapped_sse4_1(src.data(), dst.data(), count); - } - } - else - { - fmt::throw_exception("Unreachable"); - } - - remaining -= written; - } - else - { - min_index = index_limit(); - max_index = 0; - written = 0; - } - - while (remaining--) + while (count--) { T index = src[written]; dst[written++] = min_max(min_index, max_index, index); } - return std::make_tuple(min_index, max_index, written); + return (u64{max_index} << 32) | u64{min_index}; + } + +#if defined(ARCH_X64) + template + static void build_upload_untouched(asmjit::simd_builder& c, native_args& args) + { + using namespace asmjit; + + if (!utils::has_sse41()) + { + c.jmp(&upload_untouched_naive); + return; + } + + static const v128 all_ones_except_low_element = gv_shuffle_left(v128::from32p(-1)); + + c.vec_set_const(c.v1, sizeof(T) == 2 ? s_bswap_u16_mask : s_bswap_u32_mask); + c.vec_set_all_ones(c.v2); // vec min + c.vec_set_all_zeros(c.v3); // vec max + c.vec_set_const(c.v4, all_ones_except_low_element); + + c.build_loop(sizeof(T), x86::eax, args[2].r32(), [&] + { + c.zero_if_not_masked().vec_load_unaligned(sizeof(T), c.v0, c.ptr_scale_for_vec(sizeof(T), args[0], x86::rax)); + + if (utils::has_ssse3()) + { + c.vec_shuffle_xi8(c.v0, c.v0, c.v1); + } + else + { + c.emit(x86::Inst::kIdMovdqa, c.v1, c.v0); + c.emit(x86::Inst::kIdPsrlw, c.v0, 8); + c.emit(x86::Inst::kIdPsllw, c.v1, 8); + c.emit(x86::Inst::kIdPor, c.v0, c.v1); + + if constexpr (sizeof(T) == 4) + { + c.emit(x86::Inst::kIdPshuflw, c.v0, c.v0, 0b01001110); + c.emit(x86::Inst::kIdPshufhw, c.v0, c.v0, 0b01001110); + } + } + + c.keep_if_not_masked().vec_umax(sizeof(T), c.v3, c.v3, c.v0); + + if (c.vsize < 16) + { + // In remaining loop: protect min values + c.vec_or(sizeof(T), c.v5, c.v0, c.v4); + c.vec_umin(sizeof(T), c.v2, c.v5, c.v4); + } + else + { + c.keep_if_not_masked().vec_umin(sizeof(T), c.v2, c.v2, c.v0); + } + + c.keep_if_not_masked().vec_store_unaligned(sizeof(T), c.v0, c.ptr_scale_for_vec(sizeof(T), args[1], x86::rax)); + }, [&] + { + // Compress to xmm, protect high values + if (c.vsize >= 64) + { + c.vextracti32x8(x86::ymm0, x86::zmm3, 1); + c.emit(sizeof(T) == 4 ? x86::Inst::kIdVpmaxud : x86::Inst::kIdVpmaxuw, x86::ymm3, x86::ymm3, x86::ymm0); + c.vextracti32x8(x86::ymm0, x86::zmm2, 1); + c.emit(sizeof(T) == 4 ? x86::Inst::kIdVpminud : x86::Inst::kIdVpminuw, x86::ymm2, x86::ymm2, x86::ymm0); + } + if (c.vsize >= 32) + { + c.vextracti128(x86::xmm0, x86::ymm3, 1); + c.emit(sizeof(T) == 4 ? x86::Inst::kIdVpmaxud : x86::Inst::kIdVpmaxuw, x86::xmm3, x86::xmm3, x86::xmm0); + c.vextracti128(x86::xmm0, x86::ymm2, 1); + c.emit(sizeof(T) == 4 ? x86::Inst::kIdVpminud : x86::Inst::kIdVpminuw, x86::xmm2, x86::xmm2, x86::xmm0); + } + }); + + c.vec_umax_horizontal_i128(sizeof(T), x86::rdx, c.v3, c.v0); + c.vec_umin_horizontal_i128(sizeof(T), x86::rax, c.v2, c.v0); + c.shl(x86::rdx, 32); + c.or_(x86::rax, x86::rdx); + c.vec_cleanup_ret(); + } + + static inline auto upload_xi16 = build_function_asm*, u16*, u32), asmjit::simd_builder>("untouched_upload_xi16", &build_upload_untouched); + static inline auto upload_xi32 = build_function_asm*, u32*, u32), asmjit::simd_builder>("untouched_upload_xi32", &build_upload_untouched); +#endif + + template + static std::tuple upload_untouched(std::span> src, std::span dst) + { + T min_index, max_index; + u32 count = ::size32(src); + u64 r; + + if constexpr (sizeof(T) == 2) + r = upload_xi16(src.data(), dst.data(), count); + else + r = upload_xi32(src.data(), dst.data(), count); + + min_index = r; + max_index = r >> 32; + + return std::make_tuple(min_index, max_index, count); } };