55from pathlib import Path
66import sys
77import threading
8+ import traceback
89from typing import Callable , Iterable
910import uuid
11+ from warnings import warn
1012
1113import asyncio
1214import nest_asyncio
1315import aiohttp
16+ from PIL import Image
1417
1518nest_asyncio .apply ()
1619
@@ -443,6 +446,7 @@ def __init__(self):
443446 self ._queue_empty_callback = None
444447 self ._queue_remaining_callbacks = [self ._when_empty_callback ]
445448 self ._watch_display_node = None
449+ self ._watch_display_node_preview = None
446450 self ._watch_display_task = None
447451 self .queue_remaining = 0
448452
@@ -463,6 +467,7 @@ async def _watch(self):
463467 async with session .ws_connect (f'{ client .client .base_url } ws' , params = {'clientId' : _client_id }) as ws :
464468 self .queue_remaining = 0
465469 executing = False
470+ progress_data = None
466471 async for msg in ws :
467472 # print(msg.type)
468473 if msg .type == aiohttp .WSMsgType .TEXT :
@@ -513,15 +518,25 @@ async def _watch(self):
513518 if self ._watch_display_node :
514519 print (f'Queue remaining: { self .queue_remaining } ' )
515520 elif msg ['type' ] == 'progress' :
516- # TODO: https://github.com/comfyanonymous/ComfyUI/issues/2425
517- data = msg ['data' ]
518- _print_progress (data ['value' ], data ['max' ])
521+ # See ComfyUI::main.hijack_progress
522+ # 'prompt_id', 'node': https://github.com/comfyanonymous/ComfyUI/issues/2425
523+ progress_data = msg ['data' ]
524+ # TODO: Node
525+ _print_progress (progress_data ['value' ], progress_data ['max' ])
519526 elif msg .type == aiohttp .WSMsgType .BINARY :
520- pass
527+ event = client .BinaryEvent .from_bytes (msg .data )
528+ if event .type == client .BinaryEventTypes .PREVIEW_IMAGE :
529+ prompt_id = progress_data .get ('prompt_id' )
530+ if prompt_id is not None :
531+ task : Task = self ._tasks .get (prompt_id )
532+ task ._set_node_preview (progress_data ['node' ], event .to_object (), self ._watch_display_node_preview )
533+ else :
534+ warn (f'Cannot get preview node, please update the ComfyUI server to at least 66831eb6e96cd974fb2d0fc4f299b23c6af16685 (2024-01-02)' )
521535 elif msg .type in (aiohttp .WSMsgType .CLOSED , aiohttp .WSMsgType .ERROR ):
522536 break
523537 except Exception as e :
524538 print (f'ComfyScript: Failed to watch, will retry in 5 seconds: { e } ' )
539+ traceback .print_exc ()
525540 await asyncio .sleep (5 )
526541 '''
527542 {'type': 'status', 'data': {'status': {'exec_info': {'queue_remaining': 0}}, 'sid': 'adc24049-b013-4a58-956b-edbc591dc6e2'}}
@@ -539,19 +554,27 @@ async def _watch(self):
539554 {'type': 'executing', 'data': {'node': None, 'prompt_id': '3328f0c8-9368-4070-90e7-087e854fe315'}}
540555 '''
541556
542- def start_watch (self , display_node : bool = True , display_task : bool = True ):
557+ def start_watch (self , display_node : bool = True , display_task : bool = True , display_node_preview : bool = True ):
543558 '''
544559 - `display_node`: When an output node is finished, display its result.
545560 - `display_task`: When a task is finished (all output nodes are finished), display all the results.
546561
547562 `load()` will `start_watch()` by default.
563+
564+ ## Previewing
565+ Previewing is disabled by default. Pass `--preview-method auto` to ComfyUI to enable previewing.
566+
567+ The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
568+
569+ The default maximum preview resolution is 512x512. The only way to change it is to modify ComfyUI::MAX_PREVIEW_RESOLUTION.
548570 '''
549571
550- if display_node or display_task :
572+ if display_node or display_task or display_node_preview :
551573 try :
552574 import IPython
553575 self ._watch_display_node = display_node
554576 self ._watch_display_task = display_task
577+ self ._watch_display_node_preview = display_node_preview
555578 except ImportError :
556579 print ('ComfyScript: IPython is not available, cannot display task results' )
557580
@@ -567,13 +590,14 @@ def remove_queue_remaining_callback(self, callback: Callable[[int], None]):
567590 if callback in self ._queue_remaining_callbacks :
568591 self ._queue_remaining_callbacks .remove (callback )
569592
570- def watch_display (self , display_node : bool = True , display_task : bool = True ):
593+ def watch_display (self , display_node : bool = True , display_task : bool = True , display_node_preview : bool = True ):
571594 '''
572595 - `display_node`: When an output node is finished, display its result.
573596 - `display_task`: When a task is finished (all output nodes are finished), display all the results.
574597 '''
575598 self ._watch_display_node = display_node
576599 self ._watch_display_task = display_task
600+ self ._watch_display_node_preview = display_node_preview
577601
578602 async def _put (self , workflow : data .NodeOutput | Iterable [data .NodeOutput ] | Workflow , source = None ) -> Task | None :
579603 global _client_id
@@ -685,13 +709,23 @@ def __init__(self, prompt_id: str, number: int, id: data.IdManager):
685709 self ._id = id
686710 self ._new_outputs = {}
687711 self ._fut = asyncio .Future ()
712+ self ._node_preview_callbacks : list [Callable [[Task , str , Image .Image ]]] = []
688713
689714 def __str__ (self ):
690715 return f'Task { self .number } ({ self .prompt_id } )'
691716
692717 def __repr__ (self ):
693718 return f'Task(n={ self .number } , id={ self .prompt_id } )'
694719
720+ def _set_node_preview (self , node_id : str , preview : Image .Image , display : bool ):
721+ for callback in self ._node_preview_callbacks :
722+ callback (self , node_id , preview )
723+
724+ if display :
725+ from IPython .display import display
726+
727+ display (preview , clear = True )
728+
695729 async def _set_result_threadsafe (self , node_id : str | None , output : dict , display_result : bool = False ) -> None :
696730 if node_id is not None :
697731 self ._new_outputs [node_id ] = output
@@ -781,6 +815,12 @@ def wait_result(self, output: data.NodeOutput) -> data.Result | None:
781815 # def __await__(self):
782816 # return self._wait().__await__()
783817
818+ def add_preview_callback (self , callback : Callable [[Task , str , Image .Image ], None ]):
819+ self ._node_preview_callbacks .append (callback )
820+
821+ def remove_preview_callback (self , callback : Callable [[Task , str , Image .Image ], None ]):
822+ self ._node_preview_callbacks .remove (callback )
823+
784824 def done (self ) -> bool :
785825 """Return True if the task is done.
786826
0 commit comments