diff --git a/README.md b/README.md index 3b16e51..ca43f28 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ Support the use of new and old versions of BiRefNet models ### The available newest models are: - General: A pre-trained model for general use cases. +- General-HR: A pre-trained model for general use cases which shows great performance on higher resolution images (2048x2048). - General-Lite: A light pre-trained model for general use cases. - General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440). - Portrait: A pre-trained model for human portraits. @@ -37,6 +38,7 @@ Model files go here (when use AutoDownloadBiRefNetModel automatically downloaded If necessary, they can be downloaded from: - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General.safetensors` +- [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-HR.safetensors` - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite.safetensors` - [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite-2K.safetensors` - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Portrait.safetensors` diff --git a/README_CN.md b/README_CN.md index d1794c4..c0ba527 100644 --- a/README_CN.md +++ b/README_CN.md @@ -24,8 +24,9 @@ ### 最新的模型: - General: 用于一般用例的预训练模型。 +- General-HR: 用于一般用例的预训练模型,在更高分辨率的图像上表现出色(训练分辨率2048x2048)。 - General-Lite: 用于一般用例的轻量级预训练模型。 -- General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像。 (最佳分辨率2560x1440). +- General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像(最佳分辨率2560x1440)。 - Portrait: 人物肖像预训练模型。 - Matting: 一种使用无trimap matting的预训练模型。 - DIS: 一种用于二分图像分割(DIS)的预训练模型。 @@ -33,10 +34,11 @@ - COD: 一种用于隐蔽目标检测(COD)的预训练模型。 - DIS-TR_TEs: 具有大量数据集的预训练模型。 -模型文件放在`${comfyui_rootpath}/models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,则会自动下载模型). +模型文件放在`${comfyui_rootpath}/models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,则会自动下载模型)。 也可以手动下载模型: - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General.safetensors` +- [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-HR.safetensors` - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite.safetensors` - [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite-2K.safetensors` - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Portrait.safetensors` diff --git a/birefnet/config.py b/birefnet/config.py index fb7aa62..31ca582 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -61,7 +61,7 @@ def __init__(self, bb_index: int = 6) -> None: 'DIS5K': -40, 'COD': -20, 'HRSOD': -20, - 'General': -40, + 'General': -20, 'General-2K': -20, 'Matting': -20, }[self.task] @@ -154,9 +154,6 @@ def __init__(self, bb_index: int = 6) -> None: self.lambdas_cls = { 'ce': 5.0 } - # Adv - self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training - self.lambda_adv_d = 3. * (self.lambda_adv_g > 0) # PATH settings - inactive # https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms diff --git a/birefnet/models/backbones/swin_v1.py b/birefnet/models/backbones/swin_v1.py index 3a360ff..737afa7 100644 --- a/birefnet/models/backbones/swin_v1.py +++ b/birefnet/models/backbones/swin_v1.py @@ -68,9 +68,9 @@ def window_reverse(windows, window_size, H, W): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = int(windows.shape[-1]) + x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x diff --git a/birefnetNode.py b/birefnetNode.py index 5d3b63a..31720a5 100644 --- a/birefnetNode.py +++ b/birefnetNode.py @@ -9,7 +9,7 @@ from birefnet.models.birefnet import BiRefNet from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet from birefnet.utils import check_state_dict -from .util import tensor_to_pil, apply_mask_to_image, normalize_mask, refine_foreground, filter_mask, add_mask_as_alpha +from .util import refine_foreground, filter_mask, add_mask_as_alpha deviceType = model_management.get_torch_device().type models_dir_key = "birefnet" @@ -18,6 +18,7 @@ usage_to_weights_file = { 'General': 'BiRefNet', + 'General-HR': 'BiRefNet_HR', 'General-Lite': 'BiRefNet_T', 'General-Lite-2K': 'BiRefNet_lite-2K', 'Portrait': 'BiRefNet-portrait', @@ -28,7 +29,7 @@ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs' } -modelNameList = ['General', 'General-Lite', 'General-Lite-2K', 'Portrait', 'Matting', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs'] +modelNameList = ['General', 'General-HR', 'General-Lite', 'General-Lite-2K', 'Portrait', 'Matting', 'DIS', 'HRSOD', 'COD', 'DIS-TR_TEs'] def get_model_path(model_name): @@ -90,6 +91,11 @@ def old_proc(self, image) -> torch.Tensor: VERSION = ["old", "v1"] old_models_name = ["BiRefNet-DIS_ep580.pth", "BiRefNet-ep480.pth"] +torch_dtype={ + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} class AutoDownloadBiRefNetModel: @@ -99,6 +105,9 @@ def INPUT_TYPES(cls): "required": { "model_name": (modelNameList,), "device": (["AUTO", "CPU"],) + }, + "optional": { + "dtype": (["float32", "float16"], {"default": "float32"}) } } @@ -108,7 +117,7 @@ def INPUT_TYPES(cls): CATEGORY = "image/BiRefNet" DESCRIPTION = "Auto download BiRefNet model from huggingface to models/BiRefNet/{model_name}.safetensors" - def load_model(self, model_name, device): + def load_model(self, model_name, device, dtype="float32"): bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" else 6 biRefNet_model = BiRefNet(bb_pretrained=False, bb_index=bb_index) model_file_name = f'{model_name}.safetensors' @@ -122,7 +131,7 @@ def load_model(self, model_name, device): device_type = "cpu" state_dict = safetensors.torch.load_file(model_full_path, device=device_type) biRefNet_model.load_state_dict(state_dict) - biRefNet_model.to(device_type) + biRefNet_model.to(device_type, dtype=torch_dtype[dtype]) biRefNet_model.eval() return [(biRefNet_model, VERSION[1])] @@ -137,7 +146,8 @@ def INPUT_TYPES(cls): "device": (["AUTO", "CPU"], ) }, "optional": { - "use_weight": ("BOOLEAN", {"default": False}) + "use_weight": ("BOOLEAN", {"default": False}), + "dtype": (["float32", "float16"], {"default": "float32"}) } } @@ -147,7 +157,7 @@ def INPUT_TYPES(cls): CATEGORY = "rembg/BiRefNet" DESCRIPTION = "Load BiRefNet model from folder models/BiRefNet or the path of birefnet configured in the extra YAML file" - def load_model(self, model, device, use_weight=False): + def load_model(self, model, device, use_weight=False, dtype="float32"): if model in old_models_name: version = VERSION[0] biRefNet_model = OldBiRefNet(bb_pretrained=use_weight) @@ -168,7 +178,7 @@ def load_model(self, model, device, use_weight=False): state_dict = check_state_dict(state_dict) biRefNet_model.load_state_dict(state_dict) - biRefNet_model.to(device_type) + biRefNet_model.to(device_type, dtype=torch_dtype[dtype]) biRefNet_model.eval() return [(biRefNet_model, version)] @@ -211,7 +221,9 @@ def INPUT_TYPES(cls): def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilinear', mask_threshold=0.000): model, version = model - model_device_type = next(model.parameters()).device.type + one_torch = next(model.parameters()) + model_device_type = one_torch.device.type + model_dtype = one_torch.dtype b, h, w, c = images.shape image_bchw = images.permute(0, 3, 1, 2) @@ -226,7 +238,7 @@ def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilin _mask_bchw = [] for each_image in im_tensor: with torch.no_grad(): - each_mask = model(each_image.unsqueeze(0).to(model_device_type))[-1].sigmoid().cpu() + each_mask = model(each_image.unsqueeze(0).to(model_device_type, dtype=model_dtype))[-1].sigmoid().cpu().float() _mask_bchw.append(each_mask) del each_mask diff --git a/pyproject.toml b/pyproject.toml index d9a3f40..8dd673d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_birefnet_ll" description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet, RembgByBiRefNetAdvanced, GetMaskByBiRefNet, BlurFusionForegroundEstimation." -version = "1.1.0" +version = "1.1.1" license = {file = "LICENSE"} dependencies = ["numpy", "opencv-python", "timm"]