Skip to content

Commit fd31b2a

Browse files
committed
Sync connection tracking and graceful shutdown
* Implement sync server connection tracking. * Add ServerConnection.close() call for exising connections on server shutdown. This is useful for cleanly terminating/restarting the server process. Issue #1488
1 parent bb78c20 commit fd31b2a

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

src/websockets/sync/server.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,20 @@ def __init__(
238238
socket: socket.socket,
239239
handler: Callable[[socket.socket, Any], None],
240240
logger: LoggerLike | None = None,
241+
*,
242+
connections: set[ServerConnection] | None = None,
241243
) -> None:
242244
self.socket = socket
243245
self.handler = handler
244246
if logger is None:
245247
logger = logging.getLogger("websockets.server")
246248
self.logger = logger
249+
250+
# _connections tracks active connections
251+
if connections is None:
252+
connections = set()
253+
self._connections = connections
254+
247255
if sys.platform != "win32":
248256
self.shutdown_watcher, self.shutdown_notifier = os.pipe()
249257

@@ -285,15 +293,36 @@ def serve_forever(self) -> None:
285293
thread = threading.Thread(target=self.handler, args=(sock, addr))
286294
thread.start()
287295

288-
def shutdown(self) -> None:
296+
def shutdown(
297+
self, *, code: CloseCode = CloseCode.NORMAL_CLOSURE, reason: str = ""
298+
) -> None:
289299
"""
290300
See :meth:`socketserver.BaseServer.shutdown`.
291301
302+
Shuts down the server and closes existing connections. Optional arguments
303+
``code`` and ``reason`` can be used to provide additional information to
304+
the clients, e.g.,::
305+
306+
server.shutdown(reason="scheduled_maintenance")
307+
308+
Args:
309+
code: Closing code, defaults to ``CloseCode.NORMAL_CLOSURE``.
310+
reason: Closing reason, default to empty string.
311+
292312
"""
293313
self.socket.close()
294314
if sys.platform != "win32":
295315
os.write(self.shutdown_notifier, b"x")
296316

