Skip to content

Commit ad2d7db

Browse files
feat: support Output widget by implementing a custom shell
1 parent 1a5d94c commit ad2d7db

File tree

4 files changed

+229
-1
lines changed

4 files changed

+229
-1
lines changed

solara/server/kernel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from zmq.eventloop.zmqstream import ZMQStream
1717

1818
import solara
19+
from solara.server.shell import SolaraInteractiveShell
1920

2021
from . import settings, websocket
2122

@@ -263,6 +264,9 @@ def __init__(self):
263264
comm_msg_types = ["comm_open", "comm_msg", "comm_close"]
264265
for msg_type in comm_msg_types:
265266
self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)
267+
self.shell = SolaraInteractiveShell()
268+
self.shell.display_pub.session = self.session
269+
self.shell.display_pub.pub_socket = self.iopub_socket
266270

267271
async def _flush_control_queue(self):
268272
pass
@@ -275,3 +279,11 @@ def pre_handler_hook(self, *args):
275279

276280
def post_handler_hook(self, *args):
277281
pass
282+
283+
def set_parent(self, ident, parent, channel="shell"):
284+
"""Overridden from parent to tell the display hook and output streams
285+
about the parent message.
286+
"""
287+
super().set_parent(ident, parent, channel)
288+
if channel == "shell":
289+
self.shell.set_parent(parent)

solara/server/patch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ipykernel.kernelbase
1111
import IPython.display
1212
import ipywidgets
13+
from IPython.core.interactiveshell import InteractiveShell
1314

1415
from . import app, reload, settings
1516
from .utils import pdb_guard
@@ -28,7 +29,7 @@ class FakeIPython:
2829
def __init__(self, context: app.AppContext):
2930
self.context = context
3031
self.kernel = context.kernel
31-
self.display_pub = mock.MagicMock()
32+
self.display_pub = self.kernel.shell.display_pub
3233
# needed for the pyplot interface of matplotlib
3334
# (although we don't really support it)
3435
self.events = mock.MagicMock()
@@ -68,6 +69,11 @@ def kernel_instance_dispatch(cls, *args, **kwargs):
6869
return context.kernel
6970

7071

72+
def interactive_shell_instance_dispatch(cls, *args, **kwargs):
73+
context = app.get_current_context()
74+
return context.kernel.shell
75+
76+
7177
def kernel_initialized_dispatch(cls):
7278
try:
7379
app.get_current_context()
@@ -261,13 +267,17 @@ def patch():
261267
# variable has type "Callable[[VarArg(Any), KwArg(Any)], Any]")
262268
# not sure why we cannot reproduce that locally
263269
ipykernel.kernelbase.Kernel.instance = classmethod(kernel_instance_dispatch) # type: ignore
270+
InteractiveShell.instance = classmethod(interactive_shell_instance_dispatch) # type: ignore
264271
# on CI we get a mypy error:
265272
# solara/server/patch.py:211: error: Cannot assign to a method
266273
# solara/server/patch.py:211: error: Incompatible types in assignment (expression has type "classmethod[Any]", variable has type "Callable[[], Any]")
267274
# not sure why we cannot reproduce that locally
268275
ipykernel.kernelbase.Kernel.initialized = classmethod(kernel_initialized_dispatch) # type: ignore
269276
ipywidgets.widgets.widget.get_ipython = get_ipython
277+
278+
# TODO: find a way to actually monkeypatch get_ipython
270279
IPython.get_ipython = get_ipython
280+
ipywidgets.widgets.widget_output.get_ipython = get_ipython
271281

272282
def model_id_debug(self: ipywidgets.widgets.widget.Widget):
273283
from ipyvue.ForceLoad import force_load_instance

