Skip to content

Commit

Permalink
sync birefnet lib
Browse files Browse the repository at this point in the history
  • Loading branch information
刘雪峰 committed Sep 10, 2024
1 parent c0ada33 commit 113f8ca
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
49 changes: 33 additions & 16 deletions birefnet/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import math

# CUR_DIR = os.path.dirname(__file__)
import folder_paths


class Config:
Expand All @@ -11,15 +11,20 @@ def __init__(self, bb_index: int = 6) -> None:
# if os.name == 'nt':
# self.sys_home_dir = os.environ['USERPROFILE'] # For windows system
# else:
# self.sys_home_dir = os.environ['HOME'] # For Linux system
# self.sys_home_dir = [os.environ['HOME'], '/mnt/data'][1] # For Linux system
# https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ
# self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')

# TASK settings
self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'Matting'][0]
self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0]
self.training_set = {
'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
'COD': 'TR-COD10K+TR-CAMO',
'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans', # leave DIS-VD for evaluation.
# 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD']]), # leave DIS-VD for evaluation.
'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans+DIS-VD-ori',
# 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'DIS-VD-ori']]),
'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans',
'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646',
}[self.task]
self.prompt4loc = ['dense', 'sparse'][0]
Expand Down Expand Up @@ -47,17 +52,18 @@ def __init__(self, bb_index: int = 6) -> None:
# TRAINING settings
self.batch_size = 4
self.finetune_last_epochs = [
('IoU', 0),
0,
{
'DIS5K': ('IoU', -30),
'COD': ('IoU', -20),
'HRSOD': ('IoU', -20),
'General': ('MAE', -10),
'Matting': ('MAE', -10),
'DIS5K': -40,
'COD': -20,
'HRSOD': -20,
'General': -20,
'General-2K': -20,
'Matting': -20,
}[self.task]
][1] # choose 0 to skip
self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
self.size = 1024
self.size = (1024, 1024) if self.task not in ['General-2K'] else (2560, 1440) # wid, hei
self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader

# Backbone settings
Expand Down Expand Up @@ -137,8 +143,8 @@ def __init__(self, bb_index: int = 6) -> None:
self.lambda_adv_d = 3. * (self.lambda_adv_g > 0)

# PATH settings - inactive
# self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')
# self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights')
# https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms
# self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights/cv')
# self.weights = {
# 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'),
# 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
Expand All @@ -149,6 +155,17 @@ def __init__(self, bb_index: int = 6) -> None:
# 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
# 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
# }
# weight_paths_name = "birefnet"
# self.weights = {
# 'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'),
# 'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
# 'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
# 'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
# 'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
# 'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
# 'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]),
# 'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]),
# }
self.weights = {}

# Callbacks - inactive
Expand All @@ -161,11 +178,12 @@ def __init__(self, bb_index: int = 6) -> None:

self.batch_size_valid = 1
self.rand_seed = 7
# run_sh_file = [f for f in os.listdir(CUR_DIR) if 'train.sh' == f] + [os.path.join(CUR_DIR, '..', f) for f in os.listdir('..') if 'train.sh' == f]
# run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
# if run_sh_file:
# with open(run_sh_file[0], 'r') as f:
# with open(run_sh_file[0], 'r') as f:
# lines = f.readlines()
# self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
# self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])

def print_task(self) -> None:
# Return task for choosing settings in shell scripts.
Expand All @@ -174,4 +192,3 @@ def print_task(self) -> None:
# if __name__ == '__main__':
# config = Config()
# config.print_task()

17 changes: 10 additions & 7 deletions birefnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, datasets, image_size, is_train=True):
self.size_train = image_size
self.size_test = image_size
self.keep_size = not config.size
self.data_size = (config.size, config.size)
self.data_size = config.size
self.is_train = is_train
self.load_all = config.load_all
self.device = config.device
Expand All @@ -45,12 +45,12 @@ def __init__(self, datasets, image_size, is_train=True):
if self.is_train and config.auxiliary_classification:
self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
self.transform_image = transforms.Compose([
transforms.Resize(self.data_size),
transforms.Resize(self.data_size[::-1]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
][self.load_all or self.keep_size:])
self.transform_label = transforms.Compose([
transforms.Resize(self.data_size),
transforms.Resize(self.data_size[::-1]),
transforms.ToTensor(),
][self.load_all or self.keep_size:])
dataset_root = os.path.join(config.data_root_dir, config.task)
Expand All @@ -73,15 +73,18 @@ def __init__(self, datasets, image_size, is_train=True):
print('Not exists:', p_gt)

if len(self.label_paths) != len(self.image_paths):
set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths])
set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths])
print('diff:', set_image_paths - set_label_paths)
raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})")

if self.load_all:
self.images_loaded, self.labels_loaded = [], []
self.class_labels_loaded = []
# for image_path, label_path in zip(self.image_paths, self.label_paths):
for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
_image = path_to_image(image_path, size=(config.size, config.size), color_type='rgb')
_label = path_to_image(label_path, size=(config.size, config.size), color_type='gray')
_image = path_to_image(image_path, size=config.size, color_type='rgb')
_label = path_to_image(label_path, size=config.size, color_type='gray')
self.images_loaded.append(_image)
self.labels_loaded.append(_label)
self.class_labels_loaded.append(
Expand All @@ -95,8 +98,8 @@ def __getitem__(self, index):
label = self.labels_loaded[index]
class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1
else:
image = path_to_image(self.image_paths[index], size=(config.size, config.size), color_type='rgb')
label = path_to_image(self.label_paths[index], size=(config.size, config.size), color_type='gray')
image = path_to_image(self.image_paths[index], size=config.size, color_type='rgb')
label = path_to_image(self.label_paths[index], size=config.size, color_type='gray')
class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1

# loading image and label
Expand Down

0 comments on commit 113f8ca

Please sign in to comment.