Skip to content

Commit

Permalink
support BirefNet_HR and float16
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Feb 9, 2025
1 parent b9fd2e4 commit 8bda575
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 19 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Support the use of new and old versions of BiRefNet models
### The available newest models are:

- General: A pre-trained model for general use cases.
- General-HR: A pre-trained model for general use cases which shows great performance on higher resolution images (2048x2048).
- General-Lite: A light pre-trained model for general use cases.
- General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440).
- Portrait: A pre-trained model for human portraits.
Expand All @@ -37,6 +38,7 @@ Model files go here (when use AutoDownloadBiRefNetModel automatically downloaded

If necessary, they can be downloaded from:
- [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors)`model.safetensors` must be renamed `General.safetensors`
- [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-HR.safetensors`
- [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite.safetensors`
- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors)`model.safetensors` must be renamed `General-Lite-2K.safetensors`
- [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors)`model.safetensors` must be renamed `Portrait.safetensors`
Expand Down
6 changes: 4 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@
### 最新的模型:

- General: 用于一般用例的预训练模型。
- General-HR: 用于一般用例的预训练模型,在更高分辨率的图像上表现出色(训练分辨率2048x2048)。
- General-Lite: 用于一般用例的轻量级预训练模型。
- General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像(最佳分辨率2560x1440).
- General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像(最佳分辨率2560x1440)
- Portrait: 人物肖像预训练模型。
- Matting: 一种使用无trimap matting的预训练模型。
- DIS: 一种用于二分图像分割(DIS)的预训练模型。
- HRSOD: 一种用于高分辨率显著目标检测(HRSOD)的预训练模型。
- COD: 一种用于隐蔽目标检测(COD)的预训练模型。
- DIS-TR_TEs: 具有大量数据集的预训练模型。

模型文件放在`${comfyui_rootpath}/models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,则会自动下载模型).
模型文件放在`${comfyui_rootpath}/models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,则会自动下载模型)

也可以手动下载模型:
- [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General.safetensors`
- [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General-HR.safetensors`
- [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General-Lite.safetensors`
- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General-Lite-2K.safetensors`
- [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors)`model.safetensors` 重命名为 `Portrait.safetensors`
Expand Down
5 changes: 1 addition & 4 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, bb_index: int = 6) -> None:
'DIS5K': -40,
'COD': -20,
'HRSOD': -20,
'General': -40,
'General': -20,
'General-2K': -20,
'Matting': -20,
}[self.task]
Expand Down Expand Up @@ -154,9 +154,6 @@ def __init__(self, bb_index: int = 6) -> None:
self.lambdas_cls = {
'ce': 5.0
}
# Adv
self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training
self.lambda_adv_d = 3. * (self.lambda_adv_g > 0)

# PATH settings - inactive
# https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms
Expand Down
6 changes: 3 additions & 3 deletions birefnet/models/backbones/swin_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def window_reverse(windows, window_size, H, W):
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
C = int(windows.shape[-1])
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x


Expand Down
30 changes: 21 additions & 9 deletions birefnetNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from birefnet.models.birefnet import BiRefNet
from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet
from birefnet.utils import check_state_dict
from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground, filter_mask, add_mask_as_alpha
from .util import refine_foreground, filter_mask, add_mask_as_alpha
deviceType = model_management.get_torch_device().type

models_dir_key = "birefnet"
Expand All @@ -18,6 +18,7 @@

usage_to_weights_file = {
'General': 'BiRefNet',
'General-HR': 'BiRefNet_HR',
'General-Lite': 'BiRefNet_T',
'General-Lite-2K': 'BiRefNet_lite-2K',
'Portrait': 'BiRefNet-portrait',
Expand All @@ -28,7 +29,7 @@
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
}

