Skip to content

Commit cac11a0

Browse files
davidbrochartCarreaupre-commit-ci[bot]
authored
Use zmq-anyio (#1291)
Co-authored-by: M Bussonnier <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dcdc56e commit cac11a0

23 files changed

+381
-364
lines changed

docs/api/ipykernel.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,6 @@ Submodules
134134
:show-inheritance:
135135

136136

137-
.. automodule:: ipykernel.trio_runner
138-
:members:
139-
:undoc-members:
140-
:show-inheritance:
141-
142-
143137
.. automodule:: ipykernel.zmqshell
144138
:members:
145139
:undoc-members:

ipykernel/debugger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def _send_request(self, msg):
242242
self.log.debug("DEBUGPYCLIENT:")
243243
self.log.debug(self.routing_id)
244244
self.log.debug(buf)
245-
await self.debugpy_socket.send_multipart((self.routing_id, buf))
245+
await self.debugpy_socket.asend_multipart((self.routing_id, buf)).wait()
246246

247247
async def _wait_for_response(self):
248248
# Since events are never pushed to the message_queue
@@ -438,7 +438,7 @@ async def start(self):
438438
(self.shell_socket.getsockopt(ROUTING_ID)),
439439
)
440440

441-
msg = await self.shell_socket.recv_multipart()
441+
msg = await self.shell_socket.arecv_multipart().wait()
442442
ident, msg = self.session.feed_identities(msg, copy=True)
443443
try:
444444
msg = self.session.deserialize(msg, content=True, copy=True)

ipykernel/heartbeat.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,17 @@ def _bind_socket(self):
9292
def run(self):
9393
"""Run the heartbeat thread."""
9494
self.name = "Heartbeat"
95-
self.socket = self.context.socket(zmq.ROUTER)
96-
self.socket.linger = 1000
95+
9796
try:
97+
self.socket = self.context.socket(zmq.ROUTER)
98+
self.socket.linger = 1000
9899
self._bind_socket()
99100
except Exception:
100-
self.socket.close()
101-
raise
101+
try:
102+
self.socket.close()
103+
except Exception:
104+
pass
105+
return
102106

103107
while True:
104108
try:

ipykernel/inprocess/ipkernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class InProcessKernel(IPythonKernel):
5454
_underlying_iopub_socket = Instance(DummySocket, (False,))
5555
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]
5656

57-
shell_socket = Instance(DummySocket, (True,)) # type:ignore[arg-type]
57+
shell_socket = Instance(DummySocket, (True,))
5858

5959
@default("iopub_thread")
6060
def _default_iopub_thread(self):
@@ -207,7 +207,7 @@ def enable_pylab(self, gui=None, import_all=True, welcome_message=False):
207207
"""Activate pylab support at runtime."""
208208
if not gui:
209209
gui = self.kernel.gui
210-
return super().enable_pylab(gui, import_all, welcome_message)
210+
return super().enable_pylab(gui, import_all, welcome_message) # type: ignore[call-arg]
211211

212212

213213
InteractiveShellABC.register(InProcessInteractiveShell)

ipykernel/inprocess/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ async def recv( # type: ignore[override]
1212
mode, content, copy have no effect, but are present for superclass compatibility
1313
1414
"""
15-
return await socket.recv_multipart()
15+
return await socket.arecv_multipart().wait()
1616

1717
def send(
1818
self,

ipykernel/inprocess/socket.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,8 @@ async def poll(self, timeout=0):
6565
return statistics.current_buffer_used != 0
6666

6767
def close(self):
68-
pass
68+
if self.is_shell:
69+
self.in_send_stream.close()
70+
self.in_receive_stream.close()
71+
self.out_send_stream.close()
72+
self.out_receive_stream.close()

ipykernel/iostream.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Callable
2121

2222
import zmq
23+
import zmq_anyio
2324
from anyio import sleep
2425
from jupyter_client.session import extract_header
2526

@@ -48,7 +49,7 @@ class IOPubThread:
4849
whose IO is always run in a thread.
4950
"""
5051

51-
def __init__(self, socket, pipe=False):
52+
def __init__(self, socket: zmq_anyio.Socket, pipe: bool = False):
5253
"""Create IOPub thread
5354
5455
Parameters
@@ -61,10 +62,7 @@ def __init__(self, socket, pipe=False):
6162
"""
6263
# ensure all of our sockets as sync zmq.Sockets
6364
# don't create async wrappers until we are within the appropriate coroutines
64-
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
65-
if self.socket.context is None:
66-
# bug in pyzmq, shadow socket doesn't always inherit context attribute
67-
self.socket.context = socket.context # type:ignore[unreachable]
65+
self.socket: zmq_anyio.Socket = socket
6866
self._context = socket.context
6967

7068
self.background_socket = BackgroundSocket(self)
@@ -78,7 +76,7 @@ def __init__(self, socket, pipe=False):
7876
self._event_pipe_gc_lock: threading.Lock = threading.Lock()
7977
self._event_pipe_gc_seconds: float = 10
8078
self._setup_event_pipe()
81-
tasks = [self._handle_event, self._run_event_pipe_gc]
79+
tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start]
8280
if pipe:
8381
tasks.append(self._handle_pipe_msgs)
8482
self.thread = BaseThread(name="IOPub", daemon=True)
@@ -87,7 +85,7 @@ def __init__(self, socket, pipe=False):
8785

