mirror of
https://github.com/RPCS3/rpcs3.git
synced 2025-03-13 07:14:49 +00:00
atomic.cpp: implement some atomic wait operations.
Instead of plain waiting while equal to some value, it can be something like less, or greater, or even bitcount. But it's a draft and untested. Hopefully doesn't break anything.
This commit is contained in:
parent
829a697c39
commit
7cd1e767be
@ -20,6 +20,7 @@
|
||||
#include <random>
|
||||
|
||||
#include "asm.hpp"
|
||||
#include "endian.hpp"
|
||||
|
||||
// Total number of entries, should be a power of 2.
|
||||
static constexpr std::size_t s_hashtable_size = 1u << 18;
|
||||
@ -30,13 +31,18 @@ static thread_local bool(*s_tls_wait_cb)(const void* data) = [](const void*){ re
|
||||
// Callback for notification functions for optimizations
|
||||
static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = [](const void*, u64){};
|
||||
|
||||
static inline bool operator &(atomic_wait::op lhs, atomic_wait::op_flag rhs)
|
||||
{
|
||||
return !!(static_cast<u8>(lhs) & static_cast<u8>(rhs));
|
||||
}
|
||||
|
||||
// Compare data in memory with old value, and return true if they are equal
|
||||
template <bool CheckCb = true>
|
||||
static NEVER_INLINE bool
|
||||
#ifdef _WIN32
|
||||
__vectorcall
|
||||
#endif
|
||||
ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr)
|
||||
ptr_cmp(const void* data, u32 _size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr)
|
||||
{
|
||||
if constexpr (CheckCb)
|
||||
{
|
||||
@ -46,32 +52,138 @@ ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait
|
||||
}
|
||||
}
|
||||
|
||||
const u64 old_value = _mm_cvtsi128_si64(old128);
|
||||
const u64 mask = _mm_cvtsi128_si64(mask128);
|
||||
using atomic_wait::op;
|
||||
using atomic_wait::op_flag;
|
||||
|
||||
const u8 size = static_cast<u8>(_size);
|
||||
const op flag{static_cast<u8>(_size >> 8)};
|
||||
|
||||
bool result = false;
|
||||
|
||||
switch (size)
|
||||
if (size <= 8)
|
||||
{
|
||||
case 1: result = (reinterpret_cast<const atomic_t<u8>*>(data)->load() & mask) == (old_value & mask); break;
|
||||
case 2: result = (reinterpret_cast<const atomic_t<u16>*>(data)->load() & mask) == (old_value & mask); break;
|
||||
case 4: result = (reinterpret_cast<const atomic_t<u32>*>(data)->load() & mask) == (old_value & mask); break;
|
||||
case 8: result = (reinterpret_cast<const atomic_t<u64>*>(data)->load() & mask) == (old_value & mask); break;
|
||||
case 16:
|
||||
{
|
||||
const auto v0 = std::bit_cast<__m128i>(atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data)));
|
||||
const auto v1 = _mm_xor_si128(v0, old128);
|
||||
const auto v2 = _mm_and_si128(v1, mask128);
|
||||
const auto v3 = _mm_packs_epi16(v2, v2);
|
||||
u64 new_value = 0;
|
||||
u64 old_value = _mm_cvtsi128_si64(old128);
|
||||
u64 mask = _mm_cvtsi128_si64(mask128) & (UINT64_MAX >> ((64 - size * 8) & 63));
|
||||
|
||||
result = _mm_cvtsi128_si64(v3) == 0;
|
||||
break;
|
||||
switch (size)
|
||||
{
|
||||
case 1: new_value = reinterpret_cast<const atomic_t<u8>*>(data)->load(); break;
|
||||
case 2: new_value = reinterpret_cast<const atomic_t<u16>*>(data)->load(); break;
|
||||
case 4: new_value = reinterpret_cast<const atomic_t<u32>*>(data)->load(); break;
|
||||
case 8: new_value = reinterpret_cast<const atomic_t<u64>*>(data)->load(); break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "ptr_cmp(): bad size (arg=0x%x)" HERE "\n", _size);
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
|
||||
if (flag & op_flag::bit_not)
|
||||
{
|
||||
new_value = ~new_value;
|
||||
}
|
||||
|
||||
if (!mask) [[unlikely]]
|
||||
{
|
||||
new_value = 0;
|
||||
old_value = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (flag & op_flag::byteswap)
|
||||
{
|
||||
switch (size)
|
||||
{
|
||||
case 2:
|
||||
{
|
||||
new_value = stx::se_storage<u16>::swap(static_cast<u16>(new_value));
|
||||
old_value = stx::se_storage<u16>::swap(static_cast<u16>(old_value));
|
||||
mask = stx::se_storage<u16>::swap(static_cast<u16>(mask));
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
new_value = stx::se_storage<u32>::swap(static_cast<u32>(new_value));
|
||||
old_value = stx::se_storage<u32>::swap(static_cast<u32>(old_value));
|
||||
mask = stx::se_storage<u32>::swap(static_cast<u32>(mask));
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
{
|
||||
new_value = stx::se_storage<u64>::swap(new_value);
|
||||
old_value = stx::se_storage<u64>::swap(old_value);
|
||||
mask = stx::se_storage<u64>::swap(mask);
|
||||
}
|
||||
default:
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make most significant bit sign bit
|
||||
const auto shv = std::countl_zero(mask);
|
||||
new_value &= mask;
|
||||
old_value &= mask;
|
||||
new_value <<= shv;
|
||||
old_value <<= shv;
|
||||
}
|
||||
|
||||
s64 news = new_value;
|
||||
s64 olds = old_value;
|
||||
|
||||
u64 newa = news < 0 ? (0ull - new_value) : new_value;
|
||||
u64 olda = olds < 0 ? (0ull - old_value) : old_value;
|
||||
|
||||
switch (op{static_cast<u8>(static_cast<u8>(flag) & 0xf)})
|
||||
{
|
||||
case op::eq: result = old_value == new_value; break;
|
||||
case op::slt: result = olds < news; break;
|
||||
case op::sgt: result = olds > news; break;
|
||||
case op::ult: result = old_value < new_value; break;
|
||||
case op::ugt: result = old_value > new_value; break;
|
||||
case op::alt: result = olda < newa; break;
|
||||
case op::agt: result = olda > newa; break;
|
||||
case op::pop:
|
||||
{
|
||||
// Count is taken from least significant byte and ignores some flags
|
||||
const u64 count = _mm_cvtsi128_si64(old128) & 0xff;
|
||||
|
||||
u64 bitc = new_value;
|
||||
bitc = (bitc & 0xaaaaaaaaaaaaaaaa) / 2 + (bitc & 0x5555555555555555);
|
||||
bitc = (bitc & 0xcccccccccccccccc) / 4 + (bitc & 0x3333333333333333);
|
||||
bitc = (bitc & 0xf0f0f0f0f0f0f0f0) / 16 + (bitc & 0x0f0f0f0f0f0f0f0f);
|
||||
bitc = (bitc & 0xff00ff00ff00ff00) / 256 + (bitc & 0x00ff00ff00ff00ff);
|
||||
bitc = ((bitc & 0xffff0000ffff0000) >> 16) + (bitc & 0x0000ffff0000ffff);
|
||||
bitc = (bitc >> 32) + bitc;
|
||||
|
||||
result = count < bitc;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
fmt::raw_error("ptr_cmp(): unrecognized atomic wait operation.");
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
else if (size == 16 && (flag == op::eq || flag == (op::eq | op_flag::inverse)))
|
||||
{
|
||||
fprintf(stderr, "ptr_cmp(): bad size (size=%u)" HERE "\n", size);
|
||||
std::abort();
|
||||
u128 new_value = atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data));
|
||||
u128 old_value = std::bit_cast<u128>(old128);
|
||||
u128 mask = std::bit_cast<u128>(mask128);
|
||||
|
||||
// TODO
|
||||
result = !((old_value ^ new_value) & mask);
|
||||
}
|
||||
else if (size == 16)
|
||||
{
|
||||
fmt::raw_error("ptr_cmp(): no alternative operations are supported for 16-byte atomic wait yet.");
|
||||
}
|
||||
|
||||
if (flag & op_flag::inverse)
|
||||
{
|
||||
result = !result;
|
||||
}
|
||||
|
||||
// Check other wait variables if provided
|
||||
@ -101,16 +213,8 @@ __vectorcall
|
||||
#endif
|
||||
cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m128i val2)
|
||||
{
|
||||
// In force wake up, one of the size arguments is zero (obsolete)
|
||||
const u32 size = std::min(size1, size2);
|
||||
|
||||
if (!size) [[unlikely]]
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Compare only masks, new value is not available in this mode
|
||||
if ((size1 | size2) == umax)
|
||||
if (size1 == umax)
|
||||
{
|
||||
// Simple mask overlap
|
||||
const auto v0 = _mm_and_si128(mask1, mask2);
|
||||
@ -121,6 +225,17 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
|
||||
// Generate masked value inequality bits
|
||||
const auto v0 = _mm_and_si128(_mm_and_si128(mask1, mask2), _mm_xor_si128(val1, val2));
|
||||
|
||||
using atomic_wait::op;
|
||||
using atomic_wait::op_flag;
|
||||
|
||||
const u8 size = std::min<u8>(static_cast<u8>(size2), static_cast<u8>(size1));
|
||||
const op flag{static_cast<u8>(size2 >> 8)};
|
||||
|
||||
if (flag != op::eq && flag != (op::eq | op_flag::inverse))
|
||||
{
|
||||
fmt::raw_error("cmp_mask(): no operations are supported for notification with forced value yet.");
|
||||
}
|
||||
|
||||
if (size <= 8)
|
||||
{
|
||||
// Generate sized mask
|
||||
@ -128,14 +243,14 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
|
||||
|
||||
if (!(_mm_cvtsi128_si64(v0) & mask))
|
||||
{
|
||||
return 0;
|
||||
return flag & op_flag::inverse ? 2 : 0;
|
||||
}
|
||||
}
|
||||
else if (size == 16)
|
||||
{
|
||||
if (!_mm_cvtsi128_si64(_mm_packs_epi16(v0, v0)))
|
||||
{
|
||||
return 0;
|
||||
return flag & op_flag::inverse ? 2 : 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -145,7 +260,7 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
|
||||
}
|
||||
|
||||
// Use force wake-up
|
||||
return 2;
|
||||
return flag & op_flag::inverse ? 0 : 2;
|
||||
}
|
||||
|
||||
static atomic_t<u64> s_min_tsc{0};
|
||||
@ -227,7 +342,8 @@ namespace atomic_wait
|
||||
// Temporarily reduced unique tsc stamp to 48 bits to make space for refs (TODO)
|
||||
u64 tsc0 : 48 = 0;
|
||||
u64 link : 16 = 0;
|
||||
u16 size{};
|
||||
u8 size{};
|
||||
u8 flag{};
|
||||
atomic_t<u16> refs{};
|
||||
atomic_t<u32> sync{};
|
||||
|
||||
@ -262,6 +378,7 @@ namespace atomic_wait
|
||||
tsc0 = 0;
|
||||
link = 0;
|
||||
size = 0;
|
||||
flag = 0;
|
||||
sync = 0;
|
||||
|
||||
#ifdef USE_STD
|
||||
@ -868,7 +985,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
|
||||
|
||||
// Store some info for notifiers (some may be unused)
|
||||
cond->link = 0;
|
||||
cond->size = static_cast<u16>(size);
|
||||
cond->size = static_cast<u8>(size);
|
||||
cond->flag = static_cast<u8>(size >> 8);
|
||||
cond->mask = mask;
|
||||
cond->oldv = old_value;
|
||||
cond->tsc0 = stamp0;
|
||||
@ -877,7 +995,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
|
||||
{
|
||||
// Extensions point to original cond_id, copy remaining info
|
||||
cond_ext[i]->link = cond_id;
|
||||
cond_ext[i]->size = static_cast<u16>(ext[i].size);
|
||||
cond_ext[i]->size = static_cast<u8>(ext[i].size);
|
||||
cond_ext[i]->flag = static_cast<u8>(ext[i].size >> 8);
|
||||
cond_ext[i]->mask = ext[i].mask;
|
||||
cond_ext[i]->oldv = ext[i].old;
|
||||
cond_ext[i]->tsc0 = stamp0;
|
||||
@ -1058,7 +1177,7 @@ alert_sema(u32 cond_id, const void* data, u64 info, u32 size, __m128i mask, __m1
|
||||
|
||||
u32 cmp_res = 0;
|
||||
|
||||
if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size, cond->mask, cond->oldv))))))
|
||||
if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size | (cond->flag << 8), cond->mask, cond->oldv))))))
|
||||
{
|
||||
// Redirect if necessary
|
||||
const auto _old = cond;
|
||||
|
@ -14,14 +14,56 @@ enum class atomic_wait_timeout : u64
|
||||
inf = 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
// Unused externally
|
||||
// Various extensions for atomic_t::wait
|
||||
namespace atomic_wait
|
||||
{
|
||||
// Max number of simultaneous atomic variables to wait on (can be extended if really necessary)
|
||||
constexpr uint max_list = 8;
|
||||
|
||||
struct root_info;
|
||||
struct sema_handle;
|
||||
|
||||
enum class op : u8
|
||||
{
|
||||
eq, // Wait while value is bitwise equal to
|
||||
slt, // Wait while signed value is less than
|
||||
sgt, // Wait while signed value is greater than
|
||||
ult, // Wait while unsigned value is less than
|
||||
ugt, // Wait while unsigned value is greater than
|
||||
alt, // Wait while absolute value is less than
|
||||
agt, // Wait while absolute value is greater than
|
||||
pop, // Wait while set bit count of the value is less than
|
||||
__max
|
||||
};
|
||||
|
||||
static_assert(static_cast<u8>(op::__max) == 8);
|
||||
|
||||
enum class op_flag : u8
|
||||
{
|
||||
inverse = 1 << 4, // Perform inverse operation (negate the result)
|
||||
bit_not = 1 << 5, // Perform bitwise NOT on loaded value before operation
|
||||
byteswap = 1 << 6, // Perform byteswap on both arguments and masks when applicable
|
||||
};
|
||||
|
||||
constexpr op_flag op_ne = {};
|
||||
constexpr op_flag op_be = std::endian::native == std::endian::little ? op_flag::byteswap : op_flag{0};
|
||||
constexpr op_flag op_le = std::endian::native == std::endian::little ? op_flag{0} : op_flag::byteswap;
|
||||
|
||||
constexpr op operator |(op_flag lhs, op_flag rhs)
|
||||
{
|
||||
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
|
||||
}
|
||||
|
||||
constexpr op operator |(op_flag lhs, op rhs)
|
||||
{
|
||||
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
|
||||
}
|
||||
|
||||
constexpr op operator |(op lhs, op_flag rhs)
|
||||
{
|
||||
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
|
||||
}
|
||||
|
||||
struct info
|
||||
{
|
||||
const void* data;
|
||||
@ -114,24 +156,24 @@ namespace atomic_wait
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <uint Index, typename T2, std::size_t Align, typename U>
|
||||
template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U>
|
||||
constexpr void set(atomic_t<T2, Align>& var, U value)
|
||||
{
|
||||
static_assert(Index < Max);
|
||||
|
||||
m_info[Index].data = &var.raw();
|
||||
m_info[Index].size = sizeof(T2);
|
||||
m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
|
||||
m_info[Index].template set_value<T2>(value);
|
||||
m_info[Index].mask = _mm_set1_epi64x(-1);
|
||||
}
|
||||
|
||||
template <uint Index, typename T2, std::size_t Align, typename U, typename V>
|
||||
template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U, typename V>
|
||||
constexpr void set(atomic_t<T2, Align>& var, U value, V mask)
|
||||
{
|
||||
static_assert(Index < Max);
|
||||
|
||||
m_info[Index].data = &var.raw();
|
||||
m_info[Index].size = sizeof(T2);
|
||||
m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
|
||||
m_info[Index].template set_value<T2>(value);
|
||||
m_info[Index].template set_mask<T2>(mask);
|
||||
}
|
||||
@ -1387,34 +1429,36 @@ public:
|
||||
}
|
||||
|
||||
// Timeout is discouraged
|
||||
template <atomic_wait::op Flags = atomic_wait::op::eq>
|
||||
void wait(type old_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) const noexcept
|
||||
{
|
||||
if constexpr (sizeof(T) <= 8)
|
||||
{
|
||||
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
|
||||
}
|
||||
else if constexpr (sizeof(T) == 16)
|
||||
{
|
||||
const __m128i old = std::bit_cast<__m128i>(old_value);
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
|
||||
}
|
||||
}
|
||||
|
||||
// Overload with mask (only selected bits are checked), timeout is discouraged
|
||||
template <atomic_wait::op Flags = atomic_wait::op::eq>
|
||||
void wait(type old_value, type mask_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf)
|
||||
{
|
||||
if constexpr (sizeof(T) <= 8)
|
||||
{
|
||||
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
|
||||
const __m128i mask = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(mask_value));
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask);
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
|
||||
}
|
||||
else if constexpr (sizeof(T) == 16)
|
||||
{
|
||||
const __m128i old = std::bit_cast<__m128i>(old_value);
|
||||
const __m128i mask = std::bit_cast<__m128i>(mask_value);
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask);
|
||||
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user