sys_net: Implement sys_net_abort

This commit is contained in:
Eladash 2022-06-21 12:38:44 +03:00 committed by Ivan
parent 07ebbb6c84
commit c0369b2e10
7 changed files with 156 additions and 22 deletions

View File

@ -370,7 +370,7 @@ error_code sys_net_bnet_accept(ppu_thread& ppu, s32 s, vm::ptr<sys_net_sockaddr>
return true;
}
sock.poll_queue(ppu.id, lv2_socket::poll_t::read, [&](bs_t<lv2_socket::poll_t> events) -> bool
sock.poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), lv2_socket::poll_t::read, [&](bs_t<lv2_socket::poll_t> events) -> bool
{
if (events & lv2_socket::poll_t::read)
{
@ -417,7 +417,7 @@ error_code sys_net_bnet_accept(ppu_thread& ppu, s32 s, vm::ptr<sys_net_sockaddr>
if (ppu.gpr[3] == static_cast<u64>(-SYS_NET_EINTR))
{
return -sys_net_error{SYS_NET_EINTR};
return -SYS_NET_EINTR;
}
if (result < 0)
@ -527,7 +527,7 @@ error_code sys_net_bnet_connect(ppu_thread& ppu, s32 s, vm::ptr<sys_net_sockaddr
return true;
}
sock.poll_queue(ppu.id, lv2_socket::poll_t::write, [&](bs_t<lv2_socket::poll_t> events) -> bool
sock.poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), lv2_socket::poll_t::write, [&](bs_t<lv2_socket::poll_t> events) -> bool
{
if (events & lv2_socket::poll_t::write)
{
@ -802,7 +802,7 @@ error_code sys_net_bnet_recvfrom(ppu_thread& ppu, s32 s, vm::ptr<void> buf, u32
return true;
}
sock.poll_queue(ppu.id, lv2_socket::poll_t::read, [&](bs_t<lv2_socket::poll_t> events) -> bool
sock.poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), lv2_socket::poll_t::read, [&](bs_t<lv2_socket::poll_t> events) -> bool
{
if (events & lv2_socket::poll_t::read)
{
@ -934,7 +934,7 @@ error_code sys_net_bnet_sendto(ppu_thread& ppu, s32 s, vm::cptr<void> buf, u32 l
}
// Enable write event
sock.poll_queue(ppu.id, lv2_socket::poll_t::write, [&](bs_t<lv2_socket::poll_t> events) -> bool
sock.poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), lv2_socket::poll_t::write, [&](bs_t<lv2_socket::poll_t> events) -> bool
{
if (events & lv2_socket::poll_t::write)
{
@ -1257,7 +1257,7 @@ error_code sys_net_bnet_poll(ppu_thread& ppu, vm::ptr<sys_net_pollfd> fds, s32 n
// if (fds_buf[i].events & SYS_NET_POLLPRI) // Unimplemented
// selected += lv2_socket::poll::error;
sock->poll_queue(ppu.id, selected, [sock, selected, &fds_buf, i, &signaled, &ppu](bs_t<lv2_socket::poll_t> events)
sock->poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), selected, [sock, selected, &fds_buf, i, &signaled, &ppu](bs_t<lv2_socket::poll_t> events)
{
if (events & selected)
{
@ -1282,6 +1282,8 @@ error_code sys_net_bnet_poll(ppu_thread& ppu, vm::ptr<sys_net_pollfd> fds, s32 n
lv2_obj::sleep(ppu, timeout);
}
bool has_timedout = false;
while (auto state = ppu.state.fetch_sub(cpu_flag::signal))
{
if (is_stopped(state))
@ -1311,7 +1313,7 @@ error_code sys_net_bnet_poll(ppu_thread& ppu, vm::ptr<sys_net_pollfd> fds, s32 n
break;
}
network_clear_queue(ppu);
has_timedout = network_clear_queue(ppu);
break;
}
}
@ -1322,6 +1324,12 @@ error_code sys_net_bnet_poll(ppu_thread& ppu, vm::ptr<sys_net_pollfd> fds, s32 n
}
std::memcpy(fds.get_ptr(), fds_buf.data(), nfds * sizeof(fds[0]));
if (!has_timedout && !signaled)
{
return -SYS_NET_EINTR;
}
return not_an_error(signaled);
}
@ -1481,7 +1489,7 @@ error_code sys_net_bnet_select(ppu_thread& ppu, s32 nfds, vm::ptr<sys_net_fd_set
sock->set_connecting(connecting[i]);
#endif
sock->poll_queue(ppu.id, selected, [sock, selected, i, &rread, &rwrite, &rexcept, &signaled, &ppu](bs_t<lv2_socket::poll_t> events)
sock->poll_queue(idm::get_unlocked<named_thread<ppu_thread>>(ppu.id), selected, [sock, selected, i, &rread, &rwrite, &rexcept, &signaled, &ppu](bs_t<lv2_socket::poll_t> events)
{
if (events & selected)
{
@ -1514,6 +1522,8 @@ error_code sys_net_bnet_select(ppu_thread& ppu, s32 nfds, vm::ptr<sys_net_fd_set
return -SYS_NET_EINVAL;
}
bool has_timedout = false;
while (auto state = ppu.state.fetch_sub(cpu_flag::signal))
{
if (is_stopped(state))
@ -1543,7 +1553,7 @@ error_code sys_net_bnet_select(ppu_thread& ppu, s32 nfds, vm::ptr<sys_net_fd_set
break;
}
network_clear_queue(ppu);
has_timedout = network_clear_queue(ppu);
break;
}
}
@ -1560,6 +1570,11 @@ error_code sys_net_bnet_select(ppu_thread& ppu, s32 nfds, vm::ptr<sys_net_fd_set
if (exceptfds)
*exceptfds = rexcept;
if (!has_timedout && !signaled)
{
return -SYS_NET_EINTR;
}
return not_an_error(signaled);
}
@ -1595,11 +1610,112 @@ error_code _sys_net_write_dump(ppu_thread& ppu, s32 id, vm::cptr<void> buf, s32
return CELL_OK;
}
error_code lv2_socket::abort_socket(s32 flags)
{
decltype(queue) qcopy;
{
std::lock_guard lock(mutex);
if (queue.empty())
{
if (flags & SYS_NET_ABORT_STRICT_CHECK)
{
// Strict error checking: ENOENT if nothing happened
return -SYS_NET_ENOENT;
}
// TODO: Abort the subsequent function called on this socket (need to investigate correct behaviour)
return CELL_OK;
}
qcopy = std::move(queue);
events.store({});
}
for (auto& [ppu, _] : qcopy)
{
sys_net.warning("lv2_socket::abort_socket(): waking up \"%s\": (func: %s, r3=0x%x, r4=0x%x, r5=0x%x, r6=0x%x)", ppu->get_name(), ppu->current_function, ppu->gpr[3], ppu->gpr[4], ppu->gpr[5], ppu->gpr[6]);
ppu->gpr[3] = static_cast<u64>(-SYS_NET_EINTR);
lv2_obj::append(ppu.get());
}
lv2_obj::awake_all();
return CELL_OK;
}
error_code sys_net_abort(ppu_thread& ppu, s32 type, u64 arg, s32 flags)
{
ppu.state += cpu_flag::wait;
sys_net.todo("sys_net_abort(type=%d, arg=0x%x, flags=0x%x)", type, arg, flags);
enum abort_type : s32
{
_socket,
resolver,
type_2, // ??
type_3, // ??
all,
};
switch (type)
{
case _socket:
{
std::lock_guard nw_lock(g_fxo->get<network_context>().s_nw_mutex);
const auto sock = idm::get<lv2_socket>(static_cast<u32>(arg));
if (!sock)
{
return -SYS_NET_EBADF;
}
return sock->abort_socket(flags);
}
case all:
{
std::vector<u32> sockets;
idm::select<lv2_socket>([&](u32 id, lv2_socket&)
{
sockets.emplace_back(id);
});
s32 failed = 0;
for (u32 id : sockets)
{
const auto sock = idm::withdraw<lv2_socket>(id);
if (!sock)
{
failed++;
continue;
}
if (sock->get_queue_size())
sys_net.error("ABORT 4");
sock->close();
sys_net.success("lv2_socket::handle_abort(): Closed socket %d", id);
}
// Ensures the socket has no lingering copy from the network thread
g_fxo->get<network_context>().s_nw_mutex.lock_unlock();
return not_an_error(::narrow<s32>(sockets.size()) - failed);
}
case resolver:
case type_2:
case type_3:
{
break;
}
default: return -SYS_NET_EINVAL;
}
return CELL_OK;
}

View File

@ -181,6 +181,11 @@ enum
SYS_NET_POLLWRBAND = 0x0100,
};
enum lv2_socket_abort_flags : s32
{
SYS_NET_ABORT_STRICT_CHECK = 1,
};
// in_addr_t type prefixed with sys_net_
using sys_net_in_addr_t = u32;

View File

@ -63,21 +63,24 @@ void lv2_socket::set_poll_event(bs_t<lv2_socket::poll_t> event)
events += event;
}
void lv2_socket::poll_queue(u32 ppu_id, bs_t<lv2_socket::poll_t> event, std::function<bool(bs_t<lv2_socket::poll_t>)> poll_cb)
void lv2_socket::poll_queue(std::shared_ptr<ppu_thread> ppu, bs_t<lv2_socket::poll_t> event, std::function<bool(bs_t<lv2_socket::poll_t>)> poll_cb)
{
set_poll_event(event);
queue.emplace_back(ppu_id, poll_cb);
queue.emplace_back(std::move(ppu), poll_cb);
}
void lv2_socket::clear_queue(u32 ppu_id)
s32 lv2_socket::clear_queue(ppu_thread* ppu)
{
std::lock_guard lock(mutex);
s32 cleared = 0;
for (auto it = queue.begin(); it != queue.end();)
{
if (it->first == ppu_id)
if (it->first.get() == ppu)
{
it = queue.erase(it);
cleared++;
continue;
}
@ -88,6 +91,8 @@ void lv2_socket::clear_queue(u32 ppu_id)
{
events.store({});
}
return cleared;
}
void lv2_socket::handle_events(const pollfd& native_pfd, [[maybe_unused]] bool unset_connecting)

View File

@ -63,8 +63,8 @@ public:
void set_lv2_id(u32 id);
bs_t<poll_t> get_events() const;
void set_poll_event(bs_t<poll_t> event);
void poll_queue(u32 ppu_id, bs_t<poll_t> event, std::function<bool(bs_t<poll_t>)> poll_cb);
void clear_queue(u32 ppu_id);
void poll_queue(std::shared_ptr<ppu_thread> ppu, bs_t<poll_t> event, std::function<bool(bs_t<poll_t>)> poll_cb);
s32 clear_queue(ppu_thread*);
void handle_events(const pollfd& native_fd, bool unset_connecting = false);
lv2_socket_family get_family() const;
@ -101,6 +101,8 @@ public:
virtual s32 poll(sys_net_pollfd& sn_pfd, pollfd& native_pfd) = 0;
virtual std::tuple<bool, bool, bool> select(bs_t<poll_t> selected, pollfd& native_pfd) = 0;
error_code abort_socket(s32 flags);
public:
// IDM data
static const u32 id_base = 24;
@ -121,8 +123,9 @@ protected:
// Events selected for polling
atomic_bs_t<poll_t> events{};
// Event processing workload (pair of thread id and the processing function)
std::vector<std::pair<u32, std::function<bool(bs_t<poll_t>)>>> queue;
std::vector<std::pair<std::shared_ptr<ppu_thread>, std::function<bool(bs_t<poll_t>)>>> queue;
// Socket options value keepers
// Non-blocking IO option

View File

@ -176,7 +176,7 @@ std::optional<s32> lv2_socket_native::connect(const sys_net_sockaddr& addr)
#ifdef _WIN32
connecting = true;
#endif
this->poll_queue(u32{0}, lv2_socket::poll_t::write, [this](bs_t<lv2_socket::poll_t> events) -> bool
this->poll_queue(nullptr, lv2_socket::poll_t::write, [this](bs_t<lv2_socket::poll_t> events) -> bool
{
if (events & lv2_socket::poll_t::write)
{
@ -854,6 +854,7 @@ std::optional<std::tuple<s32, std::vector<u8>, sys_net_sockaddr>> lv2_socket_nat
#endif
const auto result = get_last_error(!so_nbio && (flags & SYS_NET_MSG_DONTWAIT) == 0);
if (result)
{
return {{-result, {}, {}}};

View File

@ -151,12 +151,16 @@ sys_net_sockaddr native_addr_to_sys_net_addr(const ::sockaddr_storage& native_ad
return native_addr;
}
void network_clear_queue(ppu_thread& ppu)
s32 network_clear_queue(ppu_thread& ppu)
{
s32 cleared = 0;
idm::select<lv2_socket>([&](u32, lv2_socket& sock)
{
sock.clear_queue(ppu.id);
});
{
cleared += sock.clear_queue(&ppu);
});
return cleared;
}
#ifdef _WIN32

View File

@ -21,7 +21,7 @@ int get_native_error();
sys_net_error get_last_error(bool is_blocking, int native_error = 0);
sys_net_sockaddr native_addr_to_sys_net_addr(const ::sockaddr_storage& native_addr);
::sockaddr_in sys_net_addr_to_native_addr(const sys_net_sockaddr& sn_addr);
void network_clear_queue(ppu_thread& ppu);
s32 network_clear_queue(ppu_thread& ppu);
#ifdef _WIN32
void windows_poll(pollfd* fds, unsigned long nfds, int timeout, bool* connecting);