modelNameList = ['General', 'General-Lite', 'General-Lite-2K', 'Portrait', 'Matting', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs']
modelNameList = ['General', 'General-HR', 'General-Lite', 'General-Lite-2K', 'Portrait', 'Matting', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs']


def get_model_path(model_name):
Expand Down Expand Up @@ -90,6 +91,11 @@ def old_proc(self, image) -> torch.Tensor:
VERSION = ["old", "v1"]
old_models_name = ["BiRefNet-DIS_ep580.pth", "BiRefNet-ep480.pth"]

torch_dtype={
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}

class AutoDownloadBiRefNetModel:

Expand All @@ -99,6 +105,9 @@ def INPUT_TYPES(cls):
"required": {
"model_name": (modelNameList,),
"device": (["AUTO", "CPU"],)
},
"optional": {
"dtype": (["float32", "float16"], {"default": "float32"})
}
}

Expand All @@ -108,7 +117,7 @@ def INPUT_TYPES(cls):
CATEGORY = "image/BiRefNet"
DESCRIPTION = "Auto download BiRefNet model from huggingface to models/BiRefNet/{model_name}.safetensors"

def load_model(self, model_name, device):
def load_model(self, model_name, device, dtype="float32"):
bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" else 6
biRefNet_model = BiRefNet(bb_pretrained=False, bb_index=bb_index)
model_file_name = f'{model_name}.safetensors'
Expand All @@ -122,7 +131,7 @@ def load_model(self, model_name, device):
device_type = "cpu"
state_dict = safetensors.torch.load_file(model_full_path, device=device_type)
biRefNet_model.load_state_dict(state_dict)
biRefNet_model.to(device_type)
biRefNet_model.to(device_type, dtype=torch_dtype[dtype])
biRefNet_model.eval()
return [(biRefNet_model, VERSION[1])]

Expand All @@ -137,7 +146,8 @@ def INPUT_TYPES(cls):
"device": (["AUTO", "CPU"], )
},
"optional": {
"use_weight": ("BOOLEAN", {"default": False})
"use_weight": ("BOOLEAN", {"default": False}),
"dtype": (["float32", "float16"], {"default": "float32"})
}
}

Expand All @@ -147,7 +157,7 @@ def INPUT_TYPES(cls):
CATEGORY = "rembg/BiRefNet"
DESCRIPTION = "Load BiRefNet model from folder models/BiRefNet or the path of birefnet configured in the extra YAML file"

def load_model(self, model, device, use_weight=False):
def load_model(self, model, device, use_weight=False, dtype="float32"):
if model in old_models_name:
version = VERSION[0]
biRefNet_model = OldBiRefNet(bb_pretrained=use_weight)
Expand All @@ -168,7 +178,7 @@ def load_model(self, model, device, use_weight=False):
state_dict = check_state_dict(state_dict)

biRefNet_model.load_state_dict(state_dict)
biRefNet_model.to(device_type)
biRefNet_model.to(device_type, dtype=torch_dtype[dtype])
biRefNet_model.eval()
return [(biRefNet_model, version)]

Expand Down Expand Up @@ -211,7 +221,9 @@ def INPUT_TYPES(cls):

def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilinear', mask_threshold=0.000):
model, version = model
model_device_type = next(model.parameters()).device.type
one_torch = next(model.parameters())
model_device_type = one_torch.device.type
model_dtype = one_torch.dtype
b, h, w, c = images.shape
image_bchw = images.permute(0, 3, 1, 2)

Expand All @@ -226,7 +238,7 @@ def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilin
_mask_bchw = []
for each_image in im_tensor:
with torch.no_grad():
each_mask = model(each_image.unsqueeze(0).to(model_device_type))[-1].sigmoid().cpu()
each_mask = model(each_image.unsqueeze(0).to(model_device_type, dtype=model_dtype))[-1].sigmoid().cpu().float()
_mask_bchw.append(each_mask)
del each_mask

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_birefnet_ll"
description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet, RembgByBiRefNetAdvanced, GetMaskByBiRefNet, BlurFusionForegroundEstimation."
version = "1.1.0"
version = "1.1.1"
license = {file = "LICENSE"}
dependencies = ["numpy", "opencv-python", "timm"]

Expand Down

0 comments on commit 8bda575

Please sign in to comment.