From 113f8ca70525823b9bbf1d75e7c6c75998e373c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=9B=AA=E5=B3=B0?= Date: Tue, 10 Sep 2024 19:17:37 +0800 Subject: [PATCH] sync birefnet lib --- birefnet/config.py | 49 ++++++++++++++++++++++++++++++--------------- birefnet/dataset.py | 17 +++++++++------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/birefnet/config.py b/birefnet/config.py index 598c245..9ba6776 100644 --- a/birefnet/config.py +++ b/birefnet/config.py @@ -1,7 +1,7 @@ import os import math -# CUR_DIR = os.path.dirname(__file__) +import folder_paths class Config: @@ -11,15 +11,20 @@ def __init__(self, bb_index: int = 6) -> None: # if os.name == 'nt': # self.sys_home_dir = os.environ['USERPROFILE'] # For windows system # else: - # self.sys_home_dir = os.environ['HOME'] # For Linux system + # self.sys_home_dir = [os.environ['HOME'], '/mnt/data'][1] # For Linux system + # https://drive.google.com/drive/folders/1hZW6tAGPJwo9mPS7qGGGdpxuvuXiyoMJ + # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') # TASK settings - self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'Matting'][0] + self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0] self.training_set = { 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 'COD': 'TR-COD10K+TR-CAMO', 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], - 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans', # leave DIS-VD for evaluation. + # 'General': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD']]), # leave DIS-VD for evaluation. + 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans+DIS-VD-ori', + # 'General-2K': '+'.join([ds for ds in os.listdir(os.path.join(self.data_root_dir, self.task)) if ds not in ['DIS-VD', 'DIS-VD-ori']]), + 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-NP+TE-P3M-500-P+TR-humans', 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', }[self.task] self.prompt4loc = ['dense', 'sparse'][0] @@ -47,17 +52,18 @@ def __init__(self, bb_index: int = 6) -> None: # TRAINING settings self.batch_size = 4 self.finetune_last_epochs = [ - ('IoU', 0), + 0, { - 'DIS5K': ('IoU', -30), - 'COD': ('IoU', -20), - 'HRSOD': ('IoU', -20), - 'General': ('MAE', -10), - 'Matting': ('MAE', -10), + 'DIS5K': -40, + 'COD': -20, + 'HRSOD': -20, + 'General': -20, + 'General-2K': -20, + 'Matting': -20, }[self.task] ][1] # choose 0 to skip self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly - self.size = 1024 + self.size = (1024, 1024) if self.task not in ['General-2K'] else (2560, 1440) # wid, hei self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader # Backbone settings @@ -137,8 +143,8 @@ def __init__(self, bb_index: int = 6) -> None: self.lambda_adv_d = 3. * (self.lambda_adv_g > 0) # PATH settings - inactive - # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') - # self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights') + # https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms + # self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights/cv') # self.weights = { # 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), # 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), @@ -149,6 +155,17 @@ def __init__(self, bb_index: int = 6) -> None: # 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), # 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), # } + # weight_paths_name = "birefnet" + # self.weights = { + # 'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'), + # 'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), + # 'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), + # 'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), + # 'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), + # 'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), + # 'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]), + # 'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]), + # } self.weights = {} # Callbacks - inactive @@ -161,11 +178,12 @@ def __init__(self, bb_index: int = 6) -> None: self.batch_size_valid = 1 self.rand_seed = 7 - # run_sh_file = [f for f in os.listdir(CUR_DIR) if 'train.sh' == f] + [os.path.join(CUR_DIR, '..', f) for f in os.listdir('..') if 'train.sh' == f] + # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] # if run_sh_file: - # with open(run_sh_file[0], 'r') as f: + # with open(run_sh_file[0], 'r') as f: # lines = f.readlines() # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) + # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0]) def print_task(self) -> None: # Return task for choosing settings in shell scripts. @@ -174,4 +192,3 @@ def print_task(self) -> None: # if __name__ == '__main__': # config = Config() # config.print_task() - diff --git a/birefnet/dataset.py b/birefnet/dataset.py index 0ee8428..3d0535d 100644 --- a/birefnet/dataset.py +++ b/birefnet/dataset.py @@ -36,7 +36,7 @@ def __init__(self, datasets, image_size, is_train=True): self.size_train = image_size self.size_test = image_size self.keep_size = not config.size - self.data_size = (config.size, config.size) + self.data_size = config.size self.is_train = is_train self.load_all = config.load_all self.device = config.device @@ -45,12 +45,12 @@ def __init__(self, datasets, image_size, is_train=True): if self.is_train and config.auxiliary_classification: self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)} self.transform_image = transforms.Compose([ - transforms.Resize(self.data_size), + transforms.Resize(self.data_size[::-1]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ][self.load_all or self.keep_size:]) self.transform_label = transforms.Compose([ - transforms.Resize(self.data_size), + transforms.Resize(self.data_size[::-1]), transforms.ToTensor(), ][self.load_all or self.keep_size:]) dataset_root = os.path.join(config.data_root_dir, config.task) @@ -73,6 +73,9 @@ def __init__(self, datasets, image_size, is_train=True): print('Not exists:', p_gt) if len(self.label_paths) != len(self.image_paths): + set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths]) + set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths]) + print('diff:', set_image_paths - set_label_paths) raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})") if self.load_all: @@ -80,8 +83,8 @@ def __init__(self, datasets, image_size, is_train=True): self.class_labels_loaded = [] # for image_path, label_path in zip(self.image_paths, self.label_paths): for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)): - _image = path_to_image(image_path, size=(config.size, config.size), color_type='rgb') - _label = path_to_image(label_path, size=(config.size, config.size), color_type='gray') + _image = path_to_image(image_path, size=config.size, color_type='rgb') + _label = path_to_image(label_path, size=config.size, color_type='gray') self.images_loaded.append(_image) self.labels_loaded.append(_label) self.class_labels_loaded.append( @@ -95,8 +98,8 @@ def __getitem__(self, index): label = self.labels_loaded[index] class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1 else: - image = path_to_image(self.image_paths[index], size=(config.size, config.size), color_type='rgb') - label = path_to_image(self.label_paths[index], size=(config.size, config.size), color_type='gray') + image = path_to_image(self.image_paths[index], size=config.size, color_type='rgb') + label = path_to_image(self.label_paths[index], size=config.size, color_type='gray') class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 # loading image and label