Skip to content

Commit

Permalink
sync code from BiRefNet repo
Browse files Browse the repository at this point in the history
  • Loading branch information
lldacing committed Sep 19, 2024
1 parent c122308 commit e3a7de4
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 31 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[中文文档](README_CN.md)

Support the use of new and old versions of BiRefNet models

## Preview
![save api extended](doc/base.png)
![save api extended](doc/video.gif)
Expand Down
3 changes: 3 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[English](README.md)

支持使用新老版本BiRefNet模型进行抠图

## 预览
![save api extended](doc/base.png)
![save api extended](doc/video.gif)
Expand Down
55 changes: 31 additions & 24 deletions birefnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ class Config:
def __init__(self, bb_index: int = 6) -> None:
# PATH settings
# Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
# if os.name == 'nt':
# self.sys_home_dir = os.environ['USERPROFILE'] # For windows system
# else:
# self.sys_home_dir = [os.environ['HOME'], '/mnt/data'][1] # For Linux system
# https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ
# self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][1] # Default, custom
# self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')

# TASK settings
Expand All @@ -21,10 +17,10 @@ def __init__(self, bb_index: int = 6) -> None:
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
'COD': 'TR-COD10K+TR-CAMO',
'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
# 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD']]), # leave DIS-VD for evaluation.
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans+DIS-VD-ori',
# 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'DIS-VD-ori']]),
'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans',
# 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]), # leave DIS-VD,TE-P3M-500-NP for evaluation.
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori',
# 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'TE-P3M-500-NP']]),
'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori',
'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646',
}[self.task]
self.prompt4loc = ['dense', 'sparse'][0]
Expand Down Expand Up @@ -57,7 +53,7 @@ def __init__(self, bb_index: int = 6) -> None:
'DIS5K': -40,
'COD': -20,
'HRSOD': -20,
'General': -20,
'General': -40,
'General-2K': -20,
'Matting': -20,
}[self.task]
Expand Down Expand Up @@ -105,29 +101,40 @@ def __init__(self, bb_index: int = 6) -> None:
self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch.
self.lr_decay_rate = 0.5
# Loss
if self.task not in ['Matting']:
if self.task in ['Matting']:
self.lambdas_pix_last = {
# not 0 means opening this loss
# original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
'bce': 30 * 1, # high performance
'iou': 0.5 * 1, # 0 / 255
'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64)
'mae': 30 * 0,
'mse': 30 * 0, # can smooth the saliency map
'bce': 30 * 1,
'iou': 0.5 * 0,
'iou_patch': 0.5 * 0,
'mae': 100 * 1,
'mse': 30 * 0,
'triplet': 3 * 0,
'reg': 100 * 0,
'ssim': 10 * 1, # help contours,
'cnt': 5 * 0, # help contours
'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
'ssim': 10 * 1,
'cnt': 5 * 0,
'structure': 5 * 0,
}
elif self.task in ['General', 'General-2K']:
self.lambdas_pix_last = {
'bce': 30 * 1,
'iou': 0.5 * 1,
'iou_patch': 0.5 * 0,
'mae': 100 * 1,
'mse': 30 * 0,
'triplet': 3 * 0,
'reg': 100 * 0,
'ssim': 10 * 1,
'cnt': 5 * 0,
'structure': 5 * 0,
}
else:
self.lambdas_pix_last = {
# not 0 means opening this loss
# original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
'bce': 30 * 0, # high performance
'iou': 0.5 * 0, # 0 / 255
'bce': 30 * 1, # high performance
'iou': 0.5 * 1, # 0 / 255
'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64)
'mae': 100 * 1,
'mae': 30 * 0,
'mse': 30 * 0, # can smooth the saliency map
'triplet': 3 * 0,
'reg': 100 * 0,
Expand Down
2 changes: 1 addition & 1 deletion birefnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, datasets, image_size, is_train=True):
if len(self.label_paths) != len(self.image_paths):
set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths])
set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths])
print('diff:', set_image_paths - set_label_paths)
print('Path diff:', set_image_paths - set_label_paths)
raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})")

if self.load_all:
Expand Down
7 changes: 2 additions & 5 deletions birefnet/image_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def preproc(image, label, preproc_methods=['flip']):
if 'enhance' in preproc_methods:
image = color_enhance(image)
if 'pepper' in preproc_methods:
label = random_pepper(label)
image = random_pepper(image)
return image, label


Expand Down Expand Up @@ -112,8 +112,5 @@ def random_pepper(img, N=0.0015):
for i in range(noiseNum):
randX = random.randint(0, img.shape[0] - 1)
randY = random.randint(0, img.shape[1] - 1)
if random.randint(0, 1) == 0:
img[randX, randY] = 0
else:
img[randX, randY] = 255
img[randX, randY] = random.randint(0, 1) * 255
return Image.fromarray(img)
2 changes: 1 addition & 1 deletion birefnet/models/backbones/swin_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def forward(self, x, H, W):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)

for blk in self.blocks:
blk.H, blk.W = H, W
Expand Down

0 comments on commit e3a7de4

Please sign in to comment.