From 9b7262e5887c9a786b241e83a8e9f1e198878cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=9B=AA=E5=B3=B0?= Date: Fri, 3 Jan 2025 16:38:06 +0800 Subject: [PATCH] optimize code --- birefnetNode.py | 100 ++++++++++++++++++++++++------------------------ pyproject.toml | 2 +- util.py | 28 ++++++++++++-- 3 files changed, 77 insertions(+), 53 deletions(-) diff --git a/birefnetNode.py b/birefnetNode.py index 4bce6b2..a07b92f 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -9,7 +9,7 @@ from birefnet.models.birefnet import BiRefNet from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet from birefnet.utils import check_state_dict -from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground +from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground, filter_mask, add_mask_as_alpha deviceType = model_management.get_torch_device().type models_dir_key = "birefnet" @@ -61,12 +61,12 @@ class ImagePreprocessor(): def __init__(self, resolution) -> None: self.transform_image = transforms.Compose([ transforms.Resize(resolution), - transforms.ToTensor(), + # transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) self.transform_image_old = transforms.Compose([ transforms.Resize(resolution), - transforms.ToTensor(), + # transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]), ]) @@ -195,6 +195,7 @@ def INPUT_TYPES(cls): "blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }), "fill_color": ("BOOLEAN", {"default": False}), "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }), } } @@ -203,57 +204,58 @@ def INPUT_TYPES(cls): FUNCTION = "rem_bg" CATEGORY = "rembg/BiRefNet" - def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None): + def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000): model, version = model model_device_type = next(model.parameters()).device.type - _images = [] - _masks = [] + b, h, w, c = images.shape + image_bchw = images.permute(0, 3, 1, 2) - for image in images: - h, w, c = image.shape - pil_image = tensor_to_pil(image) - - image_preproc = ImagePreprocessor(resolution=(width, height)) - if VERSION[0] == version: - im_tensor = image_preproc.old_proc(pil_image).unsqueeze(0) - else: - im_tensor = image_preproc.proc(pil_image).unsqueeze(0) - - del image_preproc + image_preproc = ImagePreprocessor(resolution=(1024, 1024)) + if VERSION[0] == version: + im_tensor = image_preproc.old_proc(image_bchw) + else: + im_tensor = image_preproc.proc(image_bchw) + _mask_bchw = [] + for each_image in im_tensor: with torch.no_grad(): - mask = model(im_tensor.to(model_device_type))[-1].sigmoid().cpu() - - # 遮罩大小需还原为与原图一致 - mask = comfy.utils.common_upscale(mask, w, h, upscale_method, "disabled") - - # (1, 1, h, w) - mask = normalize_mask(mask) - # (c, h, w) => (c, h, w) - _image_masked = refine_foreground(image.permute(2, 0, 1), mask.squeeze(0), r1=blur_size, r2=blur_size_two).squeeze(0) - # (c, h, w) => (h, w, c) - _image_masked = _image_masked.permute(1, 2, 0) - if fill_color and color is not None: - r = torch.full([h, w, 1], ((color >> 16) & 0xFF) / 0xFF) - g = torch.full([h, w, 1], ((color >> 8) & 0xFF) / 0xFF) - b = torch.full([h, w, 1], (color & 0xFF) / 0xFF) - # (h, w, 3) - background_color = torch.cat((r, g, b), dim=-1) - # (h, w, 1) - apply_mask = mask.squeeze(0).permute(1, 2, 0).expand_as(_image_masked) - _image_masked = _image_masked * apply_mask + background_color * (1 - apply_mask) - # (h, w, 3)=>(1, h, w,3) - image = _image_masked.unsqueeze(0) - del background_color, apply_mask - else: - # image的非mask对应部分设为透明 => (1, h, w, 4) - image = apply_mask_to_image(_image_masked.cpu(), mask.cpu()) - - _images.append(image) - _masks.append(mask.squeeze(0)) - - out_images = torch.cat(_images, dim=0) - out_masks = torch.cat(_masks, dim=0) + each_mask = model(each_image.unsqueeze(0).to(model_device_type))[-1].sigmoid().cpu() + _mask_bchw.append(each_mask) + del each_mask + + mask_bchw = torch.cat(_mask_bchw, dim=0) + del _mask_bchw + # 遮罩大小需还原为与原图一致 + mask = comfy.utils.common_upscale(mask_bchw, w, h, 'bilinear', "disabled") + # (b, 1, h, w) + if mask_threshold > 0: + out_masks = filter_mask(mask, threshold=mask_threshold) + else: + out_masks = normalize_mask(mask) + + # (b, c, h, w) + _image_masked = refine_foreground(image_bchw, out_masks, r1=blur_size, r2=blur_size_two) + # (b, c, h, w) => (b, h, w, c) + _image_masked = _image_masked.permute(0, 2, 3, 1) + if fill_color and color is not None: + r = torch.full([b, h, w, 1], ((color >> 16) & 0xFF) / 0xFF) + g = torch.full([b, h, w, 1], ((color >> 8) & 0xFF) / 0xFF) + b = torch.full([b, h, w, 1], (color & 0xFF) / 0xFF) + # (b, h, w, 3) + background_color = torch.cat((r, g, b), dim=-1) + # (b, 1, h, w) => (b, h, w, 3) + apply_mask = out_masks.permute(0, 2, 3, 1).expand_as(_image_masked) + out_images = _image_masked * apply_mask + background_color * (1 - apply_mask) + # (b, h, w, 3)=>(b, h, w, 3) + del background_color, apply_mask + out_masks = out_masks.squeeze(1) + else: + # (b, 1, h, w) => (b, h, w) + out_masks = out_masks.squeeze(1) + # image的非mask对应部分设为透明 => (b, h, w, 4) + out_images = add_mask_as_alpha(_image_masked.cpu(), out_masks.cpu()) + + del _image_masked return out_images, out_masks diff --git a/pyproject.toml b/pyproject.toml index 0baa12a..837f232 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_birefnet_ll" description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet." -version = "1.0.6" +version = "1.0.7" license = {file = "LICENSE"} dependencies = ["numpy", "opencv-python", "timm"] diff --git a/util.py b/util.py index 2fcb654..bafa1c6 100644 --- a/util.py +++ b/util.py @@ -19,8 +19,7 @@ def refine_foreground(image_tensor, mask_tensor, r1=90, r2=7): if r2 % 2 == 0: r2 += 1 - estimated_foreground = FB_blur_fusion_foreground_estimator_2(image_tensor, mask_tensor, r1=r1, r2=r2) - return estimated_foreground + return FB_blur_fusion_foreground_estimator_2(image_tensor, mask_tensor, r1=r1, r2=r2)[0] def FB_blur_fusion_foreground_estimator_2(image_tensor, alpha_tensor, r1=90, r2=7): @@ -28,7 +27,7 @@ def FB_blur_fusion_foreground_estimator_2(image_tensor, alpha_tensor, r1=90, r2= if alpha_tensor.dim() == 3: alpha_tensor = alpha_tensor.unsqueeze(0) # Add batch F, blur_B = FB_blur_fusion_foreground_estimator(image_tensor, image_tensor, image_tensor, alpha_tensor, r=r1) - return FB_blur_fusion_foreground_estimator(image_tensor, F, blur_B, alpha_tensor, r=r2)[0] + return FB_blur_fusion_foreground_estimator(image_tensor, F, blur_B, alpha_tensor, r=r2) def FB_blur_fusion_foreground_estimator(image_tensor, F_tensor, B_tensor, alpha_tensor, r=90): @@ -103,3 +102,26 @@ def normalize_mask(mask_tensor): normalized_mask = (mask_tensor - min_val) / (max_val - min_val) return normalized_mask + +def add_mask_as_alpha(image, mask): + """ + 将 (b, h, w) 形状的 mask 添加为 (b, h, w, 3) 形状的 image 的第 4 个通道(alpha 通道)。 + """ + # 检查输入形状 + assert image.dim() == 4 and image.size(-1) == 3, "The shape of image should be (b, h, w, 3)." + assert mask.dim() == 3, "The shape of mask should be (b, h, w)" + assert image.size(0) == mask.size(0) and image.size(1) == mask.size(1) and image.size(2) == mask.size(2), "The batch, height, and width dimensions of the image and mask must be consistent" + + # 将 mask 扩展为 (b, h, w, 1) + mask = mask[..., None] + + image = image * mask + # 将 image 和 mask 拼接为 (b, h, w, 4) + image_with_alpha = torch.cat([image, mask], dim=-1) + + return image_with_alpha + +def filter_mask(mask, threshold=4e-3): + mask_binary = mask > threshold + filtered_mask = mask * mask_binary + return filtered_mask