diff --git a/Source/Core/Common/TraversalClient.cpp b/Source/Core/Common/TraversalClient.cpp index 071bccbc1c..e34d2cbf2f 100644 --- a/Source/Core/Common/TraversalClient.cpp +++ b/Source/Core/Common/TraversalClient.cpp @@ -48,32 +48,32 @@ void TraversalClient::ReconnectToServer() } m_ServerAddress.port = m_port; - m_State = Connecting; + m_State = State::Connecting; TraversalPacket hello = {}; - hello.type = TraversalPacketHelloFromClient; + hello.type = TraversalPacketType::HelloFromClient; hello.helloFromClient.protoVersion = TraversalProtoVersion; SendTraversalPacket(hello); if (m_Client) m_Client->OnTraversalStateChanged(); } -static ENetAddress MakeENetAddress(TraversalInetAddress* address) +static ENetAddress MakeENetAddress(const TraversalInetAddress& address) { - ENetAddress eaddr; - if (address->isIPV6) + ENetAddress eaddr{}; + if (address.isIPV6) { eaddr.port = 0; // no support yet :( } else { - eaddr.host = address->address[0]; - eaddr.port = ntohs(address->port); + eaddr.host = address.address[0]; + eaddr.port = ntohs(address.port); } return eaddr; } -void TraversalClient::ConnectToClient(const std::string& host) +void TraversalClient::ConnectToClient(std::string_view host) { if (host.size() > sizeof(TraversalHostId)) { @@ -81,8 +81,8 @@ void TraversalClient::ConnectToClient(const std::string& host) return; } TraversalPacket packet = {}; - packet.type = TraversalPacketConnectPlease; - memcpy(packet.connectPlease.hostId.data(), host.c_str(), host.size()); + packet.type = TraversalPacketType::ConnectPlease; + memcpy(packet.connectPlease.hostId.data(), host.data(), host.size()); m_ConnectRequestId = SendTraversalPacket(packet); m_PendingConnect = true; } @@ -129,7 +129,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) u8 ok = 1; switch (packet->type) { - case TraversalPacketAck: + case TraversalPacketType::Ack: if (!packet->ack.ok) { OnFailure(FailureReason::ServerForgotAboutUs); @@ -144,8 +144,8 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) } } break; - case TraversalPacketHelloFromServer: - if (m_State != Connecting) + case TraversalPacketType::HelloFromServer: + if (!IsConnecting()) break; if (!packet->helloFromServer.ok) { @@ -153,14 +153,14 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) break; } m_HostId = packet->helloFromServer.yourHostId; - m_State = Connected; + m_State = State::Connected; if (m_Client) m_Client->OnTraversalStateChanged(); break; - case TraversalPacketPleaseSendPacket: + case TraversalPacketType::PleaseSendPacket: { // security is overrated. - ENetAddress addr = MakeENetAddress(&packet->pleaseSendPacket.address); + ENetAddress addr = MakeENetAddress(packet->pleaseSendPacket.address); if (addr.port != 0) { char message[] = "Hello from Dolphin Netplay..."; @@ -176,8 +176,8 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) } break; } - case TraversalPacketConnectReady: - case TraversalPacketConnectFailed: + case TraversalPacketType::ConnectReady: + case TraversalPacketType::ConnectFailed: { if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId) break; @@ -187,8 +187,8 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) if (!m_Client) break; - if (packet->type == TraversalPacketConnectReady) - m_Client->OnConnectReady(MakeENetAddress(&packet->connectReady.address)); + if (packet->type == TraversalPacketType::ConnectReady) + m_Client->OnConnectReady(MakeENetAddress(packet->connectReady.address)); else m_Client->OnConnectFailed(packet->connectFailed.reason); break; @@ -197,10 +197,10 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) WARN_LOG_FMT(NETPLAY, "Received unknown packet with type {}", packet->type); break; } - if (packet->type != TraversalPacketAck) + if (packet->type != TraversalPacketType::Ack) { TraversalPacket ack = {}; - ack.type = TraversalPacketAck; + ack.type = TraversalPacketType::Ack; ack.requestId = packet->requestId; ack.ack.ok = ok; @@ -214,7 +214,7 @@ void TraversalClient::HandleServerPacket(TraversalPacket* packet) void TraversalClient::OnFailure(FailureReason reason) { - m_State = Failure; + m_State = State::Failure; m_FailureReason = reason; if (m_Client) @@ -257,10 +257,10 @@ void TraversalClient::HandleResends() void TraversalClient::HandlePing() { const u32 now = enet_time_get(); - if (m_State == Connected && now - m_PingTime >= 500) + if (IsConnected() && now - m_PingTime >= 500) { TraversalPacket ping = {}; - ping.type = TraversalPacketPing; + ping.type = TraversalPacketType::Ping; ping.ping.hostId = m_HostId; SendTraversalPacket(ping); m_PingTime = now; diff --git a/Source/Core/Common/TraversalClient.h b/Source/Core/Common/TraversalClient.h index 78e1d0c705..e621a8169a 100644 --- a/Source/Core/Common/TraversalClient.h +++ b/Source/Core/Common/TraversalClient.h @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -19,17 +20,17 @@ public: virtual ~TraversalClientClient() = default; virtual void OnTraversalStateChanged() = 0; virtual void OnConnectReady(ENetAddress addr) = 0; - virtual void OnConnectFailed(u8 reason) = 0; + virtual void OnConnectFailed(TraversalConnectFailedReason reason) = 0; }; class TraversalClient { public: - enum State + enum class State { Connecting, Connected, - Failure + Failure, }; enum class FailureReason { @@ -46,8 +47,12 @@ public: State GetState() const; FailureReason GetFailureReason() const; + bool HasFailed() const { return m_State == State::Failure; } + bool IsConnecting() const { return m_State == State::Connecting; } + bool IsConnected() const { return m_State == State::Connected; } + void Reset(); - void ConnectToClient(const std::string& host); + void ConnectToClient(std::string_view host); void ReconnectToServer(); void Update(); void HandleResends(); diff --git a/Source/Core/Common/TraversalProto.h b/Source/Core/Common/TraversalProto.h index 828d782626..3d8e34de73 100644 --- a/Source/Core/Common/TraversalProto.h +++ b/Source/Core/Common/TraversalProto.h @@ -1,44 +1,43 @@ // This file is public domain, in case it's useful to anyone. -comex #pragma once + #include +#include #include "Common/CommonTypes.h" -#define NETPLAY_CODE_SIZE 8 -typedef std::array TraversalHostId; -typedef u64 TraversalRequestId; +constexpr size_t NETPLAY_CODE_SIZE = 8; +using TraversalHostId = std::array; +using TraversalRequestId = u64; -enum TraversalPacketType +enum class TraversalPacketType : u8 { // [*->*] - TraversalPacketAck = 0, + Ack = 0, // [c->s] - TraversalPacketPing = 1, + Ping = 1, // [c->s] - TraversalPacketHelloFromClient = 2, + HelloFromClient = 2, // [s->c] - TraversalPacketHelloFromServer = 3, + HelloFromServer = 3, // [c->s] When connecting, first the client asks the central server... - TraversalPacketConnectPlease = 4, + ConnectPlease = 4, // [s->c] ...who asks the game host to send a UDP packet to the // client... (an ack implies success) - TraversalPacketPleaseSendPacket = 5, + PleaseSendPacket = 5, // [s->c] ...which the central server relays back to the client. - TraversalPacketConnectReady = 6, + ConnectReady = 6, // [s->c] Alternately, the server might not have heard of this host. - TraversalPacketConnectFailed = 7 + ConnectFailed = 7, }; -enum -{ - TraversalProtoVersion = 0 -}; +constexpr u8 TraversalProtoVersion = 0; -enum TraversalConnectFailedReason +enum class TraversalConnectFailedReason : u8 { - TraversalConnectFailedClientDidntRespond = 0, - TraversalConnectFailedClientFailure, - TraversalConnectFailedNoSuchClient + ClientDidntRespond = 0, + ClientFailure, + NoSuchClient, }; #pragma pack(push, 1) @@ -50,7 +49,7 @@ struct TraversalInetAddress }; struct TraversalPacket { - u8 type; + TraversalPacketType type; TraversalRequestId requestId; union { @@ -88,7 +87,7 @@ struct TraversalPacket struct { TraversalRequestId requestId; - u8 reason; + TraversalConnectFailedReason reason; } connectFailed; }; }; diff --git a/Source/Core/Common/TraversalServer.cpp b/Source/Core/Common/TraversalServer.cpp index 80d32d52fa..655e591916 100644 --- a/Source/Core/Common/TraversalServer.cpp +++ b/Source/Core/Common/TraversalServer.cpp @@ -185,8 +185,9 @@ static const char* SenderName(sockaddr_in6* addr) static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr) { #if DEBUG - printf("-> %d %llu %s\n", ((TraversalPacket*)buffer)->type, - (long long)((TraversalPacket*)buffer)->requestId, SenderName(addr)); + const auto* packet = static_cast(buffer); + printf("-> %d %llu %s\n", static_cast(packet->type), + static_cast(packet->requestId), SenderName(addr)); #endif if ((size_t)sendto(sock, buffer, size, 0, (sockaddr*)addr, sizeof(*addr)) != size) { @@ -227,7 +228,7 @@ static void ResendPackets() { if (info->tries >= NUMBER_OF_TRIES) { - if (info->packet.type == TraversalPacketPleaseSendPacket) + if (info->packet.type == TraversalPacketType::PleaseSendPacket) { todoFailures.push_back(std::make_pair(info->packet.pleaseSendPacket.address, info->misc)); } @@ -245,21 +246,22 @@ static void ResendPackets() for (const auto& p : todoFailures) { TraversalPacket* fail = AllocPacket(MakeSinAddr(p.first)); - fail->type = TraversalPacketConnectFailed; + fail->type = TraversalPacketType::ConnectFailed; fail->connectFailed.requestId = p.second; - fail->connectFailed.reason = TraversalConnectFailedClientDidntRespond; + fail->connectFailed.reason = TraversalConnectFailedReason::ClientDidntRespond; } } static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr) { #if DEBUG - printf("<- %d %llu %s\n", packet->type, (long long)packet->requestId, SenderName(addr)); + printf("<- %d %llu %s\n", static_cast(packet->type), + static_cast(packet->requestId), SenderName(addr)); #endif bool packetOk = true; switch (packet->type) { - case TraversalPacketAck: + case TraversalPacketType::Ack: { auto it = outgoingPackets.find(packet->requestId); if (it == outgoingPackets.end()) @@ -267,37 +269,37 @@ static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr) OutgoingPacketInfo* info = &it->second; - if (info->packet.type == TraversalPacketPleaseSendPacket) + if (info->packet.type == TraversalPacketType::PleaseSendPacket) { TraversalPacket* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address)); if (packet->ack.ok) { - ready->type = TraversalPacketConnectReady; + ready->type = TraversalPacketType::ConnectReady; ready->connectReady.requestId = info->misc; ready->connectReady.address = MakeInetAddress(info->dest); } else { - ready->type = TraversalPacketConnectFailed; + ready->type = TraversalPacketType::ConnectFailed; ready->connectFailed.requestId = info->misc; - ready->connectFailed.reason = TraversalConnectFailedClientFailure; + ready->connectFailed.reason = TraversalConnectFailedReason::ClientFailure; } } outgoingPackets.erase(it); break; } - case TraversalPacketPing: + case TraversalPacketType::Ping: { auto r = EvictFind(connectedClients, packet->ping.hostId, true); packetOk = r.found; break; } - case TraversalPacketHelloFromClient: + case TraversalPacketType::HelloFromClient: { u8 ok = packet->helloFromClient.protoVersion <= TraversalProtoVersion; TraversalPacket* reply = AllocPacket(*addr); - reply->type = TraversalPacketHelloFromServer; + reply->type = TraversalPacketType::HelloFromServer; reply->helloFromServer.ok = ok; if (ok) { @@ -323,32 +325,34 @@ static void HandlePacket(TraversalPacket* packet, sockaddr_in6* addr) } break; } - case TraversalPacketConnectPlease: + case TraversalPacketType::ConnectPlease: { TraversalHostId& hostId = packet->connectPlease.hostId; auto r = EvictFind(connectedClients, hostId); if (!r.found) { TraversalPacket* reply = AllocPacket(*addr); - reply->type = TraversalPacketConnectFailed; + reply->type = TraversalPacketType::ConnectFailed; reply->connectFailed.requestId = packet->requestId; - reply->connectFailed.reason = TraversalConnectFailedNoSuchClient; + reply->connectFailed.reason = TraversalConnectFailedReason::NoSuchClient; } else { TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), packet->requestId); - please->type = TraversalPacketPleaseSendPacket; + please->type = TraversalPacketType::PleaseSendPacket; please->pleaseSendPacket.address = MakeInetAddress(*addr); } break; } default: - fprintf(stderr, "received unknown packet type %d from %s\n", packet->type, SenderName(addr)); + fprintf(stderr, "received unknown packet type %d from %s\n", static_cast(packet->type), + SenderName(addr)); + break; } - if (packet->type != TraversalPacketAck) + if (packet->type != TraversalPacketType::Ack) { TraversalPacket ack = {}; - ack.type = TraversalPacketAck; + ack.type = TraversalPacketType::Ack; ack.requestId = packet->requestId; ack.ack.ok = packetOk; TrySend(&ack, sizeof(ack), addr); diff --git a/Source/Core/Core/NetPlayClient.cpp b/Source/Core/Core/NetPlayClient.cpp index a505564c74..f144f77c52 100644 --- a/Source/Core/Core/NetPlayClient.cpp +++ b/Source/Core/Core/NetPlayClient.cpp @@ -174,7 +174,7 @@ NetPlayClient::NetPlayClient(const std::string& address, const u16 port, NetPlay m_traversal_client = g_TraversalClient.get(); // If we were disconnected in the background, reconnect. - if (m_traversal_client->GetState() == TraversalClient::Failure) + if (m_traversal_client->HasFailed()) m_traversal_client->ReconnectToServer(); m_traversal_client->m_Client = this; m_host_spec = address; @@ -1755,12 +1755,13 @@ void NetPlayClient::OnTraversalStateChanged() const TraversalClient::State state = m_traversal_client->GetState(); if (m_connection_state == ConnectionState::WaitingForTraversalClientConnection && - state == TraversalClient::Connected) + state == TraversalClient::State::Connected) { m_connection_state = ConnectionState::WaitingForTraversalClientConnectReady; m_traversal_client->ConnectToClient(m_host_spec); } - else if (m_connection_state != ConnectionState::Failure && state == TraversalClient::Failure) + else if (m_connection_state != ConnectionState::Failure && + state == TraversalClient::State::Failure) { Disconnect(); m_dialog->OnTraversalError(m_traversal_client->GetFailureReason()); @@ -1779,19 +1780,19 @@ void NetPlayClient::OnConnectReady(ENetAddress addr) } // called from ---NETPLAY--- thread -void NetPlayClient::OnConnectFailed(u8 reason) +void NetPlayClient::OnConnectFailed(TraversalConnectFailedReason reason) { m_connecting = false; m_connection_state = ConnectionState::Failure; switch (reason) { - case TraversalConnectFailedClientDidntRespond: + case TraversalConnectFailedReason::ClientDidntRespond: PanicAlertFmtT("Traversal server timed out connecting to the host"); break; - case TraversalConnectFailedClientFailure: + case TraversalConnectFailedReason::ClientFailure: PanicAlertFmtT("Server rejected traversal attempt"); break; - case TraversalConnectFailedNoSuchClient: + case TraversalConnectFailedReason::NoSuchClient: PanicAlertFmtT("Invalid host"); break; default: diff --git a/Source/Core/Core/NetPlayClient.h b/Source/Core/Core/NetPlayClient.h index cced924dbb..c0ab68f7f0 100644 --- a/Source/Core/Core/NetPlayClient.h +++ b/Source/Core/Core/NetPlayClient.h @@ -125,7 +125,7 @@ public: void OnTraversalStateChanged() override; void OnConnectReady(ENetAddress addr) override; - void OnConnectFailed(u8 reason) override; + void OnConnectFailed(TraversalConnectFailedReason reason) override; bool IsFirstInGamePad(int ingame_pad) const; int NumLocalPads() const; diff --git a/Source/Core/Core/NetPlayServer.cpp b/Source/Core/Core/NetPlayServer.cpp index 503aaecdca..eafddd36e6 100644 --- a/Source/Core/Core/NetPlayServer.cpp +++ b/Source/Core/Core/NetPlayServer.cpp @@ -126,7 +126,7 @@ NetPlayServer::NetPlayServer(const u16 port, const bool forward_port, NetPlayUI* m_server = g_MainNetHost.get(); - if (g_TraversalClient->GetState() == TraversalClient::Failure) + if (g_TraversalClient->HasFailed()) g_TraversalClient->ReconnectToServer(); } else @@ -190,7 +190,7 @@ void NetPlayServer::SetupIndex() if (m_traversal_client) { - if (m_traversal_client->GetState() != TraversalClient::Connected) + if (!m_traversal_client->IsConnected()) return; session.server_id = std::string(g_TraversalClient->GetHostID().data(), 8); @@ -1149,7 +1149,7 @@ void NetPlayServer::OnTraversalStateChanged() if (!m_dialog) return; - if (state == TraversalClient::Failure) + if (state == TraversalClient::State::Failure) m_dialog->OnTraversalError(m_traversal_client->GetFailureReason()); m_dialog->OnTraversalStateChanged(state); diff --git a/Source/Core/Core/NetPlayServer.h b/Source/Core/Core/NetPlayServer.h index 16034074af..67819be9bf 100644 --- a/Source/Core/Core/NetPlayServer.h +++ b/Source/Core/Core/NetPlayServer.h @@ -134,7 +134,7 @@ private: void OnTraversalStateChanged() override; void OnConnectReady(ENetAddress) override {} - void OnConnectFailed(u8) override {} + void OnConnectFailed(TraversalConnectFailedReason) override {} void UpdatePadMapping(); void UpdateWiimoteMapping(); std::vector> GetInterfaceListInternal() const; diff --git a/Source/Core/DolphinQt/NetPlay/NetPlayDialog.cpp b/Source/Core/DolphinQt/NetPlay/NetPlayDialog.cpp index c8b19678b3..03626dbb0e 100644 --- a/Source/Core/DolphinQt/NetPlay/NetPlayDialog.cpp +++ b/Source/Core/DolphinQt/NetPlay/NetPlayDialog.cpp @@ -711,11 +711,11 @@ void NetPlayDialog::UpdateGUI() { switch (g_TraversalClient->GetState()) { - case TraversalClient::Connecting: + case TraversalClient::State::Connecting: m_hostcode_label->setText(tr("...")); m_hostcode_action_button->setEnabled(false); break; - case TraversalClient::Connected: + case TraversalClient::State::Connected: { const auto host_id = g_TraversalClient->GetHostID(); m_hostcode_label->setText( @@ -725,7 +725,7 @@ void NetPlayDialog::UpdateGUI() m_is_copy_button_retry = false; break; } - case TraversalClient::Failure: + case TraversalClient::State::Failure: m_hostcode_label->setText(tr("Error")); m_hostcode_action_button->setText(tr("Retry")); m_hostcode_action_button->setEnabled(true); @@ -1003,6 +1003,7 @@ void NetPlayDialog::OnTraversalStateChanged(TraversalClient::State state) case TraversalClient::State::Connected: case TraversalClient::State::Failure: UpdateDiscordPresence(); + break; default: break; }