atomic.cpp: implement some atomic wait operations.

Instead of plain waiting while equal to some value,
it can be something like less, or greater, or even bitcount.
But it's a draft and untested. Hopefully doesn't break anything.
This commit is contained in:
Nekotekina 2020-11-11 10:25:22 +03:00
parent 829a697c39
commit 7cd1e767be
2 changed files with 207 additions and 44 deletions

View File

@ -20,6 +20,7 @@
#include <random>
#include "asm.hpp"
#include "endian.hpp"
// Total number of entries, should be a power of 2.
static constexpr std::size_t s_hashtable_size = 1u << 18;
@ -30,13 +31,18 @@ static thread_local bool(*s_tls_wait_cb)(const void* data) = [](const void*){ re
// Callback for notification functions for optimizations
static thread_local void(*s_tls_notify_cb)(const void* data, u64 progress) = [](const void*, u64){};
static inline bool operator &(atomic_wait::op lhs, atomic_wait::op_flag rhs)
{
return !!(static_cast<u8>(lhs) & static_cast<u8>(rhs));
}
// Compare data in memory with old value, and return true if they are equal
template <bool CheckCb = true>
static NEVER_INLINE bool
#ifdef _WIN32
__vectorcall
#endif
ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr)
ptr_cmp(const void* data, u32 _size, __m128i old128, __m128i mask128, atomic_wait::info* ext = nullptr)
{
if constexpr (CheckCb)
{
@ -46,32 +52,138 @@ ptr_cmp(const void* data, u32 size, __m128i old128, __m128i mask128, atomic_wait
}
}
const u64 old_value = _mm_cvtsi128_si64(old128);
const u64 mask = _mm_cvtsi128_si64(mask128);
using atomic_wait::op;
using atomic_wait::op_flag;
const u8 size = static_cast<u8>(_size);
const op flag{static_cast<u8>(_size >> 8)};
bool result = false;
switch (size)
if (size <= 8)
{
case 1: result = (reinterpret_cast<const atomic_t<u8>*>(data)->load() & mask) == (old_value & mask); break;
case 2: result = (reinterpret_cast<const atomic_t<u16>*>(data)->load() & mask) == (old_value & mask); break;
case 4: result = (reinterpret_cast<const atomic_t<u32>*>(data)->load() & mask) == (old_value & mask); break;
case 8: result = (reinterpret_cast<const atomic_t<u64>*>(data)->load() & mask) == (old_value & mask); break;
case 16:
{
const auto v0 = std::bit_cast<__m128i>(atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data)));
const auto v1 = _mm_xor_si128(v0, old128);
const auto v2 = _mm_and_si128(v1, mask128);
const auto v3 = _mm_packs_epi16(v2, v2);
u64 new_value = 0;
u64 old_value = _mm_cvtsi128_si64(old128);
u64 mask = _mm_cvtsi128_si64(mask128) & (UINT64_MAX >> ((64 - size * 8) & 63));
result = _mm_cvtsi128_si64(v3) == 0;
break;
switch (size)
{
case 1: new_value = reinterpret_cast<const atomic_t<u8>*>(data)->load(); break;
case 2: new_value = reinterpret_cast<const atomic_t<u16>*>(data)->load(); break;
case 4: new_value = reinterpret_cast<const atomic_t<u32>*>(data)->load(); break;
case 8: new_value = reinterpret_cast<const atomic_t<u64>*>(data)->load(); break;
default:
{
fprintf(stderr, "ptr_cmp(): bad size (arg=0x%x)" HERE "\n", _size);
std::abort();
}
}
if (flag & op_flag::bit_not)
{
new_value = ~new_value;
}
if (!mask) [[unlikely]]
{
new_value = 0;
old_value = 0;
}
else
{
if (flag & op_flag::byteswap)
{
switch (size)
{
case 2:
{
new_value = stx::se_storage<u16>::swap(static_cast<u16>(new_value));
old_value = stx::se_storage<u16>::swap(static_cast<u16>(old_value));
mask = stx::se_storage<u16>::swap(static_cast<u16>(mask));
break;
}
case 4:
{
new_value = stx::se_storage<u32>::swap(static_cast<u32>(new_value));
old_value = stx::se_storage<u32>::swap(static_cast<u32>(old_value));
mask = stx::se_storage<u32>::swap(static_cast<u32>(mask));
break;
}
case 8:
{
new_value = stx::se_storage<u64>::swap(new_value);
old_value = stx::se_storage<u64>::swap(old_value);
mask = stx::se_storage<u64>::swap(mask);
}
default:
{
break;
}
}
}
// Make most significant bit sign bit
const auto shv = std::countl_zero(mask);
new_value &= mask;
old_value &= mask;
new_value <<= shv;
old_value <<= shv;
}
s64 news = new_value;
s64 olds = old_value;
u64 newa = news < 0 ? (0ull - new_value) : new_value;
u64 olda = olds < 0 ? (0ull - old_value) : old_value;
switch (op{static_cast<u8>(static_cast<u8>(flag) & 0xf)})
{
case op::eq: result = old_value == new_value; break;
case op::slt: result = olds < news; break;
case op::sgt: result = olds > news; break;
case op::ult: result = old_value < new_value; break;
case op::ugt: result = old_value > new_value; break;
case op::alt: result = olda < newa; break;
case op::agt: result = olda > newa; break;
case op::pop:
{
// Count is taken from least significant byte and ignores some flags
const u64 count = _mm_cvtsi128_si64(old128) & 0xff;
u64 bitc = new_value;
bitc = (bitc & 0xaaaaaaaaaaaaaaaa) / 2 + (bitc & 0x5555555555555555);
bitc = (bitc & 0xcccccccccccccccc) / 4 + (bitc & 0x3333333333333333);
bitc = (bitc & 0xf0f0f0f0f0f0f0f0) / 16 + (bitc & 0x0f0f0f0f0f0f0f0f);
bitc = (bitc & 0xff00ff00ff00ff00) / 256 + (bitc & 0x00ff00ff00ff00ff);
bitc = ((bitc & 0xffff0000ffff0000) >> 16) + (bitc & 0x0000ffff0000ffff);
bitc = (bitc >> 32) + bitc;
result = count < bitc;
break;
}
default:
{
fmt::raw_error("ptr_cmp(): unrecognized atomic wait operation.");
}
}
}
default:
else if (size == 16 && (flag == op::eq || flag == (op::eq | op_flag::inverse)))
{
fprintf(stderr, "ptr_cmp(): bad size (size=%u)" HERE "\n", size);
std::abort();
u128 new_value = atomic_storage<u128>::load(*reinterpret_cast<const u128*>(data));
u128 old_value = std::bit_cast<u128>(old128);
u128 mask = std::bit_cast<u128>(mask128);
// TODO
result = !((old_value ^ new_value) & mask);
}
else if (size == 16)
{
fmt::raw_error("ptr_cmp(): no alternative operations are supported for 16-byte atomic wait yet.");
}
if (flag & op_flag::inverse)
{
result = !result;
}
// Check other wait variables if provided
@ -101,16 +213,8 @@ __vectorcall
#endif
cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m128i val2)
{
// In force wake up, one of the size arguments is zero (obsolete)
const u32 size = std::min(size1, size2);
if (!size) [[unlikely]]
{
return 2;
}
// Compare only masks, new value is not available in this mode
if ((size1 | size2) == umax)
if (size1 == umax)
{
// Simple mask overlap
const auto v0 = _mm_and_si128(mask1, mask2);
@ -121,6 +225,17 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
// Generate masked value inequality bits
const auto v0 = _mm_and_si128(_mm_and_si128(mask1, mask2), _mm_xor_si128(val1, val2));
using atomic_wait::op;
using atomic_wait::op_flag;
const u8 size = std::min<u8>(static_cast<u8>(size2), static_cast<u8>(size1));
const op flag{static_cast<u8>(size2 >> 8)};
if (flag != op::eq && flag != (op::eq | op_flag::inverse))
{
fmt::raw_error("cmp_mask(): no operations are supported for notification with forced value yet.");
}
if (size <= 8)
{
// Generate sized mask
@ -128,14 +243,14 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
if (!(_mm_cvtsi128_si64(v0) & mask))
{
return 0;
return flag & op_flag::inverse ? 2 : 0;
}
}
else if (size == 16)
{
if (!_mm_cvtsi128_si64(_mm_packs_epi16(v0, v0)))
{
return 0;
return flag & op_flag::inverse ? 2 : 0;
}
}
else
@ -145,7 +260,7 @@ cmp_mask(u32 size1, __m128i mask1, __m128i val1, u32 size2, __m128i mask2, __m12
}
// Use force wake-up
return 2;
return flag & op_flag::inverse ? 0 : 2;
}
static atomic_t<u64> s_min_tsc{0};
@ -227,7 +342,8 @@ namespace atomic_wait
// Temporarily reduced unique tsc stamp to 48 bits to make space for refs (TODO)
u64 tsc0 : 48 = 0;
u64 link : 16 = 0;
u16 size{};
u8 size{};
u8 flag{};
atomic_t<u16> refs{};
atomic_t<u32> sync{};
@ -262,6 +378,7 @@ namespace atomic_wait
tsc0 = 0;
link = 0;
size = 0;
flag = 0;
sync = 0;
#ifdef USE_STD
@ -868,7 +985,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
// Store some info for notifiers (some may be unused)
cond->link = 0;
cond->size = static_cast<u16>(size);
cond->size = static_cast<u8>(size);
cond->flag = static_cast<u8>(size >> 8);
cond->mask = mask;
cond->oldv = old_value;
cond->tsc0 = stamp0;
@ -877,7 +995,8 @@ atomic_wait_engine::wait(const void* data, u32 size, __m128i old_value, u64 time
{
// Extensions point to original cond_id, copy remaining info
cond_ext[i]->link = cond_id;
cond_ext[i]->size = static_cast<u16>(ext[i].size);
cond_ext[i]->size = static_cast<u8>(ext[i].size);
cond_ext[i]->flag = static_cast<u8>(ext[i].size >> 8);
cond_ext[i]->mask = ext[i].mask;
cond_ext[i]->oldv = ext[i].old;
cond_ext[i]->tsc0 = stamp0;
@ -1058,7 +1177,7 @@ alert_sema(u32 cond_id, const void* data, u64 info, u32 size, __m128i mask, __m1
u32 cmp_res = 0;
if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size, cond->mask, cond->oldv))))))
if (cond->sync && (!size ? (!info || cond->tid == info) : (cond->ptr == data && ((cmp_res = cmp_mask(size, mask, new_value, cond->size | (cond->flag << 8), cond->mask, cond->oldv))))))
{
// Redirect if necessary
const auto _old = cond;

View File

@ -14,14 +14,56 @@ enum class atomic_wait_timeout : u64
inf = 0xffffffffffffffff,
};
// Unused externally
// Various extensions for atomic_t::wait
namespace atomic_wait
{
// Max number of simultaneous atomic variables to wait on (can be extended if really necessary)
constexpr uint max_list = 8;
struct root_info;
struct sema_handle;
enum class op : u8
{
eq, // Wait while value is bitwise equal to
slt, // Wait while signed value is less than
sgt, // Wait while signed value is greater than
ult, // Wait while unsigned value is less than
ugt, // Wait while unsigned value is greater than
alt, // Wait while absolute value is less than
agt, // Wait while absolute value is greater than
pop, // Wait while set bit count of the value is less than
__max
};
static_assert(static_cast<u8>(op::__max) == 8);
enum class op_flag : u8
{
inverse = 1 << 4, // Perform inverse operation (negate the result)
bit_not = 1 << 5, // Perform bitwise NOT on loaded value before operation
byteswap = 1 << 6, // Perform byteswap on both arguments and masks when applicable
};
constexpr op_flag op_ne = {};
constexpr op_flag op_be = std::endian::native == std::endian::little ? op_flag::byteswap : op_flag{0};
constexpr op_flag op_le = std::endian::native == std::endian::little ? op_flag{0} : op_flag::byteswap;
constexpr op operator |(op_flag lhs, op_flag rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
constexpr op operator |(op_flag lhs, op rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
constexpr op operator |(op lhs, op_flag rhs)
{
return op{static_cast<u8>(static_cast<u8>(lhs) | static_cast<u8>(rhs))};
}
struct info
{
const void* data;
@ -114,24 +156,24 @@ namespace atomic_wait
return *this;
}
template <uint Index, typename T2, std::size_t Align, typename U>
template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U>
constexpr void set(atomic_t<T2, Align>& var, U value)
{
static_assert(Index < Max);
m_info[Index].data = &var.raw();
m_info[Index].size = sizeof(T2);
m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
m_info[Index].template set_value<T2>(value);
m_info[Index].mask = _mm_set1_epi64x(-1);
}
template <uint Index, typename T2, std::size_t Align, typename U, typename V>
template <uint Index, op Flags = op::eq, typename T2, std::size_t Align, typename U, typename V>
constexpr void set(atomic_t<T2, Align>& var, U value, V mask)
{
static_assert(Index < Max);
m_info[Index].data = &var.raw();
m_info[Index].size = sizeof(T2);
m_info[Index].size = sizeof(T2) | (static_cast<u8>(Flags) << 8);
m_info[Index].template set_value<T2>(value);
m_info[Index].template set_mask<T2>(mask);
}
@ -1387,34 +1429,36 @@ public:
}
// Timeout is discouraged
template <atomic_wait::op Flags = atomic_wait::op::eq>
void wait(type old_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf) const noexcept
{
if constexpr (sizeof(T) <= 8)
{
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
}
else if constexpr (sizeof(T) == 16)
{
const __m128i old = std::bit_cast<__m128i>(old_value);
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), _mm_set1_epi64x(-1));
}
}
// Overload with mask (only selected bits are checked), timeout is discouraged
template <atomic_wait::op Flags = atomic_wait::op::eq>
void wait(type old_value, type mask_value, atomic_wait_timeout timeout = atomic_wait_timeout::inf)
{
if constexpr (sizeof(T) <= 8)
{
const __m128i old = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(old_value));
const __m128i mask = _mm_cvtsi64_si128(std::bit_cast<get_uint_t<sizeof(T)>>(mask_value));
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask);
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
}
else if constexpr (sizeof(T) == 16)
{
const __m128i old = std::bit_cast<__m128i>(old_value);
const __m128i mask = std::bit_cast<__m128i>(mask_value);
atomic_wait_engine::wait(&m_data, sizeof(T), old, static_cast<u64>(timeout), mask);
atomic_wait_engine::wait(&m_data, sizeof(T) | (static_cast<u8>(Flags) << 8), old, static_cast<u64>(timeout), mask);
}
}