From 69aa6740bbdee9b9dcaf26022deef6a2e4879b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=9B=AA=E5=B3=B0?= Date: Tue, 22 Oct 2024 17:24:46 +0800 Subject: [PATCH] support model General-Lite-2K and model Matting --- README.md | 8 ++++++-- README_CN.md | 8 ++++++-- birefnet/config.py | 21 +++++++++++++++------ birefnetNode.py | 10 ++++++---- pyproject.toml | 2 +- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 1e00ed8..3cf57ba 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,9 @@ Support the use of new and old versions of BiRefNet models - General: A pre-trained model for general use cases. - General-Lite: A light pre-trained model for general use cases. +- General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440). - Portrait: A pre-trained model for human portraits. +- Matting: A pre-trained model for general trimap-free matting use. - DIS: A pre-trained model for dichotomous image segmentation (DIS). - HRSOD: A pre-trained model for high-resolution salient object detection (HRSOD). - COD: A pre-trained model for concealed object detection (COD). @@ -36,7 +38,9 @@ Model files go here (when use AutoDownloadBiRefNetModel automatically downloaded If necessary, they can be downloaded from: - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General.safetensors` - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite.safetensors` +- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite-2K.safetensors` - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Portrait.safetensors` +- [Matting](https://huggingface.co/ZhengPeng7/BiRefNet-matting/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Matting.safetensors` - [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `DIS.safetensors` - [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `HRSOD.safetensors` - [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `COD.safetensors` @@ -50,8 +54,8 @@ Some models on GitHub: - [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth) ## Weight Models (Optional) -- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(not General-Lite model) -- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(just General-Lite model) +- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(not General-Lite and General-Lite-2K model) +- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(just General-Lite and General-Lite-2K model) ## Nodes diff --git a/README_CN.md b/README_CN.md index be6274f..246cffc 100644 --- a/README_CN.md +++ b/README_CN.md @@ -25,7 +25,9 @@ - General: 用于一般用例的预训练模型。 - General-Lite: 用于一般用例的轻量级预训练模型。 +- General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像。 (最佳分辨率2560x1440). - Portrait: 人物肖像预训练模型。 +- Matting: 一种使用无trimap matting的预训练模型。 - DIS: 一种用于二分图像分割(DIS)的预训练模型。 - HRSOD: 一种用于高分辨率显著目标检测(HRSOD)的预训练模型。 - COD: 一种用于隐蔽目标检测(COD)的预训练模型。 @@ -36,7 +38,9 @@ 也可以手动下载模型: - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General.safetensors` - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite.safetensors` +- [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite-2K.safetensors` - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Portrait.safetensors` +- [Matting](https://huggingface.co/ZhengPeng7/BiRefNet-matting/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Matting.safetensors` - [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `DIS.safetensors` - [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `HRSOD.safetensors` - [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `COD.safetensors` @@ -52,8 +56,8 @@ GitHub上的模型: ## 权重模型(非必须) 下载放在`models/BiRefNet` -- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(非General-Lite模型) -- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(仅General-Lite模型) +- [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(非General-Lite和General-Lite-2K模型) +- [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(仅General-Lite和General-Lite-2K模型) ## 节点 diff --git a/birefnet/config.py b/birefnet/config.py index 5ed667b..78e4cc6 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -8,20 +8,29 @@ class Config: def __init__(self, bb_index: int = 6) -> None: # PATH settings # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx - # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][1] # Default, custom + # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][0] # Default, custom # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') # TASK settings self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0] + self.testsets = { + # Benchmarks + 'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']), + 'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']), + 'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']), + # Practical use + 'General': ','.join(['DIS-VD', 'TE-P3M-500-NP']), + 'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']), + 'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']), + }[self.task] + # datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')]) self.training_set = { 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 'COD': 'TR-COD10K+TR-CAMO', 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], - # 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), # leave DIS-VD,TE-P3M-500-NP for evaluation. - 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', - # 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), - 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', - 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', + 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all + 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all + 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', # datasets_all }[self.task] self.prompt4loc = ['dense', 'sparse'][0] diff --git a/birefnetNode.py b/birefnetNode.py index cd5bfaf..69ceaa9 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -20,14 +20,16 @@ usage_to_weights_file = { 'General': 'BiRefNet', 'General-Lite': 'BiRefNet_T', + 'General-Lite-2K': 'BiRefNet_lite-2K', 'Portrait': 'BiRefNet-portrait', + 'Matting': 'BiRefNet-matting', 'DIS': 'BiRefNet-DIS5K', 'HRSOD': 'BiRefNet-HRSOD', 'COD': 'BiRefNet-COD', 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs' } -modelNameList = ['General', 'General-Lite', 'Portrait', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs'] +modelNameList = ['General', 'General-Lite', 'General-Lite-2K', 'Portrait', 'Matting', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs'] def get_model_path(model_name): @@ -94,7 +96,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Auto download BiRefNet model from huggingface to models/BiRefNet/{model_name}.safetensors" def load_model(self, model_name, device): - bb_index = 3 if model_name == "General-Lite" else 6 + bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" else 6 biRefNet_model = BiRefNet(bb_pretrained=False, bb_index=bb_index) model_file_name = f'{model_name}.safetensors' model_full_path = folder_paths.get_full_path(models_dir_key, model_file_name) @@ -138,7 +140,7 @@ def load_model(self, model, device, use_weight=False): biRefNet_model = OldBiRefNet(bb_pretrained=use_weight) else: version = VERSION[1] - bb_index = 3 if model == "General-Lite.safetensors" else 6 + bb_index = 3 if model == "General-Lite.safetensors" or model == "General-Lite-2K.safetensors" else 6 biRefNet_model = BiRefNet(bb_pretrained=use_weight, bb_index=bb_index) model_path = folder_paths.get_full_path(models_dir_key, model) @@ -199,7 +201,7 @@ def rem_bg(self, model, images): image = apply_mask_to_image(image.cpu(), mask.cpu()) _images.append(image) - _masks.append(mask) + _masks.append(mask.squeeze(0)) out_images = torch.cat(_images, dim=0) out_masks = torch.cat(_masks, dim=0) diff --git a/pyproject.toml b/pyproject.toml index 3829342..054f78d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_birefnet_ll" description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet." -version = "1.0.3" +version = "1.0.4" license = {file = "LICENSE"} dependencies = ["numpy<2", "opencv-python", "scipy", "timm"]