Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use zmq-anyio #1291

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,44 @@ jobs:
python-version: "3.11"
- os: ubuntu-latest
python-version: "3.12"
- os: windows-latest
python-version: "3.10"
- os: windows-latest
python-version: "3.11"
- os: windows-latest
python-version: "3.12"
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Base Setup
uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install hatch
run: |
python --version
python -m pip install hatch

- name: Run the tests
timeout-minutes: 15
if: ${{ !startsWith( matrix.python-version, 'pypy' ) && !startsWith(matrix.os, 'windows') }}
run: |
hatch run cov:test --cov-fail-under 50 || hatch run test:test --lf
PYTHONTRACEMALLOC=20 hatch run cov:test --cov-fail-under 50 || PYTHONTRACEMALLOC=20 hatch run test:test --lf

- name: Run the tests on pypy
timeout-minutes: 15
if: ${{ startsWith( matrix.python-version, 'pypy' ) }}
run: |
hatch run test:nowarn || hatch run test:nowarn --lf
PYTHONTRACEMALLOC=20 hatch run test:nowarn || PYTHONTRACEMALLOC=20 hatch run test:nowarn --lf

- name: Run the tests on Windows
timeout-minutes: 15
if: ${{ startsWith(matrix.os, 'windows') }}
run: |
hatch run cov:nowarn || hatch run test:nowarn --lf
hatch run test:pip list
hatch run test:python --version
PYTHONTRACEMALLOC=20 hatch run test:pytest -v
Carreau marked this conversation as resolved.
Show resolved Hide resolved

- name: Check Launcher
run: |
Expand Down Expand Up @@ -138,7 +152,7 @@ jobs:

- name: Run the tests
timeout-minutes: 15
run: pytest -W default -vv || pytest --vv -W default --lf
run: PYTHONTRACEMALLOC=20 pytest -W default -vv || PYTHONTRACEMALLOC=20 pytest --vv -W default --lf

test_miniumum_versions:
name: Test Minimum Versions
Expand Down
4 changes: 2 additions & 2 deletions ipykernel/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def _send_request(self, msg):
self.log.debug("DEBUGPYCLIENT:")
self.log.debug(self.routing_id)
self.log.debug(buf)
await self.debugpy_socket.send_multipart((self.routing_id, buf))
await self.debugpy_socket.asend_multipart((self.routing_id, buf))

async def _wait_for_response(self):
# Since events are never pushed to the message_queue
Expand Down Expand Up @@ -437,7 +437,7 @@ async def start(self):
(self.shell_socket.getsockopt(ROUTING_ID)),
)

msg = await self.shell_socket.recv_multipart()
msg = await self.shell_socket.arecv_multipart()
ident, msg = self.session.feed_identities(msg, copy=True)
try:
msg = self.session.deserialize(msg, content=True, copy=True)
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/inprocess/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class InProcessKernel(IPythonKernel):
_underlying_iopub_socket = Instance(DummySocket, (False,))
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]

shell_socket = Instance(DummySocket, (True,)) # type:ignore[arg-type]
shell_socket = Instance(DummySocket, (True,))

@default("iopub_thread")
def _default_iopub_thread(self):
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/inprocess/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class Session(_Session):
async def recv(self, socket, copy=True):
return await socket.recv_multipart()
return await socket.arecv_multipart()

def send(
self,
Expand Down
6 changes: 5 additions & 1 deletion ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,8 @@ async def poll(self, timeout=0):
return statistics.current_buffer_used != 0

def close(self):
pass
if self.is_shell:
self.in_send_stream.close()
self.in_receive_stream.close()
self.out_send_stream.close()
self.out_receive_stream.close()
108 changes: 38 additions & 70 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from binascii import b2a_hex
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import Event, Thread, local
from threading import local
from typing import Any, Callable

import zmq
from anyio import create_task_group, run, sleep, to_thread
import zmq_anyio
from anyio import sleep
from jupyter_client.session import extract_header

from .thread import BaseThread

# -----------------------------------------------------------------------------
# Globals
# -----------------------------------------------------------------------------
Expand All @@ -37,38 +40,6 @@
# -----------------------------------------------------------------------------


class _IOPubThread(Thread):
"""A thread for a IOPub."""

def __init__(self, tasks, **kwargs):
"""Initialize the thread."""
super().__init__(name="IOPub", **kwargs)
self._tasks = tasks
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True
self.daemon = True
self.__stop = Event()

def run(self):
"""Run the thread."""
self.name = "IOPub"
run(self._main)

async def _main(self):
async with create_task_group() as tg:
for task in self._tasks:
tg.start_soon(task)
await to_thread.run_sync(self.__stop.wait)
tg.cancel_scope.cancel()

def stop(self):
"""Stop the thread.

This method is threadsafe.
"""
self.__stop.set()


class IOPubThread:
"""An object for sending IOPub messages in a background thread

Expand All @@ -78,7 +49,7 @@ class IOPubThread:
whose IO is always run in a thread.
"""

def __init__(self, socket, pipe=False):
def __init__(self, socket: zmq_anyio.Socket, pipe=False):
"""Create IOPub thread

Parameters
Expand All @@ -91,10 +62,7 @@ def __init__(self, socket, pipe=False):
"""
# ensure all of our sockets as sync zmq.Sockets
# don't create async wrappers until we are within the appropriate coroutines
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
if self.socket.context is None:
# bug in pyzmq, shadow socket doesn't always inherit context attribute
self.socket.context = socket.context # type:ignore[unreachable]
self.socket: zmq_anyio.Socket = socket
self._context = socket.context

self.background_socket = BackgroundSocket(self)
Expand All @@ -108,14 +76,16 @@ def __init__(self, socket, pipe=False):
self._event_pipe_gc_lock: threading.Lock = threading.Lock()
self._event_pipe_gc_seconds: float = 10
self._setup_event_pipe()
tasks = [self._handle_event, self._run_event_pipe_gc]
tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start]
if pipe:
tasks.append(self._handle_pipe_msgs)
self.thread = _IOPubThread(tasks)
self.thread = BaseThread(name="IOPub", daemon=True)
for task in tasks:
self.thread.start_soon(task)

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in0 = self._context.socket(zmq.PULL)
self._pipe_in0.linger = 0

