From 03ff6c2da0159e712d1dc3ae0c9886ed22eec428 Mon Sep 17 00:00:00 2001 From: Ivan R Date: Wed, 25 Sep 2024 03:40:34 +0500 Subject: [PATCH 1/2] Fix mask redundant dims --- birefnetNode.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/birefnetNode.py b/birefnetNode.py index cd5bfaf..8021fdf 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -199,11 +199,14 @@ def rem_bg(self, model, images): image = apply_mask_to_image(image.cpu(), mask.cpu()) _images.append(image) - _masks.append(mask) + _masks.append(mask.squeeze(0)) out_images = torch.cat(_images, dim=0) out_masks = torch.cat(_masks, dim=0) + if out_masks.shape[0] == 1: + out_masks = out_masks.squeeze(0) + return out_images, out_masks From 09bb7452d9efee7a707808ecb51a8a5d071b295a Mon Sep 17 00:00:00 2001 From: Ivan R Date: Tue, 1 Oct 2024 21:50:42 +0500 Subject: [PATCH 2/2] Revert fix for batch size 1 --- birefnetNode.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/birefnetNode.py b/birefnetNode.py index 8021fdf..ce2c63d 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -204,9 +204,6 @@ def rem_bg(self, model, images): out_images = torch.cat(_images, dim=0) out_masks = torch.cat(_masks, dim=0) - if out_masks.shape[0] == 1: - out_masks = out_masks.squeeze(0) - return out_images, out_masks