diff --git a/src/training/data.py b/src/training/data.py index 622e6207f..a88215a23 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -45,6 +45,70 @@ def __getitem__(self, idx): return images, texts +class DebiasDataLoader: + def __init__(self, args, preprocess_train, epoch): + + train_data_list = args.train_data.split(', ') + data_train = [] + for train_data in train_data_list: + data_train.append(get_dataset_fn(train_data, args.dataset_type)( + args, preprocess_train, is_train=True, epoch=epoch, filepath=train_data)) + + self.args = args + self.num_datasets = len(data_train) + self.dataloaders = [dataset.dataloader for dataset in data_train] + self.dataiters = [iter(dataloader) for dataloader in self.dataloaders] + self.samplers = [dataset.sampler for dataset in data_train] + + self.num_batches = sum([dataloader.num_batches for dataloader in self.dataloaders]) + self.num_samples = sum([dataloader.num_samples for dataloader in self.dataloaders]) + + # calculate sample weights according to num_samples of multiple datasets + self.sample_weights = np.array([float(dataloader.num_samples) / self.num_samples for dataloader in self.dataloaders]) + + self.count = 0 + self.current_epoch = 0 + # initialize each sampler + if self.args.distributed and self.samplers is not None: + for sampler in self.samplers: + sampler.set_epoch(0) + + def __len__(self): + return self.num_batches + + def __iter__(self): + while True: + + if self.count == self.num_batches: + self.current_epoch += 1 + self.count = 0 + if self.args.distributed and self.samplers is not None: + for sampler in self.samplers: + sampler.set_epoch(self.current_epoch) + return # end each epoch + + # set random seed for sampling from the same dataset. + stable_random_seed = int(self.count + self.num_batches * self.current_epoch) + np.random.seed(stable_random_seed) + + # sample a dataset according to sample_weights + iter_index = np.random.choice(range(self.num_datasets), p=self.sample_weights) + + # generate training image-text pairs from the sampled dataset. + try: + data_iter = self.dataiters[iter_index] + batch = next(data_iter) + except StopIteration: + # refresh dataiter if dataloader is used up. + self.dataiters[iter_index] = iter(self.dataloaders[iter_index]) + data_iter = self.dataiters[iter_index] + batch = next(data_iter) + + self.count += 1 + + yield batch + + @dataclass class DataInfo: dataloader: DataLoader @@ -184,8 +248,11 @@ def tarfile_to_samples_nothrow(src, handler=log_and_continue): _SAMPLE_SHUFFLE_INITIAL = 1000 -def get_wds_dataset(args, preprocess_img, is_train, epoch=0): - input_shards = args.train_data if is_train else args.val_data +def get_wds_dataset(args, preprocess_img, is_train, epoch=0, filepath=None): + if filepath: + input_shards = filepath + else: + input_shards = args.train_data if is_train else args.val_data assert input_shards is not None num_samples, num_shards = get_dataset_size(input_shards) @@ -280,8 +347,11 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0): return DataInfo(dataloader, None) -def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): - input_filename = args.train_data if is_train else args.val_data +def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, filepath=None): + if filepath: + input_filename = filepath + else: + input_filename = args.train_data if is_train else args.val_data assert input_filename dataset = CsvDataset( input_filename, @@ -331,8 +401,12 @@ def get_data(args, preprocess_fns, epoch=0): data = {} if args.train_data: - data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( - args, preprocess_train, is_train=True, epoch=epoch) + if args.debias_sample: + dataloader = DebiasDataLoader(args, preprocess_train, epoch) + data['train'] = DataInfo(dataloader, None) + else: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, preprocess_train, is_train=True, epoch=epoch) if args.val_data: data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( diff --git a/src/training/params.py b/src/training/params.py index ef2b0990a..3638a4168 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -205,6 +205,12 @@ def parse_args(): action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.", ) + parser.add_argument( + "--debias-sample", + default=False, + action='store_true', + help="Enable debias sampling.", + ) parser.add_argument( "--torchscript", default=False,