diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 29589b4abade..6c803626e013 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -5,7 +5,19 @@ from nodes import MAX_RESOLUTION +def tensor_to_rgb(img): + # convert from bw to rgb : cwh -> bchw + if img.shape[1] != 3: + if len(img.shape) == 3: + img = img.unsqueeze(0) + img = img.permute(0, 1, 3, 2).repeat(1, 3, 1, 1) + + return img + def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + destination = tensor_to_rgb(destination) + source = tensor_to_rgb(source) + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")