diff --git a/birefnetNode.py b/birefnetNode.py index cd5bfaf..ce2c63d 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -199,7 +199,7 @@ def rem_bg(self, model, images): image = apply_mask_to_image(image.cpu(), mask.cpu()) _images.append(image) - _masks.append(mask) + _masks.append(mask.squeeze(0)) out_images = torch.cat(_images, dim=0) out_masks = torch.cat(_masks, dim=0)