8886
def _setup_event_pipe(self):
8987
"""Create the PULL socket listening for events that should fire in this thread."""
90-
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
88+
self._pipe_in0 = self._context.socket(zmq.PULL)
9189
self._pipe_in0.linger = 0
9290

9391
_uuid = b2a_hex(os.urandom(16)).decode("ascii")
@@ -99,11 +97,11 @@ async def _run_event_pipe_gc(self):
9997
while True:
10098
await sleep(self._event_pipe_gc_seconds)
10199
try:
102-
await self._event_pipe_gc()
100+
self._event_pipe_gc()
103101
except Exception as e:
104102
print(f"Exception in IOPubThread._event_pipe_gc: {e}", file=sys.__stderr__)
105103

106-
async def _event_pipe_gc(self):
104+
def _event_pipe_gc(self):
107105
"""run a single garbage collection on event pipes"""
108106
if not self._event_pipes:
109107
# don't acquire the lock if there's nothing to do
@@ -122,7 +120,7 @@ def _event_pipe(self):
122120
except AttributeError:
123121
# new thread, new event pipe
124122
# create sync base socket
125-
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
123+
event_pipe = self._context.socket(zmq.PUSH)
126124
event_pipe.linger = 0
127125
event_pipe.connect(self._event_interface)
128126
self._local.event_pipe = event_pipe
@@ -141,30 +139,28 @@ async def _handle_event(self):
141139
Whenever *an* event arrives on the event stream,
142140
*all* waiting events are processed in order.
143141
"""
144-
# create async wrapper within coroutine
145-
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
146-
try:
147-
while True:
148-
await pipe_in.recv()
149-
# freeze event count so new writes don't extend the queue
150-
# while we are processing
151-
n_events = len(self._events)
152-
for _ in range(n_events):
153-
event_f = self._events.popleft()
154-
event_f()
155-
except Exception:
156-
if self.thread.stopped.is_set():
157-
return
158-
raise
142+
pipe_in = zmq_anyio.Socket(self._pipe_in0)
143+
async with pipe_in:
144+
try:
145+
while True:
146+
await pipe_in.arecv().wait()
147+
# freeze event count so new writes don't extend the queue
148+
# while we are processing
149+
n_events = len(self._events)
150+
for _ in range(n_events):
151+
event_f = self._events.popleft()
152+
event_f()
153+
except Exception:
154+
if self.thread.stopped.is_set():
155+
return
156+
raise
159157

160158
def _setup_pipe_in(self):
161159
"""setup listening pipe for IOPub from forked subprocesses"""
162-
ctx = self._context
163-
164160
# use UUID to authenticate pipe messages
165161
self._pipe_uuid = os.urandom(16)
166162

167-
self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
163+
self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL))
168164
self._pipe_in1.linger = 0
169165

170166
try:
@@ -181,19 +177,18 @@ def _setup_pipe_in(self):
181177

182178
async def _handle_pipe_msgs(self):
183179
"""handle pipe messages from a subprocess"""
184-
# create async wrapper within coroutine
185-
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
186-
try:
187-
while True:
188-
await self._handle_pipe_msg()
189-
except Exception:
190-
if self.thread.stopped.is_set():
191-
return
192-
raise
180+
async with self._pipe_in1:
181+
try:
182+
while True:
183+
await self._handle_pipe_msg()
184+
except Exception:
185+
if self.thread.stopped.is_set():
186+
return
187+
raise
193188

194189
async def _handle_pipe_msg(self, msg=None):
195190
"""handle a pipe message from a subprocess"""
196-
msg = msg or await self._async_pipe_in1.recv_multipart()
191+
msg = msg or await self._pipe_in1.arecv_multipart().wait()
197192
if not self._pipe_flag or not self._is_main_process():
198193
return
199194
if msg[0] != self._pipe_uuid:
@@ -246,7 +241,10 @@ def close(self):
246241
"""Close the IOPub thread."""
247242
if self.closed:
248243
return
249-
self._pipe_in0.close()
244+
try:
245+
self._pipe_in0.close()
246+
except Exception:
247+
pass
250248
if self._pipe_flag:
251249
self._pipe_in1.close()
252250
if self.socket is not None:

ipykernel/ipkernel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313

1414
import comm
15-
import zmq.asyncio
15+
import zmq_anyio
1616
from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread
1717
from anyio.abc import TaskStatus
1818
from IPython.core import release
@@ -93,7 +93,7 @@ class IPythonKernel(KernelBase):
9393
help="Set this flag to False to deactivate the use of experimental IPython completion APIs.",
9494
).tag(config=True)
9595

96-
debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True)
96+
debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True)
9797

9898
user_module = Any()
9999

@@ -229,7 +229,8 @@ def __init__(self, **kwargs):
229229
}
230230

231231
async def process_debugpy(self):
232-
async with create_task_group() as tg:
232+
assert self.debugpy_socket is not None
233+
async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg:
233234
tg.start_soon(self.receive_debugpy_messages)
234235
tg.start_soon(self.poll_stopped_queue)
235236
await to_thread.run_sync(self.debugpy_stop.wait)
@@ -252,7 +253,7 @@ async def receive_debugpy_message(self, msg=None):
252253

253254
if msg is None:
254255
assert self.debugpy_socket is not None
255-
msg = await self.debugpy_socket.recv_multipart()
256+
msg = await self.debugpy_socket.arecv_multipart().wait()
256257
# The first frame is the socket id, we can drop it
257258
frame = msg[1].decode("utf-8")
258259
self.log.debug("Debugpy received: %s", frame)

0 commit comments

Comments
 (0)