Skip to content

Commit 8495548

Browse files
Make outputs go to correct cell when generated in threads/asyncio (#1186)
Co-authored-by: Steven Silvester <steven.silvester@ieee.org>
1 parent e8185df commit 8495548

File tree

4 files changed

+271
-36
lines changed

4 files changed

+271
-36
lines changed

ipykernel/iostream.py

+69-36
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import asyncio
77
import atexit
8+
import contextvars
89
import io
910
import os
1011
import sys
1112
import threading
1213
import traceback
1314
import warnings
1415
from binascii import b2a_hex
15-
from collections import deque
16+
from collections import defaultdict, deque
1617
from io import StringIO, TextIOBase
1718
from threading import local
1819
from typing import Any, Callable, Deque, Dict, Optional
@@ -412,7 +413,7 @@ def __init__(
412413
name : str {'stderr', 'stdout'}
413414
the name of the standard stream to replace
414415
pipe : object
415-
the pip object
416+
the pipe object
416417
echo : bool
417418
whether to echo output
418419
watchfd : bool (default, True)
@@ -446,13 +447,19 @@ def __init__(
446447
self.pub_thread = pub_thread
447448
self.name = name
448449
self.topic = b"stream." + name.encode()
449-
self.parent_header = {}
450+
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
451+
"parent_header"
452+
)
453+
self._parent_header.set({})
454+
self._thread_to_parent = {}
455+
self._thread_to_parent_header = {}
456+
self._parent_header_global = {}
450457
self._master_pid = os.getpid()
451458
self._flush_pending = False
452459
self._subprocess_flush_pending = False
453460
self._io_loop = pub_thread.io_loop
454461
self._buffer_lock = threading.RLock()
455-
self._buffer = StringIO()
462+
self._buffers = defaultdict(StringIO)
456463
self.echo = None
457464
self._isatty = bool(isatty)
458465
self._should_watch = False
@@ -495,6 +502,30 @@ def __init__(
495502
msg = "echo argument must be a file-like object"
496503
raise ValueError(msg)
497504

505+
@property
506+
def parent_header(self):
507+
try:
508+
# asyncio-specific
509+
return self._parent_header.get()
510+
except LookupError:
511+
try:
512+
# thread-specific
513+
identity = threading.current_thread().ident
514+
# retrieve the outermost (oldest ancestor,
515+
# discounting the kernel thread) thread identity
516+
while identity in self._thread_to_parent:
517+
identity = self._thread_to_parent[identity]
518+
# use the header of the oldest ancestor
519+
return self._thread_to_parent_header[identity]
520+
except KeyError:
521+
# global (fallback)
522+
return self._parent_header_global
523+
524+
@parent_header.setter
525+
def parent_header(self, value):
526+
self._parent_header_global = value
527+
return self._parent_header.set(value)
528+
498529
def isatty(self):
499530
"""Return a bool indicating whether this is an 'interactive' stream.
500531
@@ -598,28 +629,28 @@ def _flush(self):
598629
if self.echo is not sys.__stderr__:
599630
print(f"Flush failed: {e}", file=sys.__stderr__)
600631

601-
data = self._flush_buffer()
602-
if data:
603-
# FIXME: this disables Session's fork-safe check,
604-
# since pub_thread is itself fork-safe.
605-
# There should be a better way to do this.
606-
self.session.pid = os.getpid()
607-
content = {"name": self.name, "text": data}
608-
msg = self.session.msg("stream", content, parent=self.parent_header)
609-
610-
# Each transform either returns a new
611-
# message or None. If None is returned,
612-
# the message has been 'used' and we return.
613-
for hook in self._hooks:
614-
msg = hook(msg)
615-
if msg is None:
616-
return
617-
618-
self.session.send(
619-
self.pub_thread,
620-
msg,
621-
ident=self.topic,
622-
)
632+
for parent, data in self._flush_buffers():
633+
if data:
634+
# FIXME: this disables Session's fork-safe check,
635+
# since pub_thread is itself fork-safe.
636+
# There should be a better way to do this.
637+
self.session.pid = os.getpid()
638+
content = {"name": self.name, "text": data}
639+
msg = self.session.msg("stream", content, parent=parent)
640+
641+
# Each transform either returns a new
642+
# message or None. If None is returned,
643+
# the message has been 'used' and we return.
644+
for hook in self._hooks:
645+
msg = hook(msg)
646+
if msg is None:
647+
return
648+
649+
self.session.send(
650+
self.pub_thread,
651+
msg,
652+
ident=self.topic,
653+
)
623654

624655
def write(self, string: str) -> Optional[int]: # type:ignore[override]
625656
"""Write to current stream after encoding if necessary
@@ -630,6 +661,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
630661
number of items from input parameter written to stream.
631662
632663
"""
664+
parent = self.parent_header
633665

634666
if not isinstance(string, str):
635667
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
@@ -649,7 +681,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
649681
is_child = not self._is_master_process()
650682
# only touch the buffer in the IO thread to avoid races
651683
with self._buffer_lock:
652-
self._buffer.write(string)
684+
self._buffers[frozenset(parent.items())].write(string)
653685
if is_child:
654686
# mp.Pool cannot be trusted to flush promptly (or ever),
655687
# and this helps.
@@ -675,19 +707,20 @@ def writable(self):
675707
"""Test whether the stream is writable."""
676708
return True
677709

678-
def _flush_buffer(self):
710+
def _flush_buffers(self):
679711
"""clear the current buffer and return the current buffer data."""
680-
buf = self._rotate_buffer()
681-
data = buf.getvalue()
682-
buf.close()
683-
return data
712+
buffers = self._rotate_buffers()
713+
for frozen_parent, buffer in buffers.items():
714+
data = buffer.getvalue()
715+
buffer.close()
716+
yield dict(frozen_parent), data
684717

685-
def _rotate_buffer(self):
718+
def _rotate_buffers(self):
686719
"""Returns the current buffer and replaces it with an empty buffer."""
687720
with self._buffer_lock:
688-
old_buffer = self._buffer
689-
self._buffer = StringIO()
690-
return old_buffer
721+
old_buffers = self._buffers
722+
self._buffers = defaultdict(StringIO)
723+
return old_buffers
691724

692725
@property
693726
def _hooks(self):

ipykernel/ipkernel.py

+94
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import builtins
5+
import gc
56
import getpass
67
import os
78
import signal
@@ -14,6 +15,7 @@
1415
import comm
1516
from IPython.core import release
1617
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
18+
from jupyter_client.session import extract_header
1719
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
1820
from zmq.eventloop.zmqstream import ZMQStream
1921

@@ -22,6 +24,7 @@
2224
from .compiler import XCachingCompiler
2325
from .debugger import Debugger, _is_debugpy_available
2426
from .eventloops import _use_appnope
27+
from .iostream import OutStream
2528
from .kernelbase import Kernel as KernelBase
2629
from .kernelbase import _accepts_parameters
2730
from .zmqshell import ZMQInteractiveShell
@@ -151,6 +154,14 @@ def __init__(self, **kwargs):
151154

152155
appnope.nope()
153156

157+
self._new_threads_parent_header = {}
158+
self._initialize_thread_hooks()
159+
160+
if hasattr(gc, "callbacks"):
161+
# while `gc.callbacks` exists since Python 3.3, pypy does not
162+
# implement it even as of 3.9.
163+
gc.callbacks.append(self._clean_thread_parent_frames)
164+
154165
help_links = List(
155166
[
156167
{
@@ -341,6 +352,12 @@ def set_sigint_result():
341352
# restore the previous sigint handler
342353
signal.signal(signal.SIGINT, save_sigint)
343354

355+
async def execute_request(self, stream, ident, parent):
356+
"""Override for cell output - cell reconciliation."""
357+
parent_header = extract_header(parent)
358+
self._associate_new_top_level_threads_with(parent_header)
359+
await super().execute_request(stream, ident, parent)
360+
344361
async def do_execute(
345362
self,
346363
code,
@@ -706,6 +723,83 @@ def do_clear(self):
706723
self.shell.reset(False)
707724
return dict(status="ok")
708725

726+
def _associate_new_top_level_threads_with(self, parent_header):
727+
"""Store the parent header to associate it with new top-level threads"""
728+
self._new_threads_parent_header = parent_header
729+
730+
def _initialize_thread_hooks(self):
731+
"""Store thread hierarchy and thread-parent_header associations."""
732+
stdout = self._stdout
733+
stderr = self._stderr
734+
kernel_thread_ident = threading.get_ident()
735+
kernel = self
736+
_threading_Thread_run = threading.Thread.run
737+
_threading_Thread__init__ = threading.Thread.__init__
738+
739+
def run_closure(self: threading.Thread):
740+
"""Wrap the `threading.Thread.start` to intercept thread identity.
741+
742+
This is needed because there is no "start" hook yet, but there
743+
might be one in the future: https://bugs.python.org/issue14073
744+
745+
This is a no-op if the `self._stdout` and `self._stderr` are not
746+
sub-classes of `OutStream`.
747+
"""
748+
749+
try:
750+
parent = self._ipykernel_parent_thread_ident # type:ignore[attr-defined]
751+
except AttributeError:
752+
return
753+
for stream in [stdout, stderr]:
754+
if isinstance(stream, OutStream):
755+
if parent == kernel_thread_ident:
756+
stream._thread_to_parent_header[
757+
self.ident
758+
] = kernel._new_threads_parent_header
759+
else:
760+
stream._thread_to_parent[self.ident] = parent
761+
_threading_Thread_run(self)
762+
763+
def init_closure(self: threading.Thread, *args, **kwargs):
764+
_threading_Thread__init__(self, *args, **kwargs)
765+
self._ipykernel_parent_thread_ident = threading.get_ident() # type:ignore[attr-defined]
766+
767+
threading.Thread.__init__ = init_closure # type:ignore[method-assign]
768+
threading.Thread.run = run_closure # type:ignore[method-assign]
769+
770+
def _clean_thread_parent_frames(
771+
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
772+
):
773+
"""Clean parent frames of threads which are no longer running.
774+
This is meant to be invoked by garbage collector callback hook.
775+
776+
The implementation enumerates the threads because there is no "exit" hook yet,
777+
but there might be one in the future: https://bugs.python.org/issue14073
778+
779+
This is a no-op if the `self._stdout` and `self._stderr` are not
780+
sub-classes of `OutStream`.
781+
"""
782+
# Only run before the garbage collector starts
783+
if phase != "start":
784+
return
785+
active_threads = {thread.ident for thread in threading.enumerate()}
786+
for stream in [self._stdout, self._stderr]:
787+
if isinstance(stream, OutStream):
788+
thread_to_parent_header = stream._thread_to_parent_header
789+
for identity in list(thread_to_parent_header.keys()):
790+
if identity not in active_threads:
791+
try:
792+
del thread_to_parent_header[identity]
793+
except KeyError:
794+
pass
795+
thread_to_parent = stream._thread_to_parent
796+
for identity in list(thread_to_parent.keys()):
797+
if identity not in active_threads:
798+
try:
799+
del thread_to_parent[identity]
800+
except KeyError:
801+
pass
802+
709803

710804
# This exists only for backwards compatibility - use IPythonKernel instead
711805

ipykernel/kernelbase.py

+8
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from ipykernel.jsonutil import json_clean
6262

6363
from ._version import kernel_protocol_version
64+
from .iostream import OutStream
6465

6566

6667
def _accepts_parameters(meth, param_names):
@@ -272,6 +273,13 @@ def _parent_header(self):
272273
def __init__(self, **kwargs):
273274
"""Initialize the kernel."""
274275
super().__init__(**kwargs)
276+
277+
# Kernel application may swap stdout and stderr to OutStream,
278+
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
279+
# can already by different from TextIO at initialization time.
280+
self._stdout: OutStream | t.TextIO = sys.stdout
281+
self._stderr: OutStream | t.TextIO = sys.stderr
282+
275283
# Build dict of handlers for message types
276284
self.shell_handlers = {}
277285
for msg_type in self.msg_types:

0 commit comments

Comments
 (0)