Skip to content

Commit 6e7cd33

Browse files
committed
feat(runtime): add node preview display and callbacks (#36)
1 parent c555758 commit 6e7cd33

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "comfy-script"
3-
version = "0.4.6"
3+
version = "0.5.0a1"
44
description = "A Python front end and library for ComfyUI"
55
readme = "README.md"
66
# ComfyUI: >=3.8
@@ -28,6 +28,11 @@ client = [
2828

2929
# 1.5.9: https://github.com/erdewit/nest_asyncio/issues/87
3030
"nest_asyncio ~= 1.0, >= 1.5.9",
31+
32+
# Already required by ComfyUI
33+
"Pillow",
34+
35+
"aenum ~= 3.1"
3136
]
3237

3338
# Transpiler

src/comfy_script/client/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from __future__ import annotations
2+
from dataclasses import dataclass
3+
from enum import IntEnum
4+
from io import BytesIO
25
import json
36
import os
47
from pathlib import PurePath
8+
import struct
59
import sys
610
import traceback
711
from typing import Callable
812

913
import asyncio
14+
from warnings import warn
15+
from PIL import Image
1016
import nest_asyncio
1117
import aiohttp
1218
from yarl import URL
@@ -136,6 +142,58 @@ def default(self, o):
136142
return str(o)
137143
return super().default(o)
138144

145+
class BinaryEventTypes(IntEnum):
146+
# See ComfyUI::server.BinaryEventTypes
147+
PREVIEW_IMAGE = 1
148+
UNENCODED_PREVIEW_IMAGE = 2
149+
'''Only used internally in ComfyUI.'''
150+
151+
@dataclass
152+
class BinaryEvent:
153+
type: BinaryEventTypes | int
154+
data: bytes
155+
156+
@staticmethod
157+
def from_bytes(data: bytes) -> BinaryEvent:
158+
# See ComfyUI::server.encode_bytes()
159+
type_int = struct.unpack('>I', data[:4])[0]
160+
try:
161+
type = BinaryEventTypes(type_int)
162+
except ValueError:
163+
warn(f'Unknown binary event type: {data[:4]}')
164+
type = type_int
165+
data = data[4:]
166+
return BinaryEvent(type, data)
167+
168+
def to_object(self) -> Image.Image | bytes:
169+
if self.type == BinaryEventTypes.PREVIEW_IMAGE:
170+
return _PreviewImage.from_bytes(self.data).image
171+
return self
172+
173+
class _PreviewImageFormat(IntEnum):
174+
'''`format.name` is compatible with PIL.'''
175+
JPEG = 1
176+
PNG = 2
177+
178+
@dataclass
179+
class _PreviewImage:
180+
format: _PreviewImageFormat
181+
image: Image.Image
182+
183+
@staticmethod
184+
def from_bytes(data: bytes) -> _PreviewImage:
185+
# See ComfyUI::LatentPreviewer
186+
format_int = struct.unpack('>I', data[:4])[0]
187+
format = None
188+
try:
189+
format = _PreviewImageFormat(format_int).name
190+
except ValueError:
191+
warn(f'Unknown image format: {data[:4]}')
192+
193+
image = Image.open(BytesIO(data[4:]), formats=(format,) if format is not None else None)
194+
195+
return _PreviewImage(format, image)
196+
139197
__all__ = [
140198
'client',
141199
'Client',

src/comfy_script/runtime/__init__.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
from pathlib import Path
66
import sys
77
import threading
8+
import traceback
89
from typing import Callable, Iterable
910
import uuid
11+
from warnings import warn
1012

1113
import asyncio
1214
import nest_asyncio
1315
import aiohttp
16+
from PIL import Image
1417

1518
nest_asyncio.apply()
1619

@@ -443,6 +446,7 @@ def __init__(self):
443446
self._queue_empty_callback = None
444447
self._queue_remaining_callbacks = [self._when_empty_callback]
445448
self._watch_display_node = None
449+
self._watch_display_node_preview = None
446450
self._watch_display_task = None
447451
self.queue_remaining = 0
448452

@@ -463,6 +467,7 @@ async def _watch(self):
463467
async with session.ws_connect(f'{client.client.base_url}ws', params={'clientId': _client_id}) as ws:
464468
self.queue_remaining = 0
465469
executing = False
470+
progress_data = None
466471
async for msg in ws:
467472
# print(msg.type)
468473
if msg.type == aiohttp.WSMsgType.TEXT:
@@ -513,15 +518,25 @@ async def _watch(self):
513518
if self._watch_display_node:
514519
print(f'Queue remaining: {self.queue_remaining}')
515520
elif msg['type'] == 'progress':
516-
# TODO: https://github.com/comfyanonymous/ComfyUI/issues/2425
517-
data = msg['data']
518-
_print_progress(data['value'], data['max'])
521+
# See ComfyUI::main.hijack_progress
522+
# 'prompt_id', 'node': https://github.com/comfyanonymous/ComfyUI/issues/2425
523+
progress_data = msg['data']
524+
# TODO: Node
525+
_print_progress(progress_data['value'], progress_data['max'])
519526
elif msg.type == aiohttp.WSMsgType.BINARY:
520-
pass
527+
event = client.BinaryEvent.from_bytes(msg.data)
528+
if event.type == client.BinaryEventTypes.PREVIEW_IMAGE:
529+
prompt_id = progress_data.get('prompt_id')
530+
if prompt_id is not None:
531+
task: Task = self._tasks.get(prompt_id)
532+
task._set_node_preview(progress_data['node'], event.to_object(), self._watch_display_node_preview)
533+
else:
534+
warn(f'Cannot get preview node, please update the ComfyUI server to at least 66831eb6e96cd974fb2d0fc4f299b23c6af16685 (2024-01-02)')
521535
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
522536
break
523537
except Exception as e:
524538
print(f'ComfyScript: Failed to watch, will retry in 5 seconds: {e}')
539+
traceback.print_exc()
525540
await asyncio.sleep(5)
526541
'''
527542
{'type': 'status', 'data': {'status': {'exec_info': {'queue_remaining': 0}}, 'sid': 'adc24049-b013-4a58-956b-edbc591dc6e2'}}
@@ -539,19 +554,27 @@ async def _watch(self):
539554
{'type': 'executing', 'data': {'node': None, 'prompt_id': '3328f0c8-9368-4070-90e7-087e854fe315'}}
540555
'''
541556

542-
def start_watch(self, display_node: bool = True, display_task: bool = True):
557+
def start_watch(self, display_node: bool = True, display_task: bool = True, display_node_preview: bool = True):
543558
'''
544559
- `display_node`: When an output node is finished, display its result.
545560
- `display_task`: When a task is finished (all output nodes are finished), display all the results.
546561
547562
`load()` will `start_watch()` by default.
563+
564+
## Previewing
565+
Previewing is disabled by default. Pass `--preview-method auto` to ComfyUI to enable previewing.
566+
567+
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
568+
569+
The default maximum preview resolution is 512x512. The only way to change it is to modify ComfyUI::MAX_PREVIEW_RESOLUTION.
548570
'''
549571

550-
if display_node or display_task:
572+
if display_node or display_task or display_node_preview:
551573
try:
552574
import IPython
553575
self._watch_display_node = display_node
554576
self._watch_display_task = display_task
577+
self._watch_display_node_preview = display_node_preview
555578
except ImportError:
556579
print('ComfyScript: IPython is not available, cannot display task results')
557580

@@ -567,13 +590,14 @@ def remove_queue_remaining_callback(self, callback: Callable[[int], None]):
567590
if callback in self._queue_remaining_callbacks:
568591
self._queue_remaining_callbacks.remove(callback)
569592

570-
def watch_display(self, display_node: bool = True, display_task: bool = True):
593+
def watch_display(self, display_node: bool = True, display_task: bool = True, display_node_preview: bool = True):
571594
'''
572595
- `display_node`: When an output node is finished, display its result.
573596
- `display_task`: When a task is finished (all output nodes are finished), display all the results.
574597
'''
575598
self._watch_display_node = display_node
576599
self._watch_display_task = display_task
600+
self._watch_display_node_preview = display_node_preview
577601

578602
async def _put(self, workflow: data.NodeOutput | Iterable[data.NodeOutput] | Workflow, source = None) -> Task | None:
579603
global _client_id
@@ -685,13 +709,23 @@ def __init__(self, prompt_id: str, number: int, id: data.IdManager):
685709
self._id = id
686710
self._new_outputs = {}
687711
self._fut = asyncio.Future()
712+
self._node_preview_callbacks: list[Callable[[Task, str, Image.Image]]] = []
688713

689714
def __str__(self):
690715
return f'Task {self.number} ({self.prompt_id})'
691716

692717
def __repr__(self):
693718
return f'Task(n={self.number}, id={self.prompt_id})'
694719

720+
def _set_node_preview(self, node_id: str, preview: Image.Image, display: bool):
721+
for callback in self._node_preview_callbacks:
722+
callback(self, node_id, preview)
723+
724+
if display:
725+
from IPython.display import display
726+
727+
display(preview, clear=True)
728+
695729
async def _set_result_threadsafe(self, node_id: str | None, output: dict, display_result: bool = False) -> None:
696730
if node_id is not None:
697731
self._new_outputs[node_id] = output
@@ -781,6 +815,12 @@ def wait_result(self, output: data.NodeOutput) -> data.Result | None:
781815
# def __await__(self):
782816
# return self._wait().__await__()
783817

818+
def add_preview_callback(self, callback: Callable[[Task, str, Image.Image], None]):
819+
self._node_preview_callbacks.append(callback)
820+
821+
def remove_preview_callback(self, callback: Callable[[Task, str, Image.Image], None]):
822+
self._node_preview_callbacks.remove(callback)
823+
784824
def done(self) -> bool:
785825
"""Return True if the task is done.
786826

0 commit comments

Comments
 (0)