Skip to content

Commit

Permalink
support old models(birefnet code from hash 2b47f82c1100b7daebd5ab24a7…
Browse files Browse the repository at this point in the history
…904a5052576dfe), weight model
  • Loading branch information
刘雪峰 committed Sep 13, 2024
1 parent 8ba3553 commit 28e3ce7
Show file tree
Hide file tree
Showing 32 changed files with 2,924 additions and 25 deletions.
24 changes: 20 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## Example
![save api extended](example/workflow_base.png)
[中文文档](README_CN.md)
## Preview
![save api extended](doc/base.png)
![save api extended](doc/video.gif)

## Install

Expand All @@ -11,11 +13,12 @@
pip install -r requirements.txt
# restart ComfyUI
```
- Via ComfyUI Manager


## Models

The available models are:
### The available newest models are:

- General: A pre-trained model for general use cases.
- General-Lite: A light pre-trained model for general use cases.
Expand All @@ -36,13 +39,26 @@ If necessary, they can be downloaded from:
- [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors)`model.safetensors` must be renamed `COD.safetensors`
- [DIS-TR_TEs](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K-TR_TEs/resolve/main/model.safetensors)`model.safetensors` must be renamed `DIS-TR_TEs.safetensors`

Some models on GitHub:
[BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases)

### Old models:
- [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth)
- [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth)

## Weight Models (Optional)
- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)
- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms), just General-Lite model


## Nodes
- AutoDownloadBiRefNetModel
- Automatically download the model into models/BiRefNet
- LoadRembgByBiRefNetModel
- Can select model from "models/BiRefNet" or the path of "birefnet" configured in the extra YAML file
- You can download model from [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases)
- You can download latest models from [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) or old models [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth) and [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth)
- When param use_weight is True, need download weight model [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)
model General-Lite must use weight model [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)
- RembgByBiRefNet

## Thanks
Expand Down
70 changes: 70 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
[English](README.md)
## 预览
![save api extended](doc/base.png)
![save api extended](doc/video.gif)

## 安装

- 手动安装
```shell
cd custom_nodes
git clone https://github.com/lldacing/ComfyUI_BiRefNet_ll.git
cd ComfyUI_BiRefNet_ll
pip install -r requirements.txt
# restart ComfyUI
```
- ComfyUI管理器搜索安装


## 模型

### 最新的模型:

- General: 用于一般用例的预训练模型。
- General-Lite: 用于一般用例的轻量级预训练模型。
- Portrait: 人物肖像预训练模型。
- DIS: 一种用于二分图像分割(DIS)的预训练模型。
- HRSOD: 一种用于高分辨率显著目标检测(HRSOD)的预训练模型。
- COD: 一种用于隐蔽目标检测(COD)的预训练模型。
- DIS-TR_TEs: 具有大量数据集的预训练模型。

模型文件放在`models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,如果第一次运行时文件夹不存在,则会自动下载.