_uuid = b2a_hex(os.urandom(16)).decode("ascii")
Expand Down Expand Up @@ -150,7 +120,7 @@ def _event_pipe(self):
except AttributeError:
# new thread, new event pipe
# create sync base socket
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
event_pipe = self._context.socket(zmq.PUSH)
event_pipe.linger = 0
event_pipe.connect(self._event_interface)
self._local.event_pipe = event_pipe
Expand All @@ -169,30 +139,28 @@ async def _handle_event(self):
Whenever *an* event arrives on the event stream,
*all* waiting events are processed in order.
"""
# create async wrapper within coroutine
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
try:
while True:
await pipe_in.recv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()
except Exception:
if self.thread.__stop.is_set():
return
raise
pipe_in = zmq_anyio.Socket(self._pipe_in0)
async with pipe_in:
try:
while True:
await pipe_in.arecv()
# freeze event count so new writes don't extend the queue
# while we are processing
n_events = len(self._events)
for _ in range(n_events):
event_f = self._events.popleft()
event_f()
except Exception:
if self.thread.stopped.is_set():
return
raise

def _setup_pipe_in(self):
"""setup listening pipe for IOPub from forked subprocesses"""
ctx = self._context

# use UUID to authenticate pipe messages
self._pipe_uuid = os.urandom(16)

self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL))
self._pipe_in1.linger = 0

try:
Expand All @@ -210,18 +178,18 @@ def _setup_pipe_in(self):
async def _handle_pipe_msgs(self):
"""handle pipe messages from a subprocess"""
# create async wrapper within coroutine
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
try:
while True:
await self._handle_pipe_msg()
except Exception:
if self.thread.__stop.is_set():
return
raise
async with self._pipe_in1:
try:
while True:
await self._handle_pipe_msg()
except Exception:
if self.thread.stopped.is_set():
return
raise

async def _handle_pipe_msg(self, msg=None):
"""handle a pipe message from a subprocess"""
msg = msg or await self._async_pipe_in1.recv_multipart()
msg = msg or await self._pipe_in1.arecv_multipart()
if not self._pipe_flag or not self._is_main_process():
return
if msg[0] != self._pipe_uuid:
Expand Down
9 changes: 5 additions & 4 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import dataclass

import comm
import zmq.asyncio
import zmq_anyio
from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread
from anyio.abc import TaskStatus
from IPython.core import release
Expand Down Expand Up @@ -76,7 +76,7 @@ class IPythonKernel(KernelBase):
help="Set this flag to False to deactivate the use of experimental IPython completion APIs.",
).tag(config=True)

debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True)
debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True)

user_module = Any()

Expand Down Expand Up @@ -212,7 +212,8 @@ def __init__(self, **kwargs):
}

async def process_debugpy(self):
async with create_task_group() as tg:
assert self.debugpy_socket is not None
async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg:
tg.start_soon(self.receive_debugpy_messages)
tg.start_soon(self.poll_stopped_queue)
await to_thread.run_sync(self.debugpy_stop.wait)
Expand All @@ -235,7 +236,7 @@ async def receive_debugpy_message(self, msg=None):

if msg is None:
assert self.debugpy_socket is not None
msg = await self.debugpy_socket.recv_multipart()
msg = await self.debugpy_socket.arecv_multipart()
# The first frame is the socket id, we can drop it
frame = msg[1].decode("utf-8")
self.log.debug("Debugpy received: %s", frame)
Expand Down
Loading
Loading