diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 2397de3d624a..f1d04059d627 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -128,6 +128,9 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") +parser.add_argument("--type-conformance", type=str, choices=['all', 'images', 'error', 'none'], default='none', help="Level of type conformance for input/output values between nodes. 'all/images' will attempt to convert all or only image values to the expected type, 'error' will raise an error if the types do not match.") + + # The default built-in provider hosted under web/ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" diff --git a/execution.py b/execution.py index 0a2e62e7ec47..b361a57eec78 100644 --- a/execution.py +++ b/execution.py @@ -13,6 +13,130 @@ import comfy.model_management +from comfy.cli_args import args + + +import sys +import copy +import logging +import threading +import heapq +import time +import traceback +import inspect +from typing import List, Literal, NamedTuple, Optional + +import torch +import nodes + +import comfy.model_management + +from comfy.cli_args import args + + +def validate_image_shape(image, node_type, output_name): + image_len = len(image.shape) + + def type_check(): + print(f"Checking Image Shape: {image.shape}") + if image_len != 4 or image.shape[-1] > 4 or image.shape[-1] < 3: + if "error" == args.type_conformance: + raise ValueError(f"Image Shape Error: Node: {node_type}, Output: {output_name}, [{image.shape}] does not match the expected RGB format: torch.Size[B, H, W, 3] or the RGBA format: torch.Size[B, H, W, 4]") + else: + print(f"Image Shape Error: Node: {node_type}, Output: {output_name}, [{image.shape}] does not match the expected RGB format: torch.Size[B, H, W, 3] or the RGBA format: torch.Size[B, H, W, 4]") + return False + return True + + if type_check() != True and args.type_conformance in ["all", "images"]: + transforms = { + "HW": lambda t: t.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3), + "BHW": lambda t: t.unsqueeze(-1).expand(-1, -1, -1, 3), + "HWC": lambda t: t.unsqueeze(0), + } + + if image_len == 2: + #HW -> add Batch and Channel dimensions + image = transforms["HW"](image) + return image + + if image_len == 3: + if image.shape[-1] > 4 or image.shape[-1] < 3: + #BHW -> add Channel dimension + image = transforms["BHW"](image) + return image + + if image.shape[-1] == 3 or image.shape[-1] == 4: + #HW3 or HW4 -> add Batch dimension --- can misbehave in edge cases where image is BHW but W is 3 or 4 px + image = transforms["HWC"](image) + return image + + return image + + +def generate_type_validator(valid_types, validator=None): + def validate_type(node_type, output_name, value): + print(type(validator)) + if not isinstance(value, valid_types) and validator is None: + print(f"TypeError in Node: {node_type}. Expected {output_name} to be one of type(s): {[t.__name__ for t in valid_types]}, got {[type(value).__name__]} instead") + + if "all" == args.type_conformance: + return valid_types[0](value) + + if "error" == args.type_conformance: + raise TypeError(f"Expected: [{output_name}], to be one of type(s): {[vtype.__name__ for vtype in valid_types]}, got {[type(value).__name__]} instead") + + return value + + if isinstance(value, valid_types) and validator is not None: + print(f"Validating {output_name} in Node: {node_type}") + value = validator(value, node_type, output_name) + return value + + return value + + return validate_type + + +def input_validation(input_data_all, obj): + validation_funcs = { + "IMAGE": generate_type_validator((torch.Tensor,), validator=validate_image_shape), + "INT": generate_type_validator((int,)), + "FLOAT": generate_type_validator((float,)), + "STRING": generate_type_validator((str,)), + } + input_types = obj.INPUT_TYPES() + for _, v in input_types.items(): + if isinstance(v, dict): + for k2, v2 in v.items(): + if tuple(v2[0]) in validation_funcs.keys(): + input_data_all[k2] = [validation_funcs[v2[0]](obj.__class__.__name__, k2, x) for x in input_data_all[k2]] + + return input_data_all + + +def output_validation(results, obj): + validation_funcs = { + "IMAGE": generate_type_validator((torch.Tensor,), validator=validate_image_shape), + "INT": generate_type_validator((int,)), + "FLOAT": generate_type_validator((float,)), + "STRING": generate_type_validator((str,)), + } + if hasattr(obj, "RETURN_NAMES") and hasattr(obj, "RETURN_TYPES"): + return_indexs = {} + formatted_results = [] + + for i, return_type in enumerate(obj.RETURN_TYPES): + return_indexs[i] = return_type + + for i, result in enumerate(results[0]): + return_type = return_indexs[i] + formatted_results.append(validation_funcs[return_type](obj.__class__.__name__, obj.RETURN_NAMES[i], result) if return_type in validation_funcs.keys() else result) + + results = [tuple(formatted_results)] + del formatted_results + return results + + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -28,7 +152,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -39,9 +163,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + return input_data_all + def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + input_data_all = input_validation(input_data_all, obj) + # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -73,14 +201,16 @@ def slice_dict(d, i): if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + + results = output_validation(results, obj) return results + def get_output_data(obj, input_data_all): - results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - + for r in return_values: if isinstance(r, dict): if 'ui' in r: