5
5
6
6
import asyncio
7
7
import atexit
8
+ import contextvars
8
9
import io
9
10
import os
10
11
import sys
11
12
import threading
12
13
import traceback
13
14
import warnings
14
15
from binascii import b2a_hex
15
- from collections import deque
16
+ from collections import defaultdict , deque
16
17
from io import StringIO , TextIOBase
17
18
from threading import local
18
19
from typing import Any , Callable , Deque , Dict , Optional
@@ -412,7 +413,7 @@ def __init__(
412
413
name : str {'stderr', 'stdout'}
413
414
the name of the standard stream to replace
414
415
pipe : object
415
- the pip object
416
+ the pipe object
416
417
echo : bool
417
418
whether to echo output
418
419
watchfd : bool (default, True)
@@ -446,13 +447,19 @@ def __init__(
446
447
self .pub_thread = pub_thread
447
448
self .name = name
448
449
self .topic = b"stream." + name .encode ()
449
- self .parent_header = {}
450
+ self ._parent_header : contextvars .ContextVar [Dict [str , Any ]] = contextvars .ContextVar (
451
+ "parent_header"
452
+ )
453
+ self ._parent_header .set ({})
454
+ self ._thread_to_parent = {}
455
+ self ._thread_to_parent_header = {}
456
+ self ._parent_header_global = {}
450
457
self ._master_pid = os .getpid ()
451
458
self ._flush_pending = False
452
459
self ._subprocess_flush_pending = False
453
460
self ._io_loop = pub_thread .io_loop
454
461
self ._buffer_lock = threading .RLock ()
455
- self ._buffer = StringIO ( )
462
+ self ._buffers = defaultdict ( StringIO )
456
463
self .echo = None
457
464
self ._isatty = bool (isatty )
458
465
self ._should_watch = False
@@ -495,6 +502,30 @@ def __init__(
495
502
msg = "echo argument must be a file-like object"
496
503
raise ValueError (msg )
497
504
505
+ @property
506
+ def parent_header (self ):
507
+ try :
508
+ # asyncio-specific
509
+ return self ._parent_header .get ()
510
+ except LookupError :
511
+ try :
512
+ # thread-specific
513
+ identity = threading .current_thread ().ident
514
+ # retrieve the outermost (oldest ancestor,
515
+ # discounting the kernel thread) thread identity
516
+ while identity in self ._thread_to_parent :
517
+ identity = self ._thread_to_parent [identity ]
518
+ # use the header of the oldest ancestor
519
+ return self ._thread_to_parent_header [identity ]
520
+ except KeyError :
521
+ # global (fallback)
522
+ return self ._parent_header_global
523
+
524
+ @parent_header .setter
525
+ def parent_header (self , value ):
526
+ self ._parent_header_global = value
527
+ return self ._parent_header .set (value )
528
+
498
529
def isatty (self ):
499
530
"""Return a bool indicating whether this is an 'interactive' stream.
500
531
@@ -598,28 +629,28 @@ def _flush(self):
598
629
if self .echo is not sys .__stderr__ :
599
630
print (f"Flush failed: { e } " , file = sys .__stderr__ )
600
631
601
- data = self ._flush_buffer ()
602
- if data :
603
- # FIXME: this disables Session's fork-safe check,
604
- # since pub_thread is itself fork-safe.
605
- # There should be a better way to do this.
606
- self .session .pid = os .getpid ()
607
- content = {"name" : self .name , "text" : data }
608
- msg = self .session .msg ("stream" , content , parent = self . parent_header )
609
-
610
- # Each transform either returns a new
611
- # message or None. If None is returned,
612
- # the message has been 'used' and we return.
613
- for hook in self ._hooks :
614
- msg = hook (msg )
615
- if msg is None :
616
- return
617
-
618
- self .session .send (
619
- self .pub_thread ,
620
- msg ,
621
- ident = self .topic ,
622
- )
632
+ for parent , data in self ._flush_buffers ():
633
+ if data :
634
+ # FIXME: this disables Session's fork-safe check,
635
+ # since pub_thread is itself fork-safe.
636
+ # There should be a better way to do this.
637
+ self .session .pid = os .getpid ()
638
+ content = {"name" : self .name , "text" : data }
639
+ msg = self .session .msg ("stream" , content , parent = parent )
640
+
641
+ # Each transform either returns a new
642
+ # message or None. If None is returned,
643
+ # the message has been 'used' and we return.
644
+ for hook in self ._hooks :
645
+ msg = hook (msg )
646
+ if msg is None :
647
+ return
648
+
649
+ self .session .send (
650
+ self .pub_thread ,
651
+ msg ,
652
+ ident = self .topic ,
653
+ )
623
654
624
655
def write (self , string : str ) -> Optional [int ]: # type:ignore[override]
625
656
"""Write to current stream after encoding if necessary
@@ -630,6 +661,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630
661
number of items from input parameter written to stream.
631
662
632
663
"""
664
+ parent = self .parent_header
633
665
634
666
if not isinstance (string , str ):
635
667
msg = f"write() argument must be str, not { type (string )} " # type:ignore[unreachable]
@@ -649,7 +681,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649
681
is_child = not self ._is_master_process ()
650
682
# only touch the buffer in the IO thread to avoid races
651
683
with self ._buffer_lock :
652
- self ._buffer .write (string )
684
+ self ._buffers [ frozenset ( parent . items ())] .write (string )
653
685
if is_child :
654
686
# mp.Pool cannot be trusted to flush promptly (or ever),
655
687
# and this helps.
@@ -675,19 +707,20 @@ def writable(self):
675
707
"""Test whether the stream is writable."""
676
708
return True
677
709
678
- def _flush_buffer (self ):
710
+ def _flush_buffers (self ):
679
711
"""clear the current buffer and return the current buffer data."""
680
- buf = self ._rotate_buffer ()
681
- data = buf .getvalue ()
682
- buf .close ()
683
- return data
712
+ buffers = self ._rotate_buffers ()
713
+ for frozen_parent , buffer in buffers .items ():
714
+ data = buffer .getvalue ()
715
+ buffer .close ()
716
+ yield dict (frozen_parent ), data
684
717
685
- def _rotate_buffer (self ):
718
+ def _rotate_buffers (self ):
686
719
"""Returns the current buffer and replaces it with an empty buffer."""
687
720
with self ._buffer_lock :
688
- old_buffer = self ._buffer
689
- self ._buffer = StringIO ( )
690
- return old_buffer
721
+ old_buffers = self ._buffers
722
+ self ._buffers = defaultdict ( StringIO )
723
+ return old_buffers
691
724
692
725
@property
693
726
def _hooks (self ):
0 commit comments