22import inspect
33import json
44import threading
5- from typing import Iterable
5+ from typing import Callable , Iterable
66import uuid
77
88import asyncio
@@ -59,8 +59,11 @@ class TaskQueue:
5959 def __init__ (self ):
6060 self ._tasks = {}
6161 self ._watch_thread = None
62+ self ._queue_empty_callback = None
63+ self ._queue_remaining_callbacks = [self ._when_empty_callback ]
6264 self ._watch_display_node = None
6365 self ._watch_display_task = None
66+ self .queue_remaining = 0
6467
6568 async def _get_history (self , prompt_id : str ) -> dict | None :
6669 async with aiohttp .ClientSession () as session :
@@ -77,48 +80,57 @@ async def _watch(self):
7780 try :
7881 async with aiohttp .ClientSession () as session :
7982 async with session .ws_connect (f'{ api .endpoint } ws' , params = {'clientId' : _client_id }) as ws :
80- queue_remaining = 0
83+ self .queue_remaining = 0
84+ executing = False
8185 async for msg in ws :
8286 # print(msg.type)
8387 if msg .type == aiohttp .WSMsgType .TEXT :
8488 msg = msg .json ()
8589 # print(msg)
8690 if msg ['type' ] == 'status' :
8791 data = msg ['data' ]
88- new_queue_remaining = data ['status' ]['exec_info' ]['queue_remaining' ]
89- if queue_remaining != new_queue_remaining :
90- queue_remaining = new_queue_remaining
91- print (f'Queue remaining: { queue_remaining } ' )
92+ queue_remaining = data ['status' ]['exec_info' ]['queue_remaining' ]
93+ if self .queue_remaining != queue_remaining :
94+ self .queue_remaining = queue_remaining
95+ if not executing :
96+ for callback in self ._queue_remaining_callbacks :
97+ callback (self .queue_remaining )
98+ print (f'Queue remaining: { self .queue_remaining } ' )
9299 elif msg ['type' ] == 'execution_start' :
93- pass
100+ executing = True
94101 elif msg ['type' ] == 'executing' :
95102 data = msg ['data' ]
96103 if data ['node' ] is None :
97104 prompt_id = data ['prompt_id' ]
98105 task : Task = self ._tasks .get (prompt_id )
99106 if task is not None :
100- history = await self ._get_history (prompt_id )
101- outputs = {}
102- if history is not None :
103- outputs = history ['outputs' ]
104- task ._set_result_threadsafe (None , outputs , self ._watch_display_task )
105- if self ._watch_display_task :
106- print (f'Queue remaining: { queue_remaining } ' )
107107 del self ._tasks [prompt_id ]
108-
109- if new_queue_remaining == 0 :
108+
109+ if self . queue_remaining == 0 :
110110 for task in self ._tasks .values ():
111111 print (f'ComfyScript: The queue is empty but { task } has not been executed' )
112112 task ._set_result_threadsafe (None , {})
113113 self ._tasks .clear ()
114+
115+ for callback in self ._queue_remaining_callbacks :
116+ callback (self .queue_remaining )
117+ executing = False
118+
119+ history = await self ._get_history (prompt_id )
120+ outputs = {}
121+ if history is not None :
122+ outputs = history ['outputs' ]
123+ task ._set_result_threadsafe (None , outputs , self ._watch_display_task )
124+ if self ._watch_display_task :
125+ print (f'Queue remaining: { self .queue_remaining } ' )
114126 elif msg ['type' ] == 'executed' :
115127 data = msg ['data' ]
116128 prompt_id = data ['prompt_id' ]
117129 task : Task = self ._tasks .get (prompt_id )
118130 if task is not None :
119131 task ._set_result_threadsafe (data ['node' ], data ['output' ], self ._watch_display_node )
120132 if self ._watch_display_node :
121- print (f'Queue remaining: { queue_remaining } ' )
133+ print (f'Queue remaining: { self . queue_remaining } ' )
122134 elif msg ['type' ] == 'progress' :
123135 data = msg ['data' ]
124136 _print_progress (data ['value' ], data ['max' ])
@@ -165,6 +177,14 @@ def start_watch(self, display_node: bool = True, display_task: bool = True):
165177 self ._watch_thread = threading .Thread (target = asyncio .run , args = (queue ._watch (),), daemon = True )
166178 self ._watch_thread .start ()
167179
180+ def add_queue_remaining_callback (self , callback : Callable [[int ], None ]):
181+ self .remove_queue_remaining_callback (callback )
182+ self ._queue_remaining_callbacks .append (callback )
183+
184+ def remove_queue_remaining_callback (self , callback : Callable [[int ], None ]):
185+ if callback in self ._queue_remaining_callbacks :
186+ self ._queue_remaining_callbacks .remove (callback )
187+
168188 def watch_display (self , display_node : bool = True , display_task : bool = True ):
169189 '''
170190 - `display_node`: When an output node is finished, display its result.
@@ -217,6 +237,35 @@ def __iadd__(self, workflow: data.NodeOutput | Iterable[data.NodeOutput] | Workf
217237 source = '' .join (inspect .findsource (outer )[0 ])
218238 return self .put (workflow , source )
219239
240+ def _when_empty_callback (self , queue_remaining : int ):
241+ if queue_remaining == 0 and self ._queue_empty_callback is not None :
242+ self ._queue_empty_callback ()
243+
244+ def when_empty (self , callback : Callable [[Workflow ], None | bool ] | None , enter_workflow : bool = True , source = None ):
245+ '''Call the callback when the queue is empty.
246+
247+ - `callback`: Return `True` to stop, `None` or `False` to continue.
248+
249+ Only one callback can be registered at a time. Use `add_queue_remaining_callback()` if you want to register multiple callbacks.
250+ '''
251+ if callback is None :
252+ self ._queue_empty_callback = None
253+ return
254+ if source is None :
255+ outer = inspect .currentframe ().f_back
256+ source = '' .join (inspect .findsource (outer )[0 ])
257+ def f (callback = callback , enter_workflow = enter_workflow , source = source ):
258+ wf = Workflow ()
259+ if enter_workflow :
260+ wf .__enter__ ()
261+ callback (wf )
262+ asyncio .run (wf ._exit (source ))
263+ else :
264+ callback (wf )
265+ self ._queue_empty_callback = f
266+ if self .queue_remaining == 0 :
267+ f ()
268+
220269 def cancel_current (self ):
221270 '''Interrupt the current task'''
222271 return asyncio .run (self ._cancel_current ())
0 commit comments