From 4d0c835df3d2cc784e0e81045631d1d3678608f5 Mon Sep 17 00:00:00 2001 From: Elad <18193363+elad335@users.noreply.github.com> Date: Mon, 23 Dec 2024 09:17:07 +0200 Subject: [PATCH] util/shared_ptr.hpp: STX pointers library fixes --- rpcs3/util/shared_ptr.hpp | 113 ++++++++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 35 deletions(-) diff --git a/rpcs3/util/shared_ptr.hpp b/rpcs3/util/shared_ptr.hpp index 6e07db68fd..7685cfda91 100644 --- a/rpcs3/util/shared_ptr.hpp +++ b/rpcs3/util/shared_ptr.hpp @@ -154,6 +154,7 @@ namespace stx if (m_ptr) [[likely]] { const auto o = d(); + ensure(o->refs == 1); o->destroy.load()(o); m_ptr = nullptr; } @@ -437,11 +438,15 @@ namespace stx // Set to null void reset() noexcept { - const auto o = d(); - - if (m_ptr && !--o->refs) [[unlikely]] + if (m_ptr) [[unlikely]] { - o->destroy(o); + const auto o = d(); + + if (!--o->refs) + { + o->destroy(o); + } + m_ptr = nullptr; } } @@ -571,7 +576,7 @@ namespace stx { mutable atomic_t m_val{0}; - static shared_counter* d(uptr val) + static shared_counter* d(uptr val) noexcept { return std::launder(reinterpret_cast((val >> c_ref_size) - sizeof(shared_counter))); } @@ -581,9 +586,32 @@ namespace stx return d(m_val); } + static uptr to_val(const volatile std::remove_extent_t* ptr) noexcept + { + return (reinterpret_cast(ptr) << c_ref_size); + } + + static std::remove_extent_t* ptr_to(uptr val) noexcept + { + return reinterpret_cast*>(val >> c_ref_size); + } + template friend class atomic_ptr; + // Helper struct to check if a type is an instance of a template + template class Template> + struct is_instance_of : std::false_type {}; + + template class Template> + struct is_instance_of, Template> : std::true_type {}; + + template + static constexpr bool is_stx_pointer = false + || is_instance_of, shared_ptr>::value + || is_instance_of, single_ptr>::value + || is_instance_of, atomic_ptr>::value; + public: using element_type = std::remove_extent_t; @@ -592,11 +620,14 @@ namespace stx constexpr atomic_ptr() noexcept = default; // Optimized value construct - template requires (!(sizeof...(Args) == 1 && (std::is_same_v, shared_type> || ...)) && std::is_constructible_v) + template requires (true + && sizeof...(Args) != 0 + && !(sizeof...(Args) == 1 && (is_stx_pointer || ...)) + && std::is_constructible_v) explicit atomic_ptr(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); - m_val = reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size; + m_val.raw() = to_val(std::exchange(r.m_ptr, nullptr)); d()->refs.raw() += c_ref_mask; } @@ -604,32 +635,38 @@ namespace stx atomic_ptr(const shared_ptr& r) noexcept { // Obtain a ref + as many refs as an atomic_ptr can additionally reference - m_val = reinterpret_cast(r.m_ptr) << c_ref_size; - if (m_val) - d()->refs += c_ref_mask + 1; + if (uptr rval = to_val(r.m_ptr)) + { + m_val.raw() = rval; + d(rval)->refs += c_ref_mask + 1; + } } template requires same_ptr_implicit_v atomic_ptr(shared_ptr&& r) noexcept { - m_val = reinterpret_cast(r.m_ptr) << c_ref_size; - r.m_ptr = nullptr; + if (uptr rval = to_val(r.m_ptr)) + { + m_val.raw() = rval; + d(rval)->refs += c_ref_mask; + } - if (m_val) - d()->refs += c_ref_mask; + r.m_ptr = nullptr; } template requires same_ptr_implicit_v atomic_ptr(single_ptr&& r) noexcept { - m_val = reinterpret_cast(r.m_ptr) << c_ref_size; - r.m_ptr = nullptr; + if (uptr rval = to_val(r.m_ptr)) + { + m_val.raw() = rval; + d(rval)->refs += c_ref_mask; + } - if (m_val) - d()->refs += c_ref_mask; + r.m_ptr = nullptr; } - ~atomic_ptr() + ~atomic_ptr() noexcept { const uptr v = m_val.raw(); @@ -645,13 +682,13 @@ namespace stx } // Optimized value assignment - atomic_ptr& operator=(std::remove_cv_t value) noexcept + atomic_ptr& operator=(std::remove_cv_t value) noexcept requires (!is_stx_pointer) { shared_type r = make_single(std::move(value)); r.d()->refs.raw() += c_ref_mask; atomic_ptr old; - old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size); + old.m_val.raw() = m_val.exchange(to_val(std::exchange(r.m_ptr, nullptr))); return *this; } @@ -704,7 +741,7 @@ namespace stx } // Set referenced pointer - r.m_ptr = std::launder(reinterpret_cast(prev >> c_ref_size)); + r.m_ptr = std::launder(ptr_to(prev)); r.d()->refs++; // Dereference if still the same pointer @@ -749,7 +786,7 @@ namespace stx // Set fake unreferenced pointer if (did_ref) { - r.m_ptr = std::launder(reinterpret_cast(prev >> c_ref_size)); + r.m_ptr = std::launder(ptr_to(prev)); } // Result temp storage @@ -805,14 +842,17 @@ namespace stx // Create an object from variadic args // If a type needs shared_type to be constructed, std::reference_wrapper can be used - template requires (!(sizeof...(Args) == 1 && (std::is_same_v, shared_type> || ...)) && std::is_constructible_v) + template requires (true + && sizeof...(Args) != 0 + && !(sizeof...(Args) == 1 && (is_stx_pointer || ...)) + && std::is_constructible_v) void store(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); r.d()->refs.raw() += c_ref_mask; atomic_ptr old; - old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(r.m_ptr, nullptr)) << c_ref_size); + old.m_val.raw() = m_val.exchange(to_val(std::exchange(r.m_ptr, nullptr))); } void store(shared_type value) noexcept @@ -824,20 +864,23 @@ namespace stx } atomic_ptr old; - old.m_val.raw() = m_val.exchange(reinterpret_cast(std::exchange(value.m_ptr, nullptr)) << c_ref_size); + old.m_val.raw() = m_val.exchange(to_val(std::exchange(value.m_ptr, nullptr))); } - template requires (!(sizeof...(Args) == 1 && (std::is_same_v, shared_type> || ...)) && std::is_constructible_v) + template requires (true + && sizeof...(Args) != 0 + && !(sizeof...(Args) == 1 && (is_stx_pointer || ...)) + && std::is_constructible_v) [[nodiscard]] shared_type exchange(Args&&... args) noexcept { shared_type r = make_single(std::forward(args)...); r.d()->refs.raw() += c_ref_mask; atomic_ptr old; - old.m_val.raw() += m_val.exchange(reinterpret_cast(r.m_ptr) << c_ref_size); + old.m_val.raw() = m_val.exchange(to_val(r.m_ptr)); old.m_val.raw() += 1; - r.m_ptr = std::launder(reinterpret_cast(old.m_val >> c_ref_size)); + r.m_ptr = std::launder(ptr_to(old.m_val)); return r; } @@ -850,10 +893,10 @@ namespace stx } atomic_ptr old; - old.m_val.raw() += m_val.exchange(reinterpret_cast(value.m_ptr) << c_ref_size); + old.m_val.raw() = m_val.exchange(to_val(value.m_ptr)); old.m_val.raw() += 1; - value.m_ptr = std::launder(reinterpret_cast(old.m_val >> c_ref_size)); + value.m_ptr = std::launder(ptr_to(old.m_val)); return value; } @@ -898,10 +941,10 @@ namespace stx } atomic_ptr old_exch; - old_exch.m_val.raw() = reinterpret_cast(std::exchange(exch.m_ptr, nullptr)) << c_ref_size; + old_exch.m_val.raw() = to_val(std::exchange(exch.m_ptr, nullptr)); // Set to reset old cmp_and_old value - old.m_val.raw() = (reinterpret_cast(cmp_and_old.m_ptr) << c_ref_size) | c_ref_mask; + old.m_val.raw() = to_val(cmp_and_old.m_ptr) | c_ref_mask; if (!_val) { @@ -909,7 +952,7 @@ namespace stx } // Set referenced pointer - cmp_and_old.m_ptr = std::launder(reinterpret_cast(_val >> c_ref_size)); + cmp_and_old.m_ptr = std::launder(ptr_to(_val)); cmp_and_old.d()->refs++; // Dereference if still the same pointer @@ -977,7 +1020,7 @@ namespace stx } // Failure (return references) - old.m_val.raw() = reinterpret_cast(std::exchange(exch.m_ptr, nullptr)) << c_ref_size; + old.m_val.raw() = to_val(std::exchange(exch.m_ptr, nullptr)); return false; }