Skip to content

Commit

Permalink
optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Jan 3, 2025
1 parent 868b271 commit 9b7262e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 53 deletions.
100 changes: 51 additions & 49 deletions birefnetNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from birefnet.models.birefnet import BiRefNet
from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet
from birefnet.utils import check_state_dict
from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground
from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground, filter_mask, add_mask_as_alpha
deviceType = model_management.get_torch_device().type

models_dir_key = "birefnet"
Expand Down Expand Up @@ -61,12 +61,12 @@ class ImagePreprocessor():
def __init__(self, resolution) -> None:
self.transform_image = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
# transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
self.transform_image_old = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
# transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]),
])

Expand Down Expand Up @@ -195,6 +195,7 @@ def INPUT_TYPES(cls):
"blur_size_two": ("INT", {"default": 7, "min": 1, "max": 255, "step": 2, }),
"fill_color": ("BOOLEAN", {"default": False}),
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
"mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }),
}
}

Expand All @@ -203,57 +204,58 @@ def INPUT_TYPES(cls):
FUNCTION = "rem_bg"
CATEGORY = "rembg/BiRefNet"

def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None):
def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000):
model, version = model
model_device_type = next(model.parameters()).device.type
_images = []
_masks = []
b, h, w, c = images.shape
image_bchw = images.permute(0, 3, 1, 2)

for image in images:
h, w, c = image.shape
pil_image = tensor_to_pil(image)

image_preproc = ImagePreprocessor(resolution=(width, height))
if VERSION[0] == version:
im_tensor = image_preproc.old_proc(pil_image).unsqueeze(0)
else:
im_tensor = image_preproc.proc(pil_image).unsqueeze(0)

del image_preproc
image_preproc = ImagePreprocessor(resolution=(1024, 1024))
if VERSION[0] == version:
im_tensor = image_preproc.old_proc(image_bchw)
else:
im_tensor = image_preproc.proc(image_bchw)

_mask_bchw = []
for each_image in im_tensor:
with torch.no_grad():
mask = model(im_tensor.to(model_device_type))[-1].sigmoid().cpu()

# 遮罩大小需还原为与原图一致
mask = comfy.utils.common_upscale(mask, w, h, upscale_method, "disabled")

