From 7fc240858e2ebdab127e4512040495b298b8a133 Mon Sep 17 00:00:00 2001 From: JohnCorby Date: Thu, 13 Jan 2022 20:52:04 -0800 Subject: [PATCH] message handler shenanigans --- Mirror/Runtime/MessagePacking.cs | 8 +-- Mirror/Runtime/NetworkClient.cs | 12 +++- Mirror/Runtime/NetworkConnection.cs | 2 +- Mirror/Runtime/NetworkDiagnostics.cs | 4 +- Mirror/Runtime/NetworkServer.cs | 12 +++- QSB/Messaging/QSBMessage.cs | 8 +-- QSB/Messaging/QSBMessageManager.cs | 104 +++++++++++++-------------- 7 files changed, 82 insertions(+), 68 deletions(-) diff --git a/Mirror/Runtime/MessagePacking.cs b/Mirror/Runtime/MessagePacking.cs index 4c5254b2..00ae6d8b 100644 --- a/Mirror/Runtime/MessagePacking.cs +++ b/Mirror/Runtime/MessagePacking.cs @@ -22,7 +22,7 @@ namespace Mirror - HeaderSize - Batcher.HeaderSize; - public static ushort GetId() where T : struct, NetworkMessage + public static ushort GetId() where T : NetworkMessage { // paul: 16 bits is enough to avoid collisions // - keeps the message size small @@ -34,7 +34,7 @@ namespace Mirror // -> NetworkWriter passed as arg so that we can use .ToArraySegment // and do an allocation free send before recycling it. public static void Pack(T message, NetworkWriter writer) - where T : struct, NetworkMessage + where T : NetworkMessage { ushort msgType = GetId(); writer.WriteUShort(msgType); @@ -64,7 +64,7 @@ namespace Mirror // version for handlers with channelId internal static NetworkMessageDelegate WrapHandler(Action handler, bool requireAuthentication) - where T : struct, NetworkMessage + where T : NetworkMessage where C : NetworkConnection => (conn, reader, channelId) => { @@ -129,7 +129,7 @@ namespace Mirror // TODO obsolete this some day to always use the channelId version. // all handlers in this version are wrapped with 1 extra action. internal static NetworkMessageDelegate WrapHandler(Action handler, bool requireAuthentication) - where T : struct, NetworkMessage + where T : NetworkMessage where C : NetworkConnection { // wrap action as channelId version, call original diff --git a/Mirror/Runtime/NetworkClient.cs b/Mirror/Runtime/NetworkClient.cs index 5494b9bd..bae93763 100644 --- a/Mirror/Runtime/NetworkClient.cs +++ b/Mirror/Runtime/NetworkClient.cs @@ -424,7 +424,7 @@ namespace Mirror // send //////////////////////////////////////////////////////////////// /// Send a NetworkMessage to the server over the given channel. public static void Send(T message, int channelId = Channels.Reliable) - where T : struct, NetworkMessage + where T : NetworkMessage { if (connection != null) { @@ -454,6 +454,16 @@ namespace Mirror handlers[msgType] = MessagePacking.WrapHandler((Action) HandlerWrapped, requireAuthentication); } + public static void RegisterHandlerSafe(Action handler) + where T : NetworkMessage + { + var msgType = MessagePacking.GetId(); + if (!handlers.ContainsKey(msgType)) + { + handlers[msgType] = MessagePacking.WrapHandler((NetworkConnection _, T msg) => handler(msg), true); + } + } + /// Replace a handler for a particular message type. Should require authentication by default. // RegisterHandler throws a warning (as it should) if a handler is assigned twice // Use of ReplaceHandler makes it clear the user intended to replace the handler diff --git a/Mirror/Runtime/NetworkConnection.cs b/Mirror/Runtime/NetworkConnection.cs index 3a2d6a57..77b77ddd 100644 --- a/Mirror/Runtime/NetworkConnection.cs +++ b/Mirror/Runtime/NetworkConnection.cs @@ -125,7 +125,7 @@ namespace Mirror // Send stage one: NetworkMessage /// Send a NetworkMessage to this connection over the given channel. public void Send(T message, int channelId = Channels.Reliable) - where T : struct, NetworkMessage + where T : NetworkMessage { using (PooledNetworkWriter writer = NetworkWriterPool.GetWriter()) { diff --git a/Mirror/Runtime/NetworkDiagnostics.cs b/Mirror/Runtime/NetworkDiagnostics.cs index 1cdc96fd..64fcea46 100644 --- a/Mirror/Runtime/NetworkDiagnostics.cs +++ b/Mirror/Runtime/NetworkDiagnostics.cs @@ -41,7 +41,7 @@ namespace Mirror } internal static void OnSend(T message, int channel, int bytes, int count) - where T : struct, NetworkMessage + where T : NetworkMessage { if (count > 0 && OutMessageEvent != null) { @@ -51,7 +51,7 @@ namespace Mirror } internal static void OnReceive(T message, int channel, int bytes) - where T : struct, NetworkMessage + where T : NetworkMessage { if (InMessageEvent != null) { diff --git a/Mirror/Runtime/NetworkServer.cs b/Mirror/Runtime/NetworkServer.cs index ec0d20d3..04927e76 100644 --- a/Mirror/Runtime/NetworkServer.cs +++ b/Mirror/Runtime/NetworkServer.cs @@ -261,7 +261,7 @@ namespace Mirror // send //////////////////////////////////////////////////////////////// /// Send a message to all clients, even those that haven't joined the world yet (non ready) public static void SendToAll(T message, int channelId = Channels.Reliable, bool sendToReadyOnly = false) - where T : struct, NetworkMessage + where T : NetworkMessage { if (!active) { @@ -616,6 +616,16 @@ namespace Mirror handlers[msgType] = MessagePacking.WrapHandler(handler, requireAuthentication); } + public static void RegisterHandlerSafe(Action handler) + where T : NetworkMessage + { + var msgType = MessagePacking.GetId(); + if (!handlers.ContainsKey(msgType)) + { + handlers[msgType] = MessagePacking.WrapHandler((NetworkConnection _, T msg) => handler(msg), true); + } + } + /// Replace a handler for message type T. Most should require authentication. public static void ReplaceHandler(Action handler, bool requireAuthentication = true) where T : struct, NetworkMessage diff --git a/QSB/Messaging/QSBMessage.cs b/QSB/Messaging/QSBMessage.cs index 39a7609b..864d56b2 100644 --- a/QSB/Messaging/QSBMessage.cs +++ b/QSB/Messaging/QSBMessage.cs @@ -1,10 +1,10 @@ -using QuantumUNET.Messages; +using Mirror; using QuantumUNET.Transport; using System; namespace QSB.Messaging { - public abstract class QSBMessage : QMessageBase + public abstract class QSBMessage : NetworkMessage { /// /// set automatically by Send @@ -19,7 +19,7 @@ namespace QSB.Messaging /// /// call the base method when overriding /// - public override void Serialize(QNetworkWriter writer) + public virtual void Serialize(QNetworkWriter writer) { writer.Write(From); writer.Write(To); @@ -31,7 +31,7 @@ namespace QSB.Messaging /// note: no constructor is called before this, /// so fields won't be initialized. /// - public override void Deserialize(QNetworkReader reader) + public virtual void Deserialize(QNetworkReader reader) { From = reader.ReadUInt32(); To = reader.ReadUInt32(); diff --git a/QSB/Messaging/QSBMessageManager.cs b/QSB/Messaging/QSBMessageManager.cs index 75c8650f..938053fa 100644 --- a/QSB/Messaging/QSBMessageManager.cs +++ b/QSB/Messaging/QSBMessageManager.cs @@ -1,4 +1,5 @@ -using OWML.Common; +using Mirror; +using OWML.Common; using QSB.ClientServerStateSync; using QSB.ClientServerStateSync.Messages; using QSB.Player; @@ -12,6 +13,7 @@ using QuantumUNET.Messages; using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.Serialization; @@ -21,100 +23,95 @@ namespace QSB.Messaging { #region inner workings - private static readonly Dictionary _msgTypeToType = new(); - private static readonly Dictionary _typeToMsgType = new(); + private static readonly Type[] _types; static QSBMessageManager() { - var types = typeof(QSBMessageRaw).GetDerivedTypes() + _types = typeof(QSBMessageRaw).GetDerivedTypes() .Concat(typeof(QSBMessage).GetDerivedTypes()) .ToArray(); - for (var i = 0; i < types.Length; i++) + foreach (var type in _types) { - var msgType = (short)(QMsgType.Highest + 1 + i); - if (msgType >= short.MaxValue) - { - DebugLog.ToConsole("Hey, uh, maybe don't create 32,767 events? You really should never be seeing this." + - "If you are, something has either gone terrible wrong or QSB somehow needs more events that classes in Outer Wilds." + - "In either case, I guess something has gone terribly wrong...", MessageType.Error); - } - - _msgTypeToType.Add(msgType, types[i]); - _typeToMsgType.Add(types[i], msgType); - // call static constructor of message if needed - RuntimeHelpers.RunClassConstructor(types[i].TypeHandle); + RuntimeHelpers.RunClassConstructor(type.TypeHandle); } } public static void Init() { - foreach (var (msgType, type) in _msgTypeToType) + DebugLog.DebugWrite("REGISTERING MESSAGES"); + + var NetworkServer_RegisterHandlerSafe = typeof(NetworkServer).GetMethod(nameof(NetworkServer.RegisterHandlerSafe)); + var NetworkClient_RegisterHandlerSafe = typeof(NetworkClient).GetMethod(nameof(NetworkClient.RegisterHandlerSafe)); + var OnServerReceiveRaw = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnServerReceiveRaw)); + var OnClientReceiveRaw = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnClientReceiveRaw)); + var OnServerReceive = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnServerReceive)); + var OnClientReceive = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnClientReceive)); + + foreach (var type in _types) { + MethodInfo OnServerReceive2; + MethodInfo OnClientReceive2; + if (typeof(QSBMessageRaw).IsAssignableFrom(type)) { - QNetworkServer.RegisterHandlerSafe(msgType, OnServerReceiveRaw); - QNetworkManager.singleton.client.RegisterHandlerSafe(msgType, OnClientReceiveRaw); + OnServerReceive2 = OnServerReceiveRaw; + OnClientReceive2 = OnClientReceiveRaw; } else { - QNetworkServer.RegisterHandlerSafe(msgType, OnServerReceive); - QNetworkManager.singleton.client.RegisterHandlerSafe(msgType, OnClientReceive); + OnServerReceive2 = OnServerReceive; + OnClientReceive2 = OnClientReceive; } + + var serverHandler = OnServerReceive2.MakeGenericMethod(type).CreateDelegate(typeof(Action<>)); + var clientHandler = OnClientReceive2.MakeGenericMethod(type).CreateDelegate(typeof(Action<>)); + DebugLog.DebugWrite($"server handler = {serverHandler}"); + DebugLog.DebugWrite($"client handler = {clientHandler}"); + NetworkServer_RegisterHandlerSafe.MakeGenericMethod(type).Invoke(null, new object[] { serverHandler }); + NetworkClient_RegisterHandlerSafe.MakeGenericMethod(type).Invoke(null, new object[] { clientHandler }); } } - private static void OnServerReceiveRaw(QNetworkMessage netMsg) + private static void OnServerReceiveRaw(T msg) + where T : QSBMessageRaw { - var msgType = netMsg.MsgType; - var msg = (QSBMessageRaw)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]); - netMsg.ReadMessage(msg); - - QNetworkServer.SendToAll(msgType, msg); + NetworkServer.SendToAll(msg); } - private static void OnClientReceiveRaw(QNetworkMessage netMsg) + private static void OnClientReceiveRaw(T msg) + where T : QSBMessageRaw { - var msgType = netMsg.MsgType; - var msg = (QSBMessageRaw)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]); - netMsg.ReadMessage(msg); - msg.OnReceive(); } - private static void OnServerReceive(QNetworkMessage netMsg) + private static void OnServerReceive(T msg) + where T : QSBMessage { - var msgType = netMsg.MsgType; - var msg = (QSBMessage)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]); - netMsg.ReadMessage(msg); - if (msg.To == uint.MaxValue) { - QNetworkServer.SendToAll(msgType, msg); + NetworkServer.SendToAll(msg); } else if (msg.To == 0) { - QNetworkServer.localConnection.Send(msgType, msg); + NetworkServer.localConnection.Send(msg); } else { - var conn = QNetworkServer.connections.FirstOrDefault(x => msg.To == x.GetPlayerId()); + var conn = NetworkServer.connections.Values.FirstOrDefault(x => msg.To == x.identity.netId); if (conn == null) { DebugLog.ToConsole($"SendTo unknown player! id: {msg.To}, message: {msg}", MessageType.Error); return; } - conn.Send(msgType, msg); + conn.Send(msg); } } - private static void OnClientReceive(QNetworkMessage netMsg) + private static void OnClientReceive(T msg) + where T : QSBMessage { - var msgType = netMsg.MsgType; - var msg = (QSBMessage)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]); - netMsg.ReadMessage(msg); - if (PlayerTransformSync.LocalInstance == null) { DebugLog.ToConsole($"Warning - Tried to handle message {msg} before local player was established.", MessageType.Warning); @@ -162,15 +159,13 @@ namespace QSB.Messaging public static void SendRaw(this M msg) where M : QSBMessageRaw { - var msgType = _typeToMsgType[typeof(M)]; - QNetworkManager.singleton.client.Send(msgType, msg); + NetworkClient.Send(msg); } - public static void ServerSendRaw(this M msg, QNetworkConnection conn) + public static void ServerSendRaw(this M msg, NetworkConnectionToClient conn) where M : QSBMessageRaw { - var msgType = _typeToMsgType[typeof(M)]; - conn.Send(msgType, msg); + conn.Send(msg); } public static void Send(this M msg) @@ -183,8 +178,7 @@ namespace QSB.Messaging } msg.From = QSBPlayerManager.LocalPlayerId; - var msgType = _typeToMsgType[typeof(M)]; - QNetworkManager.singleton.client.Send(msgType, msg); + NetworkClient.Send(msg); } public static void SendMessage(this T worldObject, M msg) @@ -200,7 +194,7 @@ namespace QSB.Messaging /// message that will be sent to every client.
/// no checks are performed on the message. it is just sent and received. /// - public abstract class QSBMessageRaw : QMessageBase + public abstract class QSBMessageRaw : NetworkMessage { public abstract void OnReceive(); public override string ToString() => GetType().Name;