317+
# Close all connections
318+
conns = list(self._connections)
319+
for conn in conns:
320+
try:
321+
conn.close(code=code, reason=reason)
322+
except Exception as exc:
323+
debug_msg = f"Could not close {conn.id}: {exc}"
324+
self.logger.debug(debug_msg, exc_info=exc)
325+
297326
def fileno(self) -> int:
298327
"""
299328
See :meth:`socketserver.BaseServer.fileno`.
@@ -516,6 +545,24 @@ def handler(websocket):
516545
do_handshake_on_connect=False,
517546
)
518547

548+
# Stores active ServerConnection instances, used by the server to handle graceful
549+
# shutdown in Server.shutdown()
550+
connections: set[ServerConnection] = set()
551+
552+
def on_connection_created(connection: ServerConnection) -> None:
553+
# Invoked from conn_handler() to add a new ServerConnection instance to
554+
# Server._connections
555+
connections.add(connection)
556+
557+
def on_connection_closed(connection: ServerConnection) -> None:
558+
# Invoked from conn_handler() to remove a closed ServerConnection instance from
559+
# Server._connections. Keeping only active references in the set is important
560+
# for avoiding memory leaks.
561+
try:
562+
connections.remove(connection)
563+
except KeyError: # pragma: no cover
564+
pass
565+
519566
# Define request handler
520567

521568
def conn_handler(sock: socket.socket, addr: Any) -> None:
@@ -581,6 +628,7 @@ def protocol_select_subprotocol(
581628
close_timeout=close_timeout,
582629
max_queue=max_queue,
583630
)
631+
on_connection_created(connection)
584632
except Exception:
585633
sock.close()
586634
return
@@ -595,11 +643,13 @@ def protocol_select_subprotocol(
595643
)
596644
except TimeoutError:
597645
connection.close_socket()
646+
on_connection_closed(connection)
598647
connection.recv_events_thread.join()
599648
return
600649
except Exception:
601650
connection.logger.error("opening handshake failed", exc_info=True)
602651
connection.close_socket()
652+
on_connection_closed(connection)
603653
connection.recv_events_thread.join()
604654
return
605655

@@ -610,16 +660,18 @@ def protocol_select_subprotocol(
610660
except Exception:
611661
connection.logger.error("connection handler failed", exc_info=True)
612662
connection.close(CloseCode.INTERNAL_ERROR)
663+
on_connection_closed(connection)
613664
else:
614665
connection.close()
666+
on_connection_closed(connection)
615667

616668
except Exception: # pragma: no cover
617669
# Don't leak sockets on unexpected errors.
618670
sock.close()
619671

620672
# Initialize server
621673

622-
return Server(sock, conn_handler, logger)
674+
return Server(sock, conn_handler, logger, connections=connections)
623675

624676

625677
def unix_serve(

tests/sync/test_server.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import http
44
import logging
55
import socket
6+
import threading
67
import time
78
import unittest
89

@@ -12,6 +13,7 @@
1213
InvalidStatus,
1314
NegotiationError,
1415
)
16+
from websockets import CloseCode, State
1517
from websockets.http11 import Request, Response
1618
from websockets.sync.client import connect, unix_connect
1719
from websockets.sync.server import *
@@ -338,6 +340,71 @@ def test_junk_handshake(self):
338340
["invalid HTTP request line: HELO relay.invalid"],
339341
)
340342

343+
def test_initialize_server_without_tracking_connections(self):
344+
"""Call Server() constructor without 'connections' arg."""
345+
with socket.create_server(("localhost", 0)) as sock:
346+
server = Server(socket=sock, handler=handler)
347+
self.assertIsInstance(
348+
server._connections, set, "Server._connections property not initialized"
349+
)
350+
351+
def test_connections_is_empty_after_disconnects(self):
352+
"""Clients are added to Server._connections, and removed when disconnected."""
353+
with run_server() as server:
354+
connections: set[ServerConnection] = server._connections
355+
with connect(get_uri(server)) as client:
356+
self.assertEqual(len(connections), 1)
357+
time.sleep(0.5)
358+
self.assertEqual(len(connections), 0)
359+
360+
def test_shutdown_calls_close_for_all_connections(self):
361+
"""Graceful shutdown with broken ServerConnection.close() implementations."""
362+
CLIENTS_TO_LAUNCH = 3
363+
364+
connections_attempted = 0
365+
366+
class ServerConnectionWithBrokenClose(ServerConnection):
367+
close_method_called = False
368+
369+
def close(self, code=CloseCode.NORMAL_CLOSURE, reason=""):
370+
"""Custom close method that intentionally fails."""
371+
372+
# Do not increment the counter when calling .close() multiple times
373+
if self.close_method_called:
374+
return
375+
self.close_method_called = True
376+
377+
nonlocal connections_attempted
378+
connections_attempted += 1
379+
raise Exception("broken close method")
380+
381+
clients: set[threading.Thread] = set()
382+
with run_server(create_connection=ServerConnectionWithBrokenClose) as server:
383+
384+
def client():
385+
with connect(get_uri(server)) as client:
386+
time.sleep(1)
387+
388+
for i in range(CLIENTS_TO_LAUNCH):
389+
client_thread = threading.Thread(target=client)
390+
client_thread.start()
391+
clients.add(client_thread)
392+
time.sleep(0.2)
393+
self.assertEqual(
394+
len(server._connections),
395+
CLIENTS_TO_LAUNCH,
396+
"not all clients connected to the server yet, increase sleep duration",
397+
)
398+
server.shutdown()
399+
while len(clients) > 0:
400+
client = clients.pop()
401+
client.join()
402+
self.assertEqual(
403+
connections_attempted,
404+
CLIENTS_TO_LAUNCH,
405+
"server did not call ServerConnection.close() on all connections",
406+
)
407+
341408

342409
class SecureServerTests(EvalShellMixin, unittest.TestCase):
343410
def test_connection(self):

0 commit comments

Comments
 (0)