solara/server/shell.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import sys
2+
from threading import local
3+
from unittest.mock import Mock
4+
5+
from IPython.core.displaypub import DisplayPublisher
6+
from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
7+
from jupyter_client.session import Session, extract_header
8+
from traitlets import Any, CBytes, Dict, Instance, Type, default
9+
10+
11+
def encode_images(obj):
12+
# no-op in ipykernel
13+
return obj
14+
15+
16+
def json_clean(obj):
17+
# no-op in ipykernel
18+
return obj
19+
20+
21+
# based on the zmq display publisher from ipykernel
22+
# ideally this goes out of ipykernel
23+
class SolaraDisplayPublisher(DisplayPublisher):
24+
"""A display publisher that publishes data using a ZeroMQ PUB socket."""
25+
26+
session = Instance(Session, allow_none=True)
27+
pub_socket = Any(allow_none=True)
28+
parent_header = Dict({})
29+
topic = CBytes(b"display_data")
30+
31+
_thread_local = Any()
32+
33+
def set_parent(self, parent):
34+
"""Set the parent for outbound messages."""
35+
self.parent_header = extract_header(parent)
36+
37+
def _flush_streams(self):
38+
"""flush IO Streams prior to display"""
39+
sys.stdout.flush()
40+
sys.stderr.flush()
41+
42+
@default("_thread_local")
43+
def _default_thread_local(self):
44+
"""Initialize our thread local storage"""
45+
return local()
46+
47+
@property
48+
def _hooks(self):
49+
if not hasattr(self._thread_local, "hooks"):
50+
# create new list for a new thread
51+
self._thread_local.hooks = []
52+
return self._thread_local.hooks
53+
54+
def publish(
55+
self,
56+
data,
57+
metadata=None,
58+
transient=None,
59+
update=False,
60+
):
61+
"""Publish a display-data message
62+
63+
Parameters
64+
----------
65+
data : dict
66+
A mime-bundle dict, keyed by mime-type.
67+
metadata : dict, optional
68+
Metadata associated with the data.
69+
transient : dict, optional, keyword-only
70+
Transient data that may only be relevant during a live display,
71+
such as display_id.
72+
Transient data should not be persisted to documents.
73+
update : bool, optional, keyword-only
74+
If True, send an update_display_data message instead of display_data.
75+
"""
76+
self._flush_streams()
77+
if metadata is None:
78+
metadata = {}
79+
if transient is None:
80+
transient = {}
81+
self._validate_data(data, metadata)
82+
content = {}
83+
content["data"] = encode_images(data)
84+
content["metadata"] = metadata
85+
content["transient"] = transient
86+
87+
msg_type = "update_display_data" if update else "display_data"
88+
89+
# Use 2-stage process to send a message,
90+
# in order to put it through the transform
91+
# hooks before potentially sending.
92+
msg = self.session.msg(msg_type, json_clean(content), parent=self.parent_header)
93+
94+
# Each transform either returns a new
95+
# message or None. If None is returned,
96+
# the message has been 'used' and we return.
97+
for hook in self._hooks:
98+
msg = hook(msg)
99+
if msg is None:
100+
return
101+
102+
self.session.send(
103+
self.pub_socket,
104+
msg,
105+
ident=self.topic,
106+
)
107+
108+
def clear_output(self, wait=False):
109+
"""Clear output associated with the current execution (cell).
110+
111+
Parameters
112+
----------
113+
wait : bool (default: False)
114+
If True, the output will not be cleared immediately,
115+
instead waiting for the next display before clearing.
116+
This reduces bounce during repeated clear & display loops.
117+
118+
"""
119+
content = dict(wait=wait)
120+
self._flush_streams()
121+
self.session.send(
122+
self.pub_socket,
123+
"clear_output",
124+
content,
125+
parent=self.parent_header,
126+
ident=self.topic,
127+
)
128+
129+
def register_hook(self, hook):
130+
"""
131+
Registers a hook with the thread-local storage.
132+
133+
Parameters
134+
----------
135+
hook : Any callable object
136+
137+
Returns
138+
-------
139+
Either a publishable message, or `None`.
140+
The DisplayHook objects must return a message from
141+
the __call__ method if they still require the
142+
`session.send` method to be called after transformation.
143+
Returning `None` will halt that execution path, and
144+
session.send will not be called.
145+
"""
146+
self._hooks.append(hook)
147+
148+
def unregister_hook(self, hook):
149+
"""
150+
Un-registers a hook with the thread-local storage.
151+
152+
Parameters
153+
----------
154+
hook : Any callable object which has previously been
155+
registered as a hook.
156+
157+
Returns
158+
-------
159+
bool - `True` if the hook was removed, `False` if it wasn't
160+
found.
161+
"""
162+
try:
163+
self._hooks.remove(hook)
164+
return True
165+
except ValueError:
166+
return False
167+
168+
169+
class SolaraInteractiveShell(InteractiveShell):
170+
display_pub_class = Type(SolaraDisplayPublisher)
171+
history_manager = Any()
172+
173+
def set_parent(self, parent):
174+
"""Tell the children about the parent message."""
175+
self.display_pub.set_parent(parent)
176+
177+
def init_history(self):
178+
self.history_manager = Mock()
179+
180+
181+
InteractiveShellABC.register(SolaraInteractiveShell)

tests/unit/shell_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from unittest.mock import Mock
2+
3+
import IPython.display
4+
5+
from solara.server import app, kernel
6+
7+
8+
def test_shell(no_app_context):
9+
ws1 = Mock()
10+
ws2 = Mock()
11+
kernel1 = kernel.Kernel()
12+
kernel2 = kernel.Kernel()
13+
kernel1.session.websockets.add(ws1)
14+
kernel2.session.websockets.add(ws2)
15+
context1 = app.AppContext(id="1", kernel=kernel1)
16+
context2 = app.AppContext(id="2", kernel=kernel2)
17+
18+
with context1:
19+
IPython.display.display("test1")
20+
assert ws1.send.call_count == 1
21+
assert ws2.send.call_count == 0
22+
with context2:
23+
IPython.display.display("test1")
24+
assert ws1.send.call_count == 1
25+
assert ws2.send.call_count == 1

0 commit comments

Comments
 (0)