@@ -238,12 +238,20 @@ def __init__(
238
238
socket : socket .socket ,
239
239
handler : Callable [[socket .socket , Any ], None ],
240
240
logger : LoggerLike | None = None ,
241
+ * ,
242
+ connections : set [ServerConnection ] | None = None ,
241
243
) -> None :
242
244
self .socket = socket
243
245
self .handler = handler
244
246
if logger is None :
245
247
logger = logging .getLogger ("websockets.server" )
246
248
self .logger = logger
249
+
250
+ # _connections tracks active connections
251
+ if connections is None :
252
+ connections = set ()
253
+ self ._connections = connections
254
+
247
255
if sys .platform != "win32" :
248
256
self .shutdown_watcher , self .shutdown_notifier = os .pipe ()
249
257
@@ -285,15 +293,36 @@ def serve_forever(self) -> None:
285
293
thread = threading .Thread (target = self .handler , args = (sock , addr ))
286
294
thread .start ()
287
295
288
- def shutdown (self ) -> None :
296
+ def shutdown (
297
+ self , * , code : CloseCode = CloseCode .NORMAL_CLOSURE , reason : str = ""
298
+ ) -> None :
289
299
"""
290
300
See :meth:`socketserver.BaseServer.shutdown`.
291
301
302
+ Shuts down the server and closes existing connections. Optional arguments
303
+ ``code`` and ``reason`` can be used to provide additional information to
304
+ the clients, e.g.,::
305
+
306
+ server.shutdown(reason="scheduled_maintenance")
307
+
308
+ Args:
309
+ code: Closing code, defaults to ``CloseCode.NORMAL_CLOSURE``.
310
+ reason: Closing reason, default to empty string.
311
+
292
312
"""
293
313
self .socket .close ()
294
314
if sys .platform != "win32" :
295
315
os .write (self .shutdown_notifier , b"x" )
296
316
317
+ # Close all connections
318
+ conns = list (self ._connections )
319
+ for conn in conns :
320
+ try :
321
+ conn .close (code = code , reason = reason )
322
+ except Exception as exc :
323
+ debug_msg = f"Could not close { conn .id } : { exc } "
324
+ self .logger .debug (debug_msg , exc_info = exc )
325
+
297
326
def fileno (self ) -> int :
298
327
"""
299
328
See :meth:`socketserver.BaseServer.fileno`.
@@ -516,6 +545,24 @@ def handler(websocket):
516
545
do_handshake_on_connect = False ,
517
546
)
518
547
548
+ # Stores active ServerConnection instances, used by the server to handle graceful
549
+ # shutdown in Server.shutdown()
550
+ connections : set [ServerConnection ] = set ()
551
+
552
+ def on_connection_created (connection : ServerConnection ) -> None :
553
+ # Invoked from conn_handler() to add a new ServerConnection instance to
554
+ # Server._connections
555
+ connections .add (connection )
556
+
557
+ def on_connection_closed (connection : ServerConnection ) -> None :
558
+ # Invoked from conn_handler() to remove a closed ServerConnection instance from
559
+ # Server._connections. Keeping only active references in the set is important
560
+ # for avoiding memory leaks.
561
+ try :
562
+ connections .remove (connection )
563
+ except KeyError : # pragma: no cover
564
+ pass
565
+
519
566
# Define request handler
520
567
521
568
def conn_handler (sock : socket .socket , addr : Any ) -> None :
@@ -581,6 +628,7 @@ def protocol_select_subprotocol(
581
628
close_timeout = close_timeout ,
582
629
max_queue = max_queue ,
583
630
)
631
+ on_connection_created (connection )
584
632
except Exception :
585
633
sock .close ()
586
634
return
@@ -595,11 +643,13 @@ def protocol_select_subprotocol(
595
643
)
596
644
except TimeoutError :
597
645
connection .close_socket ()
646
+ on_connection_closed (connection )
598
647
connection .recv_events_thread .join ()
599
648
return
600
649
except Exception :
601
650
connection .logger .error ("opening handshake failed" , exc_info = True )
602
651
connection .close_socket ()
652
+ on_connection_closed (connection )
603
653
connection .recv_events_thread .join ()
604
654
return
605
655
@@ -610,16 +660,18 @@ def protocol_select_subprotocol(
610
660
except Exception :
611
661
connection .logger .error ("connection handler failed" , exc_info = True )
612
662
connection .close (CloseCode .INTERNAL_ERROR )
663
+ on_connection_closed (connection )
613
664
else :
614
665
connection .close ()
666
+ on_connection_closed (connection )
615
667
616
668
except Exception : # pragma: no cover
617
669
# Don't leak sockets on unexpected errors.
618
670
sock .close ()
619
671
620
672
# Initialize server
621
673
622
- return Server (sock , conn_handler , logger )
674
+ return Server (sock , conn_handler , logger , connections = connections )
623
675
624
676
625
677
def unix_serve (
0 commit comments