diff --git a/Utilities/Thread.cpp b/Utilities/Thread.cpp index e372d1e924..1475256ce7 100644 --- a/Utilities/Thread.cpp +++ b/Utilities/Thread.cpp @@ -1591,8 +1591,34 @@ bool handle_access_violation(u32 addr, bool is_writing, x64_context* context) // TODO: allow recovering from a page fault as a feature of PS3 virtual memory } +// Detect leaf function +static bool is_leaf_function(u64 rip) +{ +#ifdef _WIN32 + DWORD64 base = 0; + if (const auto rtf = RtlLookupFunctionEntry(rip, &base, nullptr)) + { + // Access UNWIND_INFO structure + const auto uw = (u8*)(base + rtf->UnwindData); + + // Leaf function has zero epilog size and no unwind codes + return uw[0] == 1 && uw[1] == 0 && uw[2] == 0 && uw[3] == 0; + } + + // No unwind info implies leaf function + return true; +#else + // TODO + return false; +#endif +} + +static thread_local u64 s_tls_ret_pos = 0; +static thread_local u64 s_tls_ret_addr = 0; + [[noreturn]] static void throw_access_violation(const char* cause, u64 addr) { + if (s_tls_ret_pos) *(u64*)s_tls_ret_pos = s_tls_ret_addr; // Fix stack vm::throw_access_violation(addr, cause); std::abort(); } @@ -1605,7 +1631,8 @@ static void prepare_throw_access_violation(x64_context* context, const char* cau ARG2(context) = address; // Push the exception address as a "return" address (throw_access_violation() shall not return) - *--(u64*&)(RSP(context)) = RIP(context); + s_tls_ret_addr = RIP(context); + s_tls_ret_pos = is_leaf_function(s_tls_ret_addr) ? 0 : RSP(context) -= sizeof(u64); RIP(context) = (u64)std::addressof(throw_access_violation); } @@ -1793,10 +1820,10 @@ struct thread_ctrl::internal #ifdef _WIN32 DWORD thread_id = 0; - x64_context _context{}; #endif - x64_context* thread_ctx{}; + x64_context _context{}; + x64_context* const thread_ctx = &this->_context; atomic_t interrupt{}; // Interrupt function }; @@ -2047,10 +2074,41 @@ void thread_ctrl::set_exception(std::exception_ptr e) static void _handle_interrupt(x64_context* ctx) { - g_tls_internal->thread_ctx = ctx; + // Copy context for further use (TODO: is it safe on all platforms?) + g_tls_internal->_context = *ctx; thread_ctrl::handle_interrupt(); } +static thread_local void(*s_tls_handler)() = nullptr; + +[[noreturn]] static void execute_interrupt_handler() +{ + // Fix stack for throwing + if (s_tls_ret_pos) s_tls_ret_addr = std::exchange(*(u64*)s_tls_ret_pos, s_tls_ret_addr); + + // Run either throwing or returning interrupt handler + s_tls_handler(); + + // Restore context in the case of return + const auto ctx = g_tls_internal->thread_ctx; + + if (s_tls_ret_pos) + { + RIP(ctx) = std::exchange(*(u64*)s_tls_ret_pos, s_tls_ret_addr); + RSP(ctx) += sizeof(u64); + } + else + { + RIP(ctx) = s_tls_ret_addr; + } + +#ifdef _WIN32 + RtlRestoreContext(ctx, nullptr); +#else + ::setcontext(ctx); +#endif +} + void thread_ctrl::handle_interrupt() { const auto _this = g_tls_this_thread; @@ -2075,9 +2133,16 @@ void thread_ctrl::handle_interrupt() _this->unlock(); g_tls_internal->icv.notify_one(); +#ifdef _WIN32 // Install function call - *--(u64*&)(RSP(ctx)) = RIP(ctx); - RIP(ctx) = (u64)handler; + s_tls_ret_addr = RIP(ctx); + s_tls_ret_pos = is_leaf_function(s_tls_ret_addr) ? 0 : RSP(ctx) -= sizeof(u64); + s_tls_handler = handler; + RIP(ctx) = (u64)execute_interrupt_handler; +#else + // Call handler directly (TODO: install function call preserving red zone) + return handler(); +#endif } } else @@ -2097,8 +2162,7 @@ void thread_ctrl::interrupt(void(*handler)()) VERIFY(m_data->interrupt.compare_and_swap_test(nullptr, handler)); // TODO: multiple interrupts #ifdef _WIN32 - const auto ctx = &m_data->_context; - m_data->thread_ctx = ctx; + const auto ctx = m_data->thread_ctx; const HANDLE nt = OpenThread(THREAD_ALL_ACCESS, FALSE, m_data->thread_id); VERIFY(nt);