# (1, 1, h, w)
mask = normalize_mask(mask)
# (c, h, w) => (c, h, w)
_image_masked = refine_foreground(image.permute(2, 0, 1), mask.squeeze(0), r1=blur_size, r2=blur_size_two).squeeze(0)
# (c, h, w) => (h, w, c)
_image_masked = _image_masked.permute(1, 2, 0)
if fill_color and color is not None:
r = torch.full([h, w, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([h, w, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([h, w, 1], (color & 0xFF) / 0xFF)
# (h, w, 3)
background_color = torch.cat((r, g, b), dim=-1)
# (h, w, 1)
apply_mask = mask.squeeze(0).permute(1, 2, 0).expand_as(_image_masked)
_image_masked = _image_masked * apply_mask + background_color * (1 - apply_mask)
# (h, w, 3)=>(1, h, w,3)
image = _image_masked.unsqueeze(0)
del background_color, apply_mask
else:
# image的非mask对应部分设为透明 => (1, h, w, 4)
image = apply_mask_to_image(_image_masked.cpu(), mask.cpu())

_images.append(image)
_masks.append(mask.squeeze(0))

out_images = torch.cat(_images, dim=0)
out_masks = torch.cat(_masks, dim=0)
each_mask = model(each_image.unsqueeze(0).to(model_device_type))[-1].sigmoid().cpu()
_mask_bchw.append(each_mask)
del each_mask

mask_bchw = torch.cat(_mask_bchw, dim=0)
del _mask_bchw
# 遮罩大小需还原为与原图一致
mask = comfy.utils.common_upscale(mask_bchw, w, h, 'bilinear', "disabled")
# (b, 1, h, w)
if mask_threshold > 0:
out_masks = filter_mask(mask, threshold=mask_threshold)
else:
out_masks = normalize_mask(mask)

# (b, c, h, w)
_image_masked = refine_foreground(image_bchw, out_masks, r1=blur_size, r2=blur_size_two)
# (b, c, h, w) => (b, h, w, c)
_image_masked = _image_masked.permute(0, 2, 3, 1)
if fill_color and color is not None:
r = torch.full([b, h, w, 1], ((color >> 16) & 0xFF) / 0xFF)
g = torch.full([b, h, w, 1], ((color >> 8) & 0xFF) / 0xFF)
b = torch.full([b, h, w, 1], (color & 0xFF) / 0xFF)
# (b, h, w, 3)
background_color = torch.cat((r, g, b), dim=-1)
# (b, 1, h, w) => (b, h, w, 3)
apply_mask = out_masks.permute(0, 2, 3, 1).expand_as(_image_masked)
out_images = _image_masked * apply_mask + background_color * (1 - apply_mask)
# (b, h, w, 3)=>(b, h, w, 3)
del background_color, apply_mask
out_masks = out_masks.squeeze(1)
else:
# (b, 1, h, w) => (b, h, w)
out_masks = out_masks.squeeze(1)
# image的非mask对应部分设为透明 => (b, h, w, 4)
out_images = add_mask_as_alpha(_image_masked.cpu(), out_masks.cpu())

del _image_masked

return out_images, out_masks

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_birefnet_ll"
description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet."
version = "1.0.6"
version = "1.0.7"
license = {file = "LICENSE"}
dependencies = ["numpy", "opencv-python", "timm"]

Expand Down
28 changes: 25 additions & 3 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ def refine_foreground(image_tensor, mask_tensor, r1=90, r2=7):
if r2 % 2 == 0:
r2 += 1

estimated_foreground = FB_blur_fusion_foreground_estimator_2(image_tensor, mask_tensor, r1=r1, r2=r2)
return estimated_foreground
return FB_blur_fusion_foreground_estimator_2(image_tensor, mask_tensor, r1=r1, r2=r2)[0]


def FB_blur_fusion_foreground_estimator_2(image_tensor, alpha_tensor, r1=90, r2=7):
# https://github.com/Photoroom/fast-foreground-estimation
if alpha_tensor.dim() == 3:
alpha_tensor = alpha_tensor.unsqueeze(0) # Add batch
F, blur_B = FB_blur_fusion_foreground_estimator(image_tensor, image_tensor, image_tensor, alpha_tensor, r=r1)
return FB_blur_fusion_foreground_estimator(image_tensor, F, blur_B, alpha_tensor, r=r2)[0]
return FB_blur_fusion_foreground_estimator(image_tensor, F, blur_B, alpha_tensor, r=r2)


def FB_blur_fusion_foreground_estimator(image_tensor, F_tensor, B_tensor, alpha_tensor, r=90):
Expand Down Expand Up @@ -103,3 +102,26 @@ def normalize_mask(mask_tensor):
normalized_mask = (mask_tensor - min_val) / (max_val - min_val)

return normalized_mask

def add_mask_as_alpha(image, mask):
"""
将 (b, h, w) 形状的 mask 添加为 (b, h, w, 3) 形状的 image 的第 4 个通道(alpha 通道)。
"""
# 检查输入形状
assert image.dim() == 4 and image.size(-1) == 3, "The shape of image should be (b, h, w, 3)."
assert mask.dim() == 3, "The shape of mask should be (b, h, w)"
assert image.size(0) == mask.size(0) and image.size(1) == mask.size(1) and image.size(2) == mask.size(2), "The batch, height, and width dimensions of the image and mask must be consistent"

# 将 mask 扩展为 (b, h, w, 1)
mask = mask[..., None]

image = image * mask
# 将 image 和 mask 拼接为 (b, h, w, 4)
image_with_alpha = torch.cat([image, mask], dim=-1)

return image_with_alpha

def filter_mask(mask, threshold=4e-3):
mask_binary = mask > threshold
filtered_mask = mask * mask_binary
return filtered_mask

0 comments on commit 9b7262e

Please sign in to comment.