From e3a7de4b8b0b1be1a24a2d6ae87ef1994c2fc2d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=9B=AA=E5=B3=B0?= Date: Thu, 19 Sep 2024 16:54:59 +0800 Subject: [PATCH] sync code from BiRefNet repo --- README.md | 3 ++ README_CN.md | 3 ++ birefnet/config.py | 55 ++++++++++++++++------------ birefnet/dataset.py | 2 +- birefnet/image_proc.py | 7 +--- birefnet/models/backbones/swin_v1.py | 2 +- 6 files changed, 41 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 2ef45a9..1e00ed8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ [中文文档](README_CN.md) + +Support the use of new and old versions of BiRefNet models + ## Preview ![save api extended](doc/base.png) ![save api extended](doc/video.gif) diff --git a/README_CN.md b/README_CN.md index d390cbb..be6274f 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,4 +1,7 @@ [English](README.md) + +支持使用新老版本BiRefNet模型进行抠图 + ## 预览 ![save api extended](doc/base.png) ![save api extended](doc/video.gif) diff --git a/birefnet/config.py b/birefnet/config.py index 9d67b89..5ed667b 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -8,11 +8,7 @@ class Config: def __init__(self, bb_index: int = 6) -> None: # PATH settings # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx - # if os.name == 'nt': - # self.sys_home_dir = os.environ['USERPROFILE'] # For windows system - # else: - # self.sys_home_dir = [os.environ['HOME'], '/mnt/data'][1] # For Linux system - # https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ + # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][1] # Default, custom # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') # TASK settings @@ -21,10 +17,10 @@ def __init__(self, bb_index: int = 6) -> None: 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 'COD': 'TR-COD10K+TR-CAMO', 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], - # 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD']]), # leave DIS-VD for evaluation. - 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans+DIS-VD-ori', - # 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'DIS-VD-ori']]), - 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans', + # 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), # leave DIS-VD,TE-P3M-500-NP for evaluation. + 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', + # 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), + 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', }[self.task] self.prompt4loc = ['dense', 'sparse'][0] @@ -57,7 +53,7 @@ def __init__(self, bb_index: int = 6) -> None: 'DIS5K': -40, 'COD': -20, 'HRSOD': -20, - 'General': -20, + 'General': -40, 'General-2K': -20, 'Matting': -20, }[self.task] @@ -105,29 +101,40 @@ def __init__(self, bb_index: int = 6) -> None: self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch. self.lr_decay_rate = 0.5 # Loss - if self.task not in ['Matting']: + if self.task in ['Matting']: self.lambdas_pix_last = { - # not 0 means opening this loss - # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 - 'bce': 30 * 1, # high performance - 'iou': 0.5 * 1, # 0 / 255 - 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) - 'mae': 30 * 0, - 'mse': 30 * 0, # can smooth the saliency map + 'bce': 30 * 1, + 'iou': 0.5 * 0, + 'iou_patch': 0.5 * 0, + 'mae': 100 * 1, + 'mse': 30 * 0, 'triplet': 3 * 0, 'reg': 100 * 0, - 'ssim': 10 * 1, # help contours, - 'cnt': 5 * 0, # help contours - 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. + 'ssim': 10 * 1, + 'cnt': 5 * 0, + 'structure': 5 * 0, + } + elif self.task in ['General', 'General-2K']: + self.lambdas_pix_last = { + 'bce': 30 * 1, + 'iou': 0.5 * 1, + 'iou_patch': 0.5 * 0, + 'mae': 100 * 1, + 'mse': 30 * 0, + 'triplet': 3 * 0, + 'reg': 100 * 0, + 'ssim': 10 * 1, + 'cnt': 5 * 0, + 'structure': 5 * 0, } else: self.lambdas_pix_last = { # not 0 means opening this loss # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 - 'bce': 30 * 0, # high performance - 'iou': 0.5 * 0, # 0 / 255 + 'bce': 30 * 1, # high performance + 'iou': 0.5 * 1, # 0 / 255 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) - 'mae': 100 * 1, + 'mae': 30 * 0, 'mse': 30 * 0, # can smooth the saliency map 'triplet': 3 * 0, 'reg': 100 * 0, diff --git a/birefnet/dataset.py b/birefnet/dataset.py index 3d0535d..e7c9505 100644 --- a/birefnet/dataset.py +++ b/birefnet/dataset.py @@ -75,7 +75,7 @@ def __init__(self, datasets, image_size, is_train=True): if len(self.label_paths) != len(self.image_paths): set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths]) set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths]) - print('diff:', set_image_paths - set_label_paths) + print('Path diff:', set_image_paths - set_label_paths) raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})") if self.load_all: diff --git a/birefnet/image_proc.py b/birefnet/image_proc.py index 415c3e6..d976d7e 100644 --- a/birefnet/image_proc.py +++ b/birefnet/image_proc.py @@ -48,7 +48,7 @@ def preproc(image, label, preproc_methods=['flip']): if 'enhance' in preproc_methods: image = color_enhance(image) if 'pepper' in preproc_methods: - label = random_pepper(label) + image = random_pepper(image) return image, label @@ -112,8 +112,5 @@ def random_pepper(img, N=0.0015): for i in range(noiseNum): randX = random.randint(0, img.shape[0] - 1) randY = random.randint(0, img.shape[1] - 1) - if random.randint(0, 1) == 0: - img[randX, randY] = 0 - else: - img[randX, randY] = 255 + img[randX, randY] = random.randint(0, 1) * 255 return Image.fromarray(img) \ No newline at end of file diff --git a/birefnet/models/backbones/swin_v1.py b/birefnet/models/backbones/swin_v1.py index 10642e3..591599b 100644 --- a/birefnet/models/backbones/swin_v1.py +++ b/birefnet/models/backbones/swin_v1.py @@ -394,7 +394,7 @@ def forward(self, x, H, W): mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W