也可以手动下载模型:
- [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General.safetensors`
- [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors)`model.safetensors` 重命名为 `General-Lite.safetensors`
- [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors)`model.safetensors` 重命名为 `Portrait.safetensors`
- [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors)`model.safetensors` 重命名为 `DIS.safetensors`
- [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors)`model.safetensors` 重命名为 `HRSOD.safetensors`
- [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors)`model.safetensors` 重命名为 `COD.safetensors`
- [DIS-TR_TEs](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K-TR_TEs/resolve/main/model.safetensors)`model.safetensors` 重命名为 `DIS-TR_TEs.safetensors`


GitHub上的模型:
[BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases)

### 旧模型:
- [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth)
- [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth)

## 权重模型(非必须)
下载放在`models/BiRefNet`
- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)
- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms), just General-Lite model


## 节点
- AutoDownloadBiRefNetModel
- 自动下载模型到 `models/BiRefNet`,不支持权重
- LoadRembgByBiRefNetModel
-`models/BiRefNet` 和 在extra YAML 文件中通过`birefnet`配置的路径中选择模型
- 支持 [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) 中的新模型 和 老的模型[BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth)[BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth)
- 参数use_weight设为True时, 需要下载权重模型,General-Lite模型使用[swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms),其它模型使用 [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)
- RembgByBiRefNet

## 感谢

[BiRefNet](https://github.com/zhengpeng7/birefnet)

[dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet)

23 changes: 11 additions & 12 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,17 @@ def __init__(self, bb_index: int = 6) -> None:
# 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
# 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
# }
# weight_paths_name = "birefnet"
# self.weights = {
# 'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'),
# 'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
# 'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
# 'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
# 'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
# 'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
# 'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]),
# 'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]),
# }
self.weights = {}
weight_paths_name = "birefnet"
self.weights = {
'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'),
'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]),
'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]),
}

# Callbacks - inactive
self.verbose_eval = True
Expand Down
2 changes: 1 addition & 1 deletion birefnet/models/backbones/build_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def build_backbone(bb_name, pretrained=True, params_settings=''):
return bb

def load_weights(model, model_name):
safetensors.torch.load_file
# safetensors.torch.load_file
save_model = torch.load(config.weights[model_name], map_location='cpu')
model_dict = model.state_dict()
state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()}
Expand Down
40 changes: 33 additions & 7 deletions birefnetNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from comfy import model_management
import folder_paths
from birefnet.models.birefnet import BiRefNet
from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet
from birefnet.utils import check_state_dict
from .util import tensor_to_pil, apply_mask_to_image
from .util import tensor_to_pil, apply_mask_to_image, normalize_mask

deviceType = model_management.get_torch_device().type

Expand Down Expand Up @@ -71,6 +72,16 @@ def download_birefnet_model(model_name):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
old_proc_img = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]),
]
)

VERSION = ["old", "v1"]
old_models_name = ["BiRefNet-DIS_ep580.pth", "BiRefNet-ep480.pth"]


class AutoDownloadBiRefNetModel:
Expand Down Expand Up @@ -106,7 +117,7 @@ def load_model(self, model_name, device):
biRefNet_model.load_state_dict(state_dict)
biRefNet_model.to(device_type)
biRefNet_model.eval()
return biRefNet_model,
return [(biRefNet_model, VERSION[1])]


class LoadRembgByBiRefNetModel:
Expand All @@ -117,6 +128,9 @@ def INPUT_TYPES(cls):
"required": {
"model": (folder_paths.get_filename_list(models_dir_key),),
"device": (["AUTO", "CPU"], )
},
"optional": {
"use_weight": ("BOOLEAN", {"default": False})
}
}

Expand All @@ -126,8 +140,15 @@ def INPUT_TYPES(cls):
CATEGORY = "rembg/BiRefNet"
DESCRIPTION = "Load BiRefNet model from folder models/BiRefNet or the path of birefnet configured in the extra YAML file"

def load_model(self, model, device):
biRefNet_model = BiRefNet(bb_pretrained=False, bb_index=6)
def load_model(self, model, device, use_weight=False):
if model in old_models_name:
version = VERSION[0]
biRefNet_model = OldBiRefNet(bb_pretrained=use_weight)
else:
version = VERSION[1]
bb_index = 3 if model == "General-Lite.safetensors" else 6
biRefNet_model = BiRefNet(bb_pretrained=use_weight, bb_index=bb_index)

model_path = folder_paths.get_full_path(models_dir_key, model)
if device == "AUTO":
device_type = deviceType
Expand All @@ -142,7 +163,7 @@ def load_model(self, model, device):
biRefNet_model.load_state_dict(state_dict)
biRefNet_model.to(device_type)
biRefNet_model.eval()
return [biRefNet_model]
return [(biRefNet_model, version)]


class RembgByBiRefNet:
Expand All @@ -162,22 +183,27 @@ def INPUT_TYPES(cls):
CATEGORY = "rembg/BiRefNet"

def rem_bg(self, model, images):
model, version = model
_images = []
_masks = []

for image in images:
h, w, c = image.shape
pil_image = tensor_to_pil(image)

im_tensor = proc_img(pil_image).unsqueeze(0)
if VERSION[0] == version:
im_tensor = old_proc_img(pil_image).unsqueeze(0)
else:
im_tensor = proc_img(pil_image).unsqueeze(0)

with torch.no_grad():
mask = model(im_tensor.to(deviceType))[-1].sigmoid().cpu()

# 遮罩大小需还原为与原图一致
mask = comfy.utils.common_upscale(mask, w, h, 'bilinear', "disabled")

# image的mask对应部分设为透明
mask = normalize_mask(mask)
# image的非mask对应部分设为透明
image = apply_mask_to_image(image.cpu(), mask.cpu())

_images.append(image)
Expand Down
Empty file added birefnet_old/__init__.py
Empty file.
Loading

0 comments on commit 28e3ce7

Please sign in to comment.