diff --git a/README.md b/README.md index 28eb6b9..461c11d 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Some models on GitHub: ## Thanks -[BiRefNet](https://github.com/zhengpeng7/birefnet) +[ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet) [dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet) diff --git a/README_CN.md b/README_CN.md index 86cfe22..dc78155 100644 --- a/README_CN.md +++ b/README_CN.md @@ -71,7 +71,7 @@ GitHub上的模型: ## 感谢 -[BiRefNet](https://github.com/zhengpeng7/birefnet) +[ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet) [dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet) diff --git a/birefnet/config.py b/birefnet/config.py index 78e4cc6..fb7aa62 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -35,9 +35,8 @@ def __init__(self, bb_index: int = 6) -> None: self.prompt4loc = ['dense', 'sparse'][0] # Faster-Training settings - self.load_all = False # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data. - self.use_fp16 = False # It may cause nan in training. - self.compile = True and (not self.use_fp16) # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. + self.load_all = False # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data. + self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting. # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607. # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training. @@ -102,6 +101,7 @@ def __init__(self, bb_index: int = 6) -> None: self.freeze_bb = False self.model = [ 'BiRefNet', + 'BiRefNetC2F', ][0] # TRAINING settings - inactive @@ -200,10 +200,4 @@ def __init__(self, bb_index: int = 6) -> None: # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0]) - def print_task(self) -> None: - # Return task for choosing settings in shell scripts. - print(self.task) -# if __name__ == '__main__': -# config = Config() -# config.print_task() diff --git a/birefnet/dataset.py b/birefnet/dataset.py index e7c9505..3fa0c8d 100644 --- a/birefnet/dataset.py +++ b/birefnet/dataset.py @@ -5,9 +5,9 @@ from torch.utils import data from torchvision import transforms -from birefnet.image_proc import preproc -from birefnet.config import Config -from birefnet.utils import path_to_image +from .image_proc import preproc +from .config import Config +from .utils import path_to_image Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning diff --git a/birefnet/models/backbones/build_backbone.py b/birefnet/models/backbones/build_backbone.py index 5d91a87..08b25dd 100644 --- a/birefnet/models/backbones/build_backbone.py +++ b/birefnet/models/backbones/build_backbone.py @@ -1,11 +1,10 @@ import torch import torch.nn as nn -import safetensors.torch from collections import OrderedDict from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights -from birefnet.models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 -from birefnet.models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l -from birefnet.config import Config +from ..backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 +from ..backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l +from ...config import Config config = Config() @@ -27,8 +26,7 @@ def build_backbone(bb_name, pretrained=True, params_settings=''): return bb def load_weights(model, model_name): - # safetensors.torch.load_file - save_model = torch.load(config.weights[model_name], map_location='cpu') + save_model = torch.load(config.weights[model_name], map_location='cpu', weights_only=True) model_dict = model.state_dict() state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} # to ignore the weights with mismatched size when I modify the backbone itself. @@ -37,7 +35,7 @@ def load_weights(model, model_name): sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} if not state_dict or not sub_item: - print('Weights are not successully loaded. Check the state dict of weights file.') + print('Weights are not successfully loaded. Check the state dict of weights file.') return None else: print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) diff --git a/birefnet/models/backbones/pvt_v2.py b/birefnet/models/backbones/pvt_v2.py index f1910dd..3089c53 100644 --- a/birefnet/models/backbones/pvt_v2.py +++ b/birefnet/models/backbones/pvt_v2.py @@ -1,13 +1,15 @@ +import math +from functools import partial import torch import torch.nn as nn -from functools import partial -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model - -import math +try: + # version > 0.6.13 + from timm.layers import DropPath, to_2tuple, trunc_normal_ +except Exception: + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from birefnet.config import Config +from ...config import Config config = Config() @@ -383,7 +385,6 @@ def _conv_filter(state_dict, patch_size=16): return out_dict -## @register_model class pvt_v2_b0(PyramidVisionTransformerImpr): def __init__(self, **kwargs): super(pvt_v2_b0, self).__init__( @@ -392,8 +393,6 @@ def __init__(self, **kwargs): drop_rate=0.0, drop_path_rate=0.1) - -## @register_model class pvt_v2_b1(PyramidVisionTransformerImpr): def __init__(self, **kwargs): super(pvt_v2_b1, self).__init__( @@ -401,7 +400,7 @@ def __init__(self, **kwargs): qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) -## @register_model + class pvt_v2_b2(PyramidVisionTransformerImpr): def __init__(self, in_channels=3, **kwargs): super(pvt_v2_b2, self).__init__( @@ -409,7 +408,7 @@ def __init__(self, in_channels=3, **kwargs): qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) -## @register_model + class pvt_v2_b3(PyramidVisionTransformerImpr): def __init__(self, **kwargs): super(pvt_v2_b3, self).__init__( @@ -417,7 +416,7 @@ def __init__(self, **kwargs): qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) -## @register_model + class pvt_v2_b4(PyramidVisionTransformerImpr): def __init__(self, **kwargs): super(pvt_v2_b4, self).__init__( @@ -426,7 +425,6 @@ def __init__(self, **kwargs): drop_rate=0.0, drop_path_rate=0.1) -## @register_model class pvt_v2_b5(PyramidVisionTransformerImpr): def __init__(self, **kwargs): super(pvt_v2_b5, self).__init__( diff --git a/birefnet/models/backbones/swin_v1.py b/birefnet/models/backbones/swin_v1.py index 591599b..141cbf2 100644 --- a/birefnet/models/backbones/swin_v1.py +++ b/birefnet/models/backbones/swin_v1.py @@ -10,7 +10,11 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +try: + # version > 0.6.13 + from timm.layers import DropPath, to_2tuple, trunc_normal_ +except Exception: + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from birefnet.config import Config diff --git a/birefnet/models/birefnet.py b/birefnet/models/birefnet.py index f6c4b79..67feb0f 100644 --- a/birefnet/models/birefnet.py +++ b/birefnet/models/birefnet.py @@ -1,17 +1,31 @@ import torch import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from kornia.filters import laplacian - -from birefnet.config import Config -from birefnet.dataset import class_labels_TR_sorted -from birefnet.models.backbones.build_backbone import build_backbone -from birefnet.models.modules.decoder_blocks import BasicDecBlk, ResBlk -from birefnet.models.modules.lateral_blocks import BasicLatBlk -from birefnet.models.modules.aspp import ASPP, ASPPDeformable -from birefnet.models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet -from birefnet.models.refinement.stem_layer import StemLayer - +from huggingface_hub import PyTorchModelHubMixin + +from ..config import Config +from ..dataset import class_labels_TR_sorted +from .backbones.build_backbone import build_backbone +from .modules.decoder_blocks import BasicDecBlk, ResBlk +from .modules.lateral_blocks import BasicLatBlk +from .modules.aspp import ASPP, ASPPDeformable +from .refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet +from .refinement.stem_layer import StemLayer + + +def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'): + if patch_ref is not None: + grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1] + patches = rearrange(image, transformation, hg=grid_h, wg=grid_w) + return patches + +def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'): + if patch_ref is not None: + grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1] + image = rearrange(patches, transformation, hg=grid_h, wg=grid_w) + return image class BiRefNet(nn.Module): def __init__(self, bb_pretrained=True, bb_index=6): @@ -159,18 +173,6 @@ def __init__(self, channels): self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) - def get_patches_batch(self, x, p): - _size_h, _size_w = p.shape[2:] - patches_batch = [] - for idx in range(x.shape[0]): - columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1) - patches_x = [] - for column_x in columns_x: - patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)] - patch_sample = torch.cat(patches_x, dim=1) - patches_batch.append(patch_sample) - return torch.cat(patches_batch, dim=0) - def forward(self, features): if self.training and self.config.out_ref: outs_gdt_pred = [] @@ -181,7 +183,7 @@ def forward(self, features): outs = [] if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, x4) if self.split else x + patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1) p4 = self.decoder_block4(x4) m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None @@ -202,7 +204,7 @@ def forward(self, features): _p3 = _p4 + self.lateral_block4(x3) if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p3) if self.split else x + patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) p3 = self.decoder_block3(_p3) m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None @@ -228,7 +230,7 @@ def forward(self, features): _p2 = _p3 + self.lateral_block3(x2) if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p2) if self.split else x + patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) p2 = self.decoder_block2(_p2) m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None @@ -249,13 +251,13 @@ def forward(self, features): _p1 = _p2 + self.lateral_block2(x1) if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p1) if self.split else x + patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) _p1 = self.decoder_block1(_p1) _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) if self.config.dec_ipt: - patches_batch = self.get_patches_batch(x, _p1) if self.split else x + patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) p1_out = self.conv_out1(_p1) @@ -277,3 +279,60 @@ def __init__( def forward(self, x): return self.conv_out(self.conv1(x)) + + +########### + + +class BiRefNetC2F( + nn.Module, + PyTorchModelHubMixin, + library_name="birefnet_c2f", + repo_url="https://github.com/ZhengPeng7/BiRefNet_C2F", + tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection'] +): + def __init__(self, bb_pretrained=True): + super(BiRefNetC2F, self).__init__() + self.config = Config() + self.epoch = 1 + self.grid = 4 + self.model_coarse = BiRefNet(bb_pretrained=True) + self.model_fine = BiRefNet(bb_pretrained=True) + self.input_mixer = nn.Conv2d(4, 3, 1, 1, 0) + self.output_mixer_merge_post = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1), nn.Conv2d(16, 1, 3, 1, 1)) + + def forward(self, x): + x_ori = x.clone() + ########## Coarse ########## + x = F.interpolate(x, size=[s//self.grid for s in self.config.size[::-1]], mode='bilinear', align_corners=True) + + if self.training: + scaled_preds, class_preds_lst = self.model_coarse(x) + else: + scaled_preds = self.model_coarse(x) + ########## Fine ########## + x_HR_patches = image2patches(x_ori, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w') + pred = F.interpolate(scaled_preds[-1] if not (self.config.out_ref and self.training) else scaled_preds[1][-1], size=x_ori.shape[2:], mode='bilinear', align_corners=True) + pred_patches = image2patches(pred, patch_ref=x, transformation='b c (hg h) (wg w) -> (b hg wg) c h w') + t = torch.cat([x_HR_patches, pred_patches], dim=1) + x_HR = self.input_mixer(t) + + pred_patches = image2patches(pred, patch_ref=x_HR, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') + if self.training: + scaled_preds_HR, class_preds_lst_HR = self.model_fine(x_HR) + else: + scaled_preds_HR = self.model_fine(x_HR) + if self.training: + if self.config.out_ref: + [outs_gdt_pred, outs_gdt_label], outs = scaled_preds + [outs_gdt_pred_HR, outs_gdt_label_HR], outs_HR = scaled_preds_HR + for idx_out, out_HR in enumerate(outs_HR): + outs_HR[idx_out] = self.output_mixer_merge_post(patches2image(out_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) + return [([outs_gdt_pred + outs_gdt_pred_HR, outs_gdt_label + outs_gdt_label_HR], outs + outs_HR), class_preds_lst] # handle gt here + else: + return [ + scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR], + class_preds_lst + ] + else: + return scaled_preds + [self.output_mixer_merge_post(patches2image(scaled_pred_HR, grid_h=self.grid, grid_w=self.grid, transformation='(b hg wg) c h w -> b c (hg h) (wg w)')) for scaled_pred_HR in scaled_preds_HR] diff --git a/birefnet/models/modules/aspp.py b/birefnet/models/modules/aspp.py index ae4961d..d910d98 100644 --- a/birefnet/models/modules/aspp.py +++ b/birefnet/models/modules/aspp.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from birefnet.models.modules.deform_conv import DeformableConv2d -from birefnet.config import Config +from ..modules.deform_conv import DeformableConv2d +from ...config import Config config = Config() diff --git a/birefnet/models/modules/decoder_blocks.py b/birefnet/models/modules/decoder_blocks.py index 439ff66..487182f 100644 --- a/birefnet/models/modules/decoder_blocks.py +++ b/birefnet/models/modules/decoder_blocks.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from birefnet.models.modules.aspp import ASPP, ASPPDeformable -from birefnet.config import Config +from ..modules.aspp import ASPP, ASPPDeformable +from ...config import Config config = Config() diff --git a/birefnet/models/modules/lateral_blocks.py b/birefnet/models/modules/lateral_blocks.py index 1fa8548..de907ac 100644 --- a/birefnet/models/modules/lateral_blocks.py +++ b/birefnet/models/modules/lateral_blocks.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from functools import partial -from birefnet.config import Config +from ...config import Config config = Config() diff --git a/birefnet/models/modules/mlp.py b/birefnet/models/modules/mlp.py deleted file mode 100644 index 39b3568..0000000 --- a/birefnet/models/modules/mlp.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial - -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model - -import math - - -class MLPLayer(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): - super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." - - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - self.q = nn.Linear(dim, dim, bias=qkv_bias) - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.sr_ratio = sr_ratio - if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) - self.norm = nn.LayerNorm(dim) - - def forward(self, x, H, W): - B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) - x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = MLPLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, H, W): - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - return x - - -class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - x = self.proj(x) - _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - return x, H, W - diff --git a/birefnet/models/refinement/refiner.py b/birefnet/models/refinement/refiner.py index 5ecf69b..c3e2211 100644 --- a/birefnet/models/refinement/refiner.py +++ b/birefnet/models/refinement/refiner.py @@ -7,12 +7,12 @@ from torchvision.models import vgg16, vgg16_bn from torchvision.models import resnet50 -from birefnet.config import Config -from birefnet.dataset import class_labels_TR_sorted -from birefnet.models.backbones.build_backbone import build_backbone -from birefnet.models.modules.decoder_blocks import BasicDecBlk -from birefnet.models.modules.lateral_blocks import BasicLatBlk -from birefnet.models.refinement.stem_layer import StemLayer +from ...config import Config +from ...dataset import class_labels_TR_sorted +from ..backbones.build_backbone import build_backbone +from ..modules.decoder_blocks import BasicDecBlk +from ..modules.lateral_blocks import BasicLatBlk +from ..refinement.stem_layer import StemLayer class RefinerPVTInChannels4(nn.Module): diff --git a/birefnet/models/refinement/stem_layer.py b/birefnet/models/refinement/stem_layer.py index a93aa69..2dc7f0f 100644 --- a/birefnet/models/refinement/stem_layer.py +++ b/birefnet/models/refinement/stem_layer.py @@ -1,5 +1,5 @@ import torch.nn as nn -from birefnet.models.modules.utils import build_act_layer, build_norm_layer +from ..modules.utils import build_act_layer, build_norm_layer class StemLayer(nn.Module): diff --git a/birefnet/utils.py b/birefnet/utils.py index d44c7d2..a6d9a29 100644 --- a/birefnet/utils.py +++ b/birefnet/utils.py @@ -26,10 +26,12 @@ def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): -def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'): +def check_state_dict(state_dict, unwanted_prefixes=['_orig_mod.', 'module.']): for k, v in list(state_dict.items()): - if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + for unwanted_prefix in unwanted_prefixes: + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + break return state_dict diff --git a/birefnet_old/dataset.py b/birefnet_old/dataset.py index e94167b..5c30e90 100644 --- a/birefnet_old/dataset.py +++ b/birefnet_old/dataset.py @@ -1,5 +1,5 @@ import os -import cv2 +# import cv2 from tqdm import tqdm from PIL import Image from torch.utils import data diff --git a/birefnet_old/models/backbones/build_backbone.py b/birefnet_old/models/backbones/build_backbone.py index 03c72cc..e236636 100644 --- a/birefnet_old/models/backbones/build_backbone.py +++ b/birefnet_old/models/backbones/build_backbone.py @@ -2,9 +2,9 @@ import torch.nn as nn from collections import OrderedDict from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights -from birefnet_old.models.backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 -from birefnet_old.models.backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l -from birefnet_old.config import Config +from ..backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 +from ..backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l +from ...config import Config config = Config() diff --git a/birefnet_old/models/backbones/pvt_v2.py b/birefnet_old/models/backbones/pvt_v2.py index 58e90af..947a122 100644 --- a/birefnet_old/models/backbones/pvt_v2.py +++ b/birefnet_old/models/backbones/pvt_v2.py @@ -2,12 +2,15 @@ import torch.nn as nn from functools import partial -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +try: + # version > 0.6.13 + from timm.layers import DropPath, to_2tuple, trunc_normal_ +except Exception: + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import math -from birefnet_old.config import Config +from ...config import Config config = Config() diff --git a/birefnet_old/models/backbones/swin_v1.py b/birefnet_old/models/backbones/swin_v1.py index 74d31cc..49a7ace 100644 --- a/birefnet_old/models/backbones/swin_v1.py +++ b/birefnet_old/models/backbones/swin_v1.py @@ -10,9 +10,13 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +try: + # version > 0.6.13 + from timm.layers import DropPath, to_2tuple, trunc_normal_ +except Exception: + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from birefnet_old.config import Config +from ...config import Config config = Config() diff --git a/birefnet_old/models/birefnet.py b/birefnet_old/models/birefnet.py index 79aa9fd..fba83f0 100644 --- a/birefnet_old/models/birefnet.py +++ b/birefnet_old/models/birefnet.py @@ -8,15 +8,15 @@ from torchvision.models import resnet50 from kornia.filters import laplacian -from birefnet_old.config import Config -from birefnet_old.dataset import class_labels_TR_sorted -from birefnet_old.models.backbones.build_backbone import build_backbone -from birefnet_old.models.modules.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk -from birefnet_old.models.modules.lateral_blocks import BasicLatBlk -from birefnet_old.models.modules.aspp import ASPP, ASPPDeformable -from birefnet_old.models.modules.ing import * -from birefnet_old.models.refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet -from birefnet_old.models.refinement.stem_layer import StemLayer +from ..config import Config +from ..dataset import class_labels_TR_sorted +from .backbones.build_backbone import build_backbone +from .modules.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk +from .modules.lateral_blocks import BasicLatBlk +from .modules.aspp import ASPP, ASPPDeformable +from .modules.ing import * +from .refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet +from .refinement.stem_layer import StemLayer class BiRefNet(nn.Module): diff --git a/birefnet_old/models/modules/aspp.py b/birefnet_old/models/modules/aspp.py index cb8e4ab..ce842f7 100644 --- a/birefnet_old/models/modules/aspp.py +++ b/birefnet_old/models/modules/aspp.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from birefnet_old.models.modules.deform_conv import DeformableConv2d -from birefnet_old.config import Config +from ..modules.deform_conv import DeformableConv2d +from ...config import Config config = Config() diff --git a/birefnet_old/models/modules/decoder_blocks.py b/birefnet_old/models/modules/decoder_blocks.py index 52d0740..3d32736 100644 --- a/birefnet_old/models/modules/decoder_blocks.py +++ b/birefnet_old/models/modules/decoder_blocks.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from birefnet_old.models.modules.aspp import ASPP, ASPPDeformable -from birefnet_old.models.modules.attentions import PSA, SGE -from birefnet_old.config import Config +from ..modules.aspp import ASPP, ASPPDeformable +from ..modules.attentions import PSA, SGE +from ...config import Config config = Config() diff --git a/birefnet_old/models/modules/ing.py b/birefnet_old/models/modules/ing.py index 771c234..b0026e9 100644 --- a/birefnet_old/models/modules/ing.py +++ b/birefnet_old/models/modules/ing.py @@ -1,5 +1,5 @@ import torch.nn as nn -from birefnet_old.models.modules.mlp import MLPLayer +from ..modules.mlp import MLPLayer class BlockA(nn.Module): diff --git a/birefnet_old/models/modules/lateral_blocks.py b/birefnet_old/models/modules/lateral_blocks.py index dc14444..de907ac 100644 --- a/birefnet_old/models/modules/lateral_blocks.py +++ b/birefnet_old/models/modules/lateral_blocks.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from functools import partial -from birefnet_old.config import Config +from ...config import Config config = Config() diff --git a/birefnet_old/models/modules/mlp.py b/birefnet_old/models/modules/mlp.py index 39b3568..506bfe4 100644 --- a/birefnet_old/models/modules/mlp.py +++ b/birefnet_old/models/modules/mlp.py @@ -2,8 +2,11 @@ import torch.nn as nn from functools import partial -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +try: + # version > 0.6.13 + from timm.layers import DropPath, to_2tuple, trunc_normal_ +except Exception: + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import math diff --git a/pyproject.toml b/pyproject.toml index d1ab829..194b4cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "comfyui_birefnet_ll" description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet." version = "1.0.5" license = {file = "LICENSE"} -dependencies = ["numpy<2", "opencv-python", "scipy", "timm"] +dependencies = ["numpy", "opencv-python", "timm"] [project.urls] Repository = "https://github.com/lldacing/ComfyUI_BiRefNet_ll" diff --git a/requirements.txt b/requirements.txt index 91f0ea0..293d45b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -numpy<2 +numpy opencv-python -scipy -timm +timm \ No newline at end of file