20
20
from typing import Any , Callable
21
21
22
22
import zmq
23
+ import zmq_anyio
23
24
from anyio import sleep
24
25
from jupyter_client .session import extract_header
25
26
@@ -48,7 +49,7 @@ class IOPubThread:
48
49
whose IO is always run in a thread.
49
50
"""
50
51
51
- def __init__ (self , socket , pipe = False ):
52
+ def __init__ (self , socket : zmq_anyio . Socket , pipe : bool = False ):
52
53
"""Create IOPub thread
53
54
54
55
Parameters
@@ -61,10 +62,7 @@ def __init__(self, socket, pipe=False):
61
62
"""
62
63
# ensure all of our sockets as sync zmq.Sockets
63
64
# don't create async wrappers until we are within the appropriate coroutines
64
- self .socket : zmq .Socket [bytes ] | None = zmq .Socket (socket )
65
- if self .socket .context is None :
66
- # bug in pyzmq, shadow socket doesn't always inherit context attribute
67
- self .socket .context = socket .context # type:ignore[unreachable]
65
+ self .socket : zmq_anyio .Socket = socket
68
66
self ._context = socket .context
69
67
70
68
self .background_socket = BackgroundSocket (self )
@@ -78,7 +76,7 @@ def __init__(self, socket, pipe=False):
78
76
self ._event_pipe_gc_lock : threading .Lock = threading .Lock ()
79
77
self ._event_pipe_gc_seconds : float = 10
80
78
self ._setup_event_pipe ()
81
- tasks = [self ._handle_event , self ._run_event_pipe_gc ]
79
+ tasks = [self ._handle_event , self ._run_event_pipe_gc , self . socket . start ]
82
80
if pipe :
83
81
tasks .append (self ._handle_pipe_msgs )
84
82
self .thread = BaseThread (name = "IOPub" , daemon = True )
@@ -87,7 +85,7 @@ def __init__(self, socket, pipe=False):
87
85
88
86
def _setup_event_pipe (self ):
89
87
"""Create the PULL socket listening for events that should fire in this thread."""
90
- self ._pipe_in0 = self ._context .socket (zmq .PULL , socket_class = zmq . Socket )
88
+ self ._pipe_in0 = self ._context .socket (zmq .PULL )
91
89
self ._pipe_in0 .linger = 0
92
90
93
91
_uuid = b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -99,11 +97,11 @@ async def _run_event_pipe_gc(self):
99
97
while True :
100
98
await sleep (self ._event_pipe_gc_seconds )
101
99
try :
102
- await self ._event_pipe_gc ()
100
+ self ._event_pipe_gc ()
103
101
except Exception as e :
104
102
print (f"Exception in IOPubThread._event_pipe_gc: { e } " , file = sys .__stderr__ )
105
103
106
- async def _event_pipe_gc (self ):
104
+ def _event_pipe_gc (self ):
107
105
"""run a single garbage collection on event pipes"""
108
106
if not self ._event_pipes :
109
107
# don't acquire the lock if there's nothing to do
@@ -122,7 +120,7 @@ def _event_pipe(self):
122
120
except AttributeError :
123
121
# new thread, new event pipe
124
122
# create sync base socket
125
- event_pipe = self ._context .socket (zmq .PUSH , socket_class = zmq . Socket )
123
+ event_pipe = self ._context .socket (zmq .PUSH )
126
124
event_pipe .linger = 0
127
125
event_pipe .connect (self ._event_interface )
128
126
self ._local .event_pipe = event_pipe
@@ -141,30 +139,28 @@ async def _handle_event(self):
141
139
Whenever *an* event arrives on the event stream,
142
140
*all* waiting events are processed in order.
143
141
"""
144
- # create async wrapper within coroutine
145
- pipe_in = zmq . asyncio . Socket ( self . _pipe_in0 )
146
- try :
147
- while True :
148
- await pipe_in .recv ()
149
- # freeze event count so new writes don't extend the queue
150
- # while we are processing
151
- n_events = len (self ._events )
152
- for _ in range (n_events ):
153
- event_f = self ._events .popleft ()
154
- event_f ()
155
- except Exception :
156
- if self .thread .stopped .is_set ():
157
- return
158
- raise
142
+ pipe_in = zmq_anyio . Socket ( self . _pipe_in0 )
143
+ async with pipe_in :
144
+ try :
145
+ while True :
146
+ await pipe_in .arecv (). wait ()
147
+ # freeze event count so new writes don't extend the queue
148
+ # while we are processing
149
+ n_events = len (self ._events )
150
+ for _ in range (n_events ):
151
+ event_f = self ._events .popleft ()
152
+ event_f ()
153
+ except Exception :
154
+ if self .thread .stopped .is_set ():
155
+ return
156
+ raise
159
157
160
158
def _setup_pipe_in (self ):
161
159
"""setup listening pipe for IOPub from forked subprocesses"""
162
- ctx = self ._context
163
-
164
160
# use UUID to authenticate pipe messages
165
161
self ._pipe_uuid = os .urandom (16 )
166
162
167
- self ._pipe_in1 = ctx . socket (zmq .PULL , socket_class = zmq . Socket )
163
+ self ._pipe_in1 = zmq_anyio . Socket ( self . _context . socket (zmq .PULL ) )
168
164
self ._pipe_in1 .linger = 0
169
165
170
166
try :
@@ -181,19 +177,18 @@ def _setup_pipe_in(self):
181
177
182
178
async def _handle_pipe_msgs (self ):
183
179
"""handle pipe messages from a subprocess"""
184
- # create async wrapper within coroutine
185
- self ._async_pipe_in1 = zmq .asyncio .Socket (self ._pipe_in1 )
186
- try :
187
- while True :
188
- await self ._handle_pipe_msg ()
189
- except Exception :
190
- if self .thread .stopped .is_set ():
191
- return
192
- raise
180
+ async with self ._pipe_in1 :
181
+ try :
182
+ while True :
183
+ await self ._handle_pipe_msg ()
184
+ except Exception :
185
+ if self .thread .stopped .is_set ():
186
+ return
187
+ raise
193
188
194
189
async def _handle_pipe_msg (self , msg = None ):
195
190
"""handle a pipe message from a subprocess"""
196
- msg = msg or await self ._async_pipe_in1 . recv_multipart ()
191
+ msg = msg or await self ._pipe_in1 . arecv_multipart (). wait ()
197
192
if not self ._pipe_flag or not self ._is_main_process ():
198
193
return
199
194
if msg [0 ] != self ._pipe_uuid :
@@ -246,7 +241,10 @@ def close(self):
246
241
"""Close the IOPub thread."""
247
242
if self .closed :
248
243
return
249
- self ._pipe_in0 .close ()
244
+ try :
245
+ self ._pipe_in0 .close ()
246
+ except Exception :
247
+ pass
250
248
if self ._pipe_flag :
251
249
self ._pipe_in1 .close ()
252
250
if self .socket is not None :
0 commit comments