Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@

import comfy.model_management

def force_bhw3(image):
#convert [CHW, BCHW, CWH] to BHW3
was_list = False

while isinstance(image, list):
was_list = True
image = image[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this only going to fix the first image in the list? If we get a list of images, we should be fixing all the images in the list (which may all have different dimensions). (I would probably make this change outside of the call to force_bhw3 so that it applies to any other type validation we add in the future.)

Copy link
Contributor Author

@shawnington shawnington Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes the image tensor was wrapped in an extraneous list, it did not seem to be related to batch size in any way, its just a difference in formatting from the ways it comes for input, and the way its structured for output. I have no idea why it's like that, but the data does not change between.

Granted it's a very ugly hack, could have gone for an if instead of a while. I just really really wanted to get rid of the list wrapper before you know, I added it back for formatting reasons at the end if it was removed.

I could be wrong. That whole function is likely to evolve considerably as I start to expose it to a wide variety of edge cases, such as 3x3x3x3, and also take into consideration the suggestions and further discord discussion we have had.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not frequently used, but ComfyUI does have functionality for "list" outputs/inputs that are different than batches. Specifically, if a node returns a batch of 5 images, those 5 images will all be passed to the following node for one execution. If a node returns a list of 5 images (each with a batch size of 1), the following node will actually have its execution function called 5 times. In order to support this, outputs are usually passed around wrapped in a list. I believe that's what you were seeing.

To continue to support that functionality, it's important that we process each entry in the list the same way. Someone in Discord/Matrix might have suggestions for real nodes that make use of that functionality so you can test it. I think some of the nodes used for making X/Y plots use it.


if len(image.shape) == 3:
#add batch dimension
image = image.unsqueeze(0)

if image.shape[1] == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if somebody just happens to have a 3x3 image somewhere? This will corrupt it.

I imagine if check [2] and [3] are both >4 then the conversion is confident. But then the conversion only happens on larger images, and quietly doesn't apply to smaller images.
... This type of thing is a good example of why it might be better to warn than to try to autocorrect - I don't think there's actually a 100% reliable detection of format in all cases, just "99% of the time it's right" heuristic checks like this one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also mentioned in Discord, but copying here for documentation)
+1 for just warning (and eventually erroring) in any of the slightly ambiguous cases. If we're going to auto fix anything, it should just be the BWH case (which is easy to differentiate due to the tensor only having 3 dimensions). That's probably the most common error anyway.

#BCHW color
image = image.permute(0, 2, 3, 1)
return image if not was_list else [image]

if image.shape[1] == 1:
#BCWH black and white
image = image.permute(0, 3, 2, 1).expand(-1, -1, -1, 3)
return image if not was_list else [image]

return image if not was_list else [image]

def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
Expand Down Expand Up @@ -42,6 +66,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
return input_data_all

def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
#Ensure image inputs are in BHW3 format
input_types = obj.INPUT_TYPES()
for _, v in input_types.items():
if isinstance(v, dict):
for k2, v2 in v.items():
if v2[0] == "IMAGE":
input_data_all[k2] = [force_bhw3(x) for x in input_data_all[k2]]
Copy link
Collaborator

@guill guill Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only validate things coming out of nodes. map_node_over_list is used for some cases that aren't actual node execution and there may be awkward side-effects. Additionally, if we're going to throw a warning/error, that warning/error would be attributed to a node that isn't the real culprit.

As long as we're validating all outputs, the only additional thing this protects people from is a node with output using the silly __ne__ trick. In that case, we can't even be sure that it was intended to be an image -- it might have been an audio clip.

In either case, validating the input in this way still wouldn't catch undeclared inputs (i.e. any node taking a variable number of inputs).

Copy link
Contributor Author

@shawnington shawnington Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all evaluated on execution, so you might be right, or it might work, if the variable number of nodes is codified at execution time, it shouldn't be an issue. All the information one what is being put in is gleaned from the states run through execution.py, and I made sure not to touch the get_input_data function for this reason. It needs rigorous testing with all the different kind of node presentations it can run up against. However since it mainly parses data once its already in a state where its linked to other node outputs/inputs I think it will work, unless there is something I am not understanding about how there variable input nodes work, which is possible.

Also the issue with evaluating after map_nodes_over_list , is that all the type info is stripped from the output after that, so another variable would need to be added to the return that includes the type information and index position of that type, or the type information would need to be included in the output, and subsequent code would need a rewrite to handle that information being present. Brighter minds than me can probably figure out a way, but this was the most logical way I could think of after tearing apart the output at various stages.

Also, side note. If someone manages to pass audio through the IMAGE pipe with this in place, I'd not only be curious, Id wonder why they chose the image pipe. If they do, we can just call it an undocumented feature, lol

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all evaluated on execution, so you might be right, or it might work, if the variable number of nodes is codified at execution time, it shouldn't be an issue.

To be clear, this is definitely evaluated on execution, but this function is also called at other times -- like calling IS_CHANGED. In that function, it's totally legitimate to have an IMAGE with a value of None (in fact, that's the expected value since only constant inputs are available when IS_CHANGED is called before the graph has begun execution.

I would put the validation right after the call to get_output_data in recursive_execute. At that point, you have information about the node that was executed like its output types.


# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
Expand Down Expand Up @@ -73,6 +105,26 @@ def slice_dict(d, i):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))

#Ensure IMAGE outputs conform to BHWC
return_indexs = {}
formated_results = []

if hasattr(obj, "RETURN_NAMES") and hasattr(obj, "RETURN_TYPES"):
for i, t in enumerate(obj.RETURN_TYPES):
return_indexs[i] = t

for i, r in enumerate(results[0]):
if return_indexs[i] == "IMAGE":
print(f"Result: {force_bhw3(r).shape}")
formated_results.append(force_bhw3(r))
else:
formated_results.append(r)

results = [tuple(formated_results)]

del formated_results

return results

def get_output_data(obj, input_data_all):
Expand Down