|
7 | 7 | include("netmessages.lua")
|
8 | 8 |
|
9 | 9 | -- Initialization
|
10 |
| - |
11 | 10 | HookNetChannel(
|
12 | 11 | -- nochan prevents a net channel being passed to the attach/detach functions
|
13 | 12 | -- CNetChan::ProcessMessages doesn't use a virtual hook, so we don't need to pass the net channel
|
14 | 13 | {name = "CNetChan::ProcessMessages", nochan = true}
|
15 | 14 | )
|
16 | 15 |
|
17 |
| -local function CopyBufferEnd(dst, src) |
18 |
| - local bitsleft = src:GetNumBitsLeft() |
19 |
| - local data = src:ReadBits(bitsleft) |
20 |
| - |
21 |
| - dst:WriteBits(data) |
22 |
| -end |
23 |
| - |
24 |
| -local specialmsg |
25 |
| -local specialhandler = { |
26 |
| - DefaultCopy = function(netchan, read, write) |
27 |
| - specialmsg:ReadFromBuffer(read) |
28 |
| - specialmsg:WriteToBuffer(write) |
29 |
| - end |
30 |
| -} |
31 |
| -hook.Add("PreProcessMessages", "InFilter", function(netchan, read, write, localchan) |
32 |
| - local totalbits = read:GetNumBitsLeft() + read:GetNumBitsRead() |
| 16 | +local NET_MESSAGES_INSTANCES = {} |
33 | 17 |
|
34 |
| - local islocal = netchan == localchan |
35 |
| - if not game.IsDedicated() and ((islocal and SERVER) or (not islocal and CLIENT)) then |
36 |
| - CopyBufferEnd(write, read) |
37 |
| - return |
| 18 | +local function GetNetMessageInstance(netchan, msgtype) |
| 19 | + local handler = NET_MESSAGES_INSTANCES[msgtype] |
| 20 | + if handler == nil then |
| 21 | + handler = NetMessage(netchan, msgtype, not SERVER) |
| 22 | + NET_MESSAGES_INSTANCES[msgtype] = handler |
| 23 | + else |
| 24 | + handler:Reset() |
38 | 25 | end
|
39 | 26 |
|
40 |
| - hook.Call("BASE_PreProcessMessages", nil, netchan, read, write) |
| 27 | + return handler |
| 28 | +end |
41 | 29 |
|
42 |
| - local changeLevelState = false |
| 30 | +local NET_MESSAGES_INCOMING_COPY = { |
| 31 | + NET = {}, |
| 32 | + CLC = {}, |
| 33 | + SVC = {} |
| 34 | +} |
43 | 35 |
|
44 |
| - while read:GetNumBitsLeft() >= NET_MESSAGE_BITS do |
45 |
| - local msg = read:ReadUInt(NET_MESSAGE_BITS) |
46 |
| - |
47 |
| - if CLIENT then |
48 |
| - -- Hack to prevent changelevel crashes |
49 |
| - if msg == net_SignonState then |
50 |
| - local state = read:ReadByte() |
51 |
| - |
52 |
| - if state == SIGNONSTATE_CHANGELEVEL then |
53 |
| - changeLevelState = true |
54 |
| - --print( "[gm_sourcenet] Received changelevel packet" ) |
55 |
| - end |
56 |
| - |
57 |
| - read:Seek(read:GetNumBitsRead() - 8) |
58 |
| - end |
59 |
| - end |
| 36 | +local function GetIncomingCopyTableForMessageType(msgtype) |
| 37 | + if NET_MESSAGES.NET[msgtype] ~= nil then |
| 38 | + return NET_MESSAGES_INCOMING_COPY.NET |
| 39 | + end |
60 | 40 |
|
61 |
| - local handler = NET_MESSAGES[msg] |
62 |
| - |
63 |
| - --[[if msg ~= net_NOP and msg ~= 3 and msg ~= 9 then |
64 |
| - Msg("(in) Pre Message: " .. msg .. ", bits: " .. read:GetNumBitsRead() .. "/" .. totalbits .. "\n") |
65 |
| - end--]] |
66 |
| - |
67 |
| - if not handler then |
68 |
| - if CLIENT then |
69 |
| - handler = NET_MESSAGES.SVC[msg] |
70 |
| - else |
71 |
| - handler = NET_MESSAGES.CLC[msg] |
72 |
| - end |
73 |
| - |
74 |
| - if not handler then |
75 |
| - for i = 1, netchan:GetNetMessageNum() do |
76 |
| - local m = netchan:GetNetMessage(i) |
77 |
| - if m:GetType() == msg then |
78 |
| - handler = specialhandler |
79 |
| - specialmsg = m |
80 |
| - break |
81 |
| - end |
82 |
| - end |
83 |
| - |
84 |
| - if not handler then |
85 |
| - Msg("Unknown outgoing message: " .. msg .. "\n") |
86 |
| - |
87 |
| - write:Seek(totalbits) |
88 |
| - |
89 |
| - break |
90 |
| - end |
91 |
| - end |
92 |
| - end |
| 41 | + if CLIENT and NET_MESSAGES.SVC[msgtype] ~= nil then |
| 42 | + return NET_MESSAGES_INCOMING_COPY.SVC |
| 43 | + end |
93 | 44 |
|
94 |
| - local func = handler.IncomingCopy or handler.DefaultCopy |
| 45 | + if SERVER and NET_MESSAGES.CLC[msgtype] ~= nil then |
| 46 | + return NET_MESSAGES_INCOMING_COPY.CLC |
| 47 | + end |
95 | 48 |
|
96 |
| - local success, ret = xpcall(func, debug.traceback, netchan, read, write) |
97 |
| - if not success then |
98 |
| - print(ret) |
| 49 | + return nil |
| 50 | +end |
99 | 51 |
|
100 |
| - break |
101 |
| - elseif ret == false then |
102 |
| - --if func(netchan, read, write) == false then |
103 |
| - Msg("Failed to filter message " .. msg .. "\n") |
| 52 | +local function DefaultCopy(netchan, read, write, handler) |
| 53 | + handler:ReadFromBuffer(read) |
| 54 | + handler:WriteToBuffer(write) |
| 55 | +end |
104 | 56 |
|
105 |
| - write:Seek(totalbits) |
| 57 | +hook.Add("PreProcessMessages", "InFilter", function(netchan, read, write, localchan) |
| 58 | + local islocal = netchan == localchan |
| 59 | + if not game.IsDedicated() and ((islocal and SERVER) or (not islocal and CLIENT)) then |
| 60 | + return |
| 61 | + end |
106 | 62 |
|
107 |
| - break |
| 63 | + while read:GetNumBitsLeft() >= NET_MESSAGE_BITS do |
| 64 | + local msgtype = read:ReadUInt(NET_MESSAGE_BITS) |
| 65 | + local handler = GetNetMessageInstance(netchan, msgtype) |
| 66 | + if handler == nil then |
| 67 | + MsgC(Color(255, 0, 0), "Unknown outgoing message " .. msgtype .. " with " .. read:GetNumBitsLeft() .. " bit(s) left\n") |
| 68 | + return false |
108 | 69 | end
|
109 | 70 |
|
110 |
| - --[[if msg ~= net_NOP and msg ~= 3 and msg ~= 9 then |
111 |
| - Msg("(in) Post Message: " .. msg .. " bits: " .. read:GetNumBitsRead() .. "/" .. totalbits .. "\n") |
112 |
| - end--]] |
| 71 | + local incoming_copy_table = GetIncomingCopyTableForMessageType(msgtype) |
| 72 | + local copy_function = incoming_copy_table ~= nil and incoming_copy_table[msgtype] or DefaultCopy |
| 73 | + copy_function(netchan, read, write, handler) |
| 74 | + |
| 75 | + --MsgC(Color(255, 255, 255), "NetMessage: " .. tostring(handler) .. "\n") |
113 | 76 | end
|
114 |
| - |
115 |
| - if CLIENT then |
116 |
| - if changeLevelState then |
117 |
| - --print("[gm_sourcenet] Server is changing level, calling PreNetChannelShutdown") |
118 |
| - hook.Call("PreNetChannelShutdown", nil, netchan, "Server Changing Level") |
119 |
| - end |
| 77 | + |
| 78 | + local bitsleft = read:GetNumBitsLeft() |
| 79 | + if bitsleft > 0 then |
| 80 | + -- Should be inocuous padding bits but just to be sure, let's copy them |
| 81 | + local data = read:ReadBits(bitsleft) |
| 82 | + write:WriteBits(data) |
120 | 83 | end
|
121 |
| -end) |
122 | 84 |
|
123 |
| -function FilterIncomingMessage(msg, func) |
124 |
| - local handler = NET_MESSAGES[msg] |
| 85 | + --MsgC(Color(0, 255, 0), "Fully parsed stream with " .. totalbits .. " bit(s) written\n") |
| 86 | + return true |
| 87 | +end) |
125 | 88 |
|
126 |
| - if not handler then |
127 |
| - if CLIENT then |
128 |
| - handler = NET_MESSAGES.SVC[msg] |
129 |
| - else |
130 |
| - handler = NET_MESSAGES.CLC[msg] |
131 |
| - end |
| 89 | +function FilterIncomingMessage(msgtype, func) |
| 90 | + local incoming_copy_table = GetIncomingCopyTableForMessageType(msgtype) |
| 91 | + if incoming_copy_table == nil then |
| 92 | + return false |
132 | 93 | end
|
133 | 94 |
|
134 |
| - if handler then |
135 |
| - handler.IncomingCopy = func |
136 |
| - end |
| 95 | + incoming_copy_table[msgtype] = func |
| 96 | + return true |
137 | 97 | end
|
138 | 98 |
|
139 |
| -function UnFilterIncomingMessage(msg) |
140 |
| - FilterIncomingMessage(msg, nil) |
| 99 | +function UnFilterIncomingMessage(msgtype) |
| 100 | + return FilterIncomingMessage(msgtype, nil) |
141 | 101 | end
|
0 commit comments