message handler shenanigans

This commit is contained in:
JohnCorby 2022-01-13 20:52:04 -08:00
parent 63efda975b
commit 7fc240858e
7 changed files with 82 additions and 68 deletions

View File

@ -22,7 +22,7 @@ namespace Mirror
- HeaderSize
- Batcher.HeaderSize;
public static ushort GetId<T>() where T : struct, NetworkMessage
public static ushort GetId<T>() where T : NetworkMessage
{
// paul: 16 bits is enough to avoid collisions
// - keeps the message size small
@ -34,7 +34,7 @@ namespace Mirror
// -> NetworkWriter passed as arg so that we can use .ToArraySegment
// and do an allocation free send before recycling it.
public static void Pack<T>(T message, NetworkWriter writer)
where T : struct, NetworkMessage
where T : NetworkMessage
{
ushort msgType = GetId<T>();
writer.WriteUShort(msgType);
@ -64,7 +64,7 @@ namespace Mirror
// version for handlers with channelId
internal static NetworkMessageDelegate WrapHandler<T, C>(Action<C, T, int> handler, bool requireAuthentication)
where T : struct, NetworkMessage
where T : NetworkMessage
where C : NetworkConnection
=> (conn, reader, channelId) =>
{
@ -129,7 +129,7 @@ namespace Mirror
// TODO obsolete this some day to always use the channelId version.
// all handlers in this version are wrapped with 1 extra action.
internal static NetworkMessageDelegate WrapHandler<T, C>(Action<C, T> handler, bool requireAuthentication)
where T : struct, NetworkMessage
where T : NetworkMessage
where C : NetworkConnection
{
// wrap action as channelId version, call original

View File

@ -424,7 +424,7 @@ namespace Mirror
// send ////////////////////////////////////////////////////////////////
/// <summary>Send a NetworkMessage to the server over the given channel.</summary>
public static void Send<T>(T message, int channelId = Channels.Reliable)
where T : struct, NetworkMessage
where T : NetworkMessage
{
if (connection != null)
{
@ -454,6 +454,16 @@ namespace Mirror
handlers[msgType] = MessagePacking.WrapHandler((Action<NetworkConnection, T>) HandlerWrapped, requireAuthentication);
}
public static void RegisterHandlerSafe<T>(Action<T> handler)
where T : NetworkMessage
{
var msgType = MessagePacking.GetId<T>();
if (!handlers.ContainsKey(msgType))
{
handlers[msgType] = MessagePacking.WrapHandler((NetworkConnection _, T msg) => handler(msg), true);
}
}
/// <summary>Replace a handler for a particular message type. Should require authentication by default.</summary>
// RegisterHandler throws a warning (as it should) if a handler is assigned twice
// Use of ReplaceHandler makes it clear the user intended to replace the handler

View File

@ -125,7 +125,7 @@ namespace Mirror
// Send stage one: NetworkMessage<T>
/// <summary>Send a NetworkMessage to this connection over the given channel.</summary>
public void Send<T>(T message, int channelId = Channels.Reliable)
where T : struct, NetworkMessage
where T : NetworkMessage
{
using (PooledNetworkWriter writer = NetworkWriterPool.GetWriter())
{

View File

@ -41,7 +41,7 @@ namespace Mirror
}
internal static void OnSend<T>(T message, int channel, int bytes, int count)
where T : struct, NetworkMessage
where T : NetworkMessage
{
if (count > 0 && OutMessageEvent != null)
{
@ -51,7 +51,7 @@ namespace Mirror
}
internal static void OnReceive<T>(T message, int channel, int bytes)
where T : struct, NetworkMessage
where T : NetworkMessage
{
if (InMessageEvent != null)
{

View File

@ -261,7 +261,7 @@ namespace Mirror
// send ////////////////////////////////////////////////////////////////
/// <summary>Send a message to all clients, even those that haven't joined the world yet (non ready)</summary>
public static void SendToAll<T>(T message, int channelId = Channels.Reliable, bool sendToReadyOnly = false)
where T : struct, NetworkMessage
where T : NetworkMessage
{
if (!active)
{
@ -616,6 +616,16 @@ namespace Mirror
handlers[msgType] = MessagePacking.WrapHandler(handler, requireAuthentication);
}
public static void RegisterHandlerSafe<T>(Action<T> handler)
where T : NetworkMessage
{
var msgType = MessagePacking.GetId<T>();
if (!handlers.ContainsKey(msgType))
{
handlers[msgType] = MessagePacking.WrapHandler((NetworkConnection _, T msg) => handler(msg), true);
}
}
/// <summary>Replace a handler for message type T. Most should require authentication.</summary>
public static void ReplaceHandler<T>(Action<NetworkConnection, T> handler, bool requireAuthentication = true)
where T : struct, NetworkMessage

View File

@ -1,10 +1,10 @@
using QuantumUNET.Messages;
using Mirror;
using QuantumUNET.Transport;
using System;
namespace QSB.Messaging
{
public abstract class QSBMessage : QMessageBase
public abstract class QSBMessage : NetworkMessage
{
/// <summary>
/// set automatically by Send
@ -19,7 +19,7 @@ namespace QSB.Messaging
/// <summary>
/// call the base method when overriding
/// </summary>
public override void Serialize(QNetworkWriter writer)
public virtual void Serialize(QNetworkWriter writer)
{
writer.Write(From);
writer.Write(To);
@ -31,7 +31,7 @@ namespace QSB.Messaging
/// note: no constructor is called before this,
/// so fields won't be initialized.
/// </summary>
public override void Deserialize(QNetworkReader reader)
public virtual void Deserialize(QNetworkReader reader)
{
From = reader.ReadUInt32();
To = reader.ReadUInt32();

View File

@ -1,4 +1,5 @@
using OWML.Common;
using Mirror;
using OWML.Common;
using QSB.ClientServerStateSync;
using QSB.ClientServerStateSync.Messages;
using QSB.Player;
@ -12,6 +13,7 @@ using QuantumUNET.Messages;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Serialization;
@ -21,100 +23,95 @@ namespace QSB.Messaging
{
#region inner workings
private static readonly Dictionary<short, Type> _msgTypeToType = new();
private static readonly Dictionary<Type, short> _typeToMsgType = new();
private static readonly Type[] _types;
static QSBMessageManager()
{
var types = typeof(QSBMessageRaw).GetDerivedTypes()
_types = typeof(QSBMessageRaw).GetDerivedTypes()
.Concat(typeof(QSBMessage).GetDerivedTypes())
.ToArray();
for (var i = 0; i < types.Length; i++)
foreach (var type in _types)
{
var msgType = (short)(QMsgType.Highest + 1 + i);
if (msgType >= short.MaxValue)
{
DebugLog.ToConsole("Hey, uh, maybe don't create 32,767 events? You really should never be seeing this." +
"If you are, something has either gone terrible wrong or QSB somehow needs more events that classes in Outer Wilds." +
"In either case, I guess something has gone terribly wrong...", MessageType.Error);
}
_msgTypeToType.Add(msgType, types[i]);
_typeToMsgType.Add(types[i], msgType);
// call static constructor of message if needed
RuntimeHelpers.RunClassConstructor(types[i].TypeHandle);
RuntimeHelpers.RunClassConstructor(type.TypeHandle);
}
}
public static void Init()
{
foreach (var (msgType, type) in _msgTypeToType)
DebugLog.DebugWrite("REGISTERING MESSAGES");
var NetworkServer_RegisterHandlerSafe = typeof(NetworkServer).GetMethod(nameof(NetworkServer.RegisterHandlerSafe));
var NetworkClient_RegisterHandlerSafe = typeof(NetworkClient).GetMethod(nameof(NetworkClient.RegisterHandlerSafe));
var OnServerReceiveRaw = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnServerReceiveRaw));
var OnClientReceiveRaw = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnClientReceiveRaw));
var OnServerReceive = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnServerReceive));
var OnClientReceive = typeof(QSBMessageManager).GetMethod(nameof(QSBMessageManager.OnClientReceive));
foreach (var type in _types)
{
MethodInfo OnServerReceive2;
MethodInfo OnClientReceive2;
if (typeof(QSBMessageRaw).IsAssignableFrom(type))
{
QNetworkServer.RegisterHandlerSafe(msgType, OnServerReceiveRaw);
QNetworkManager.singleton.client.RegisterHandlerSafe(msgType, OnClientReceiveRaw);
OnServerReceive2 = OnServerReceiveRaw;
OnClientReceive2 = OnClientReceiveRaw;
}
else
{
QNetworkServer.RegisterHandlerSafe(msgType, OnServerReceive);
QNetworkManager.singleton.client.RegisterHandlerSafe(msgType, OnClientReceive);
OnServerReceive2 = OnServerReceive;
OnClientReceive2 = OnClientReceive;
}
var serverHandler = OnServerReceive2.MakeGenericMethod(type).CreateDelegate(typeof(Action<>));
var clientHandler = OnClientReceive2.MakeGenericMethod(type).CreateDelegate(typeof(Action<>));
DebugLog.DebugWrite($"server handler = {serverHandler}");
DebugLog.DebugWrite($"client handler = {clientHandler}");
NetworkServer_RegisterHandlerSafe.MakeGenericMethod(type).Invoke(null, new object[] { serverHandler });
NetworkClient_RegisterHandlerSafe.MakeGenericMethod(type).Invoke(null, new object[] { clientHandler });
}
}
private static void OnServerReceiveRaw(QNetworkMessage netMsg)
private static void OnServerReceiveRaw<T>(T msg)
where T : QSBMessageRaw
{
var msgType = netMsg.MsgType;
var msg = (QSBMessageRaw)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]);
netMsg.ReadMessage(msg);
QNetworkServer.SendToAll(msgType, msg);
NetworkServer.SendToAll(msg);
}
private static void OnClientReceiveRaw(QNetworkMessage netMsg)
private static void OnClientReceiveRaw<T>(T msg)
where T : QSBMessageRaw
{
var msgType = netMsg.MsgType;
var msg = (QSBMessageRaw)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]);
netMsg.ReadMessage(msg);
msg.OnReceive();
}
private static void OnServerReceive(QNetworkMessage netMsg)
private static void OnServerReceive<T>(T msg)
where T : QSBMessage
{
var msgType = netMsg.MsgType;
var msg = (QSBMessage)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]);
netMsg.ReadMessage(msg);
if (msg.To == uint.MaxValue)
{
QNetworkServer.SendToAll(msgType, msg);
NetworkServer.SendToAll(msg);
}
else if (msg.To == 0)
{
QNetworkServer.localConnection.Send(msgType, msg);
NetworkServer.localConnection.Send(msg);
}
else
{
var conn = QNetworkServer.connections.FirstOrDefault(x => msg.To == x.GetPlayerId());
var conn = NetworkServer.connections.Values.FirstOrDefault(x => msg.To == x.identity.netId);
if (conn == null)
{
DebugLog.ToConsole($"SendTo unknown player! id: {msg.To}, message: {msg}", MessageType.Error);
return;
}
conn.Send(msgType, msg);
conn.Send(msg);
}
}
private static void OnClientReceive(QNetworkMessage netMsg)
private static void OnClientReceive<T>(T msg)
where T : QSBMessage
{
var msgType = netMsg.MsgType;
var msg = (QSBMessage)FormatterServices.GetUninitializedObject(_msgTypeToType[msgType]);
netMsg.ReadMessage(msg);
if (PlayerTransformSync.LocalInstance == null)
{
DebugLog.ToConsole($"Warning - Tried to handle message {msg} before local player was established.", MessageType.Warning);
@ -162,15 +159,13 @@ namespace QSB.Messaging
public static void SendRaw<M>(this M msg)
where M : QSBMessageRaw
{
var msgType = _typeToMsgType[typeof(M)];
QNetworkManager.singleton.client.Send(msgType, msg);
NetworkClient.Send(msg);
}
public static void ServerSendRaw<M>(this M msg, QNetworkConnection conn)
public static void ServerSendRaw<M>(this M msg, NetworkConnectionToClient conn)
where M : QSBMessageRaw
{
var msgType = _typeToMsgType[typeof(M)];
conn.Send(msgType, msg);
conn.Send(msg);
}
public static void Send<M>(this M msg)
@ -183,8 +178,7 @@ namespace QSB.Messaging
}
msg.From = QSBPlayerManager.LocalPlayerId;
var msgType = _typeToMsgType[typeof(M)];
QNetworkManager.singleton.client.Send(msgType, msg);
NetworkClient.Send(msg);
}
public static void SendMessage<T, M>(this T worldObject, M msg)
@ -200,7 +194,7 @@ namespace QSB.Messaging
/// message that will be sent to every client. <br/>
/// no checks are performed on the message. it is just sent and received.
/// </summary>
public abstract class QSBMessageRaw : QMessageBase
public abstract class QSBMessageRaw : NetworkMessage
{
public abstract void OnReceive();
public override string ToString() => GetType().Name;