Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a sampling strategy for multiple training datasets #107

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,70 @@ def __getitem__(self, idx):
return images, texts


class DebiasDataLoader:
def __init__(self, args, preprocess_train, epoch):

train_data_list = args.train_data.split(', ')
data_train = []
for train_data in train_data_list:
data_train.append(get_dataset_fn(train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch, filepath=train_data))

self.args = args
self.num_datasets = len(data_train)
self.dataloaders = [dataset.dataloader for dataset in data_train]
self.dataiters = [iter(dataloader) for dataloader in self.dataloaders]
self.samplers = [dataset.sampler for dataset in data_train]

self.num_batches = sum([dataloader.num_batches for dataloader in self.dataloaders])
self.num_samples = sum([dataloader.num_samples for dataloader in self.dataloaders])

# calculate sample weights according to num_samples of multiple datasets
self.sample_weights = np.array([float(dataloader.num_samples) / self.num_samples for dataloader in self.dataloaders])

self.count = 0
self.current_epoch = 0
# initialize each sampler
if self.args.distributed and self.samplers is not None:
for sampler in self.samplers:
sampler.set_epoch(0)

def __len__(self):
return self.num_batches

def __iter__(self):
while True:

if self.count == self.num_batches:
self.current_epoch += 1
self.count = 0
if self.args.distributed and self.samplers is not None:
for sampler in self.samplers:
sampler.set_epoch(self.current_epoch)
return # end each epoch

# set random seed for sampling from the same dataset.
stable_random_seed = int(self.count + self.num_batches * self.current_epoch)
np.random.seed(stable_random_seed)

# sample a dataset according to sample_weights
iter_index = np.random.choice(range(self.num_datasets), p=self.sample_weights)

# generate training image-text pairs from the sampled dataset.
try:
data_iter = self.dataiters[iter_index]
batch = next(data_iter)
except StopIteration:
# refresh dataiter if dataloader is used up.
self.dataiters[iter_index] = iter(self.dataloaders[iter_index])
data_iter = self.dataiters[iter_index]
batch = next(data_iter)

self.count += 1

yield batch


@dataclass
class DataInfo:
dataloader: DataLoader
Expand Down Expand Up @@ -184,8 +248,11 @@ def tarfile_to_samples_nothrow(src, handler=log_and_continue):
_SAMPLE_SHUFFLE_INITIAL = 1000


def get_wds_dataset(args, preprocess_img, is_train, epoch=0):
input_shards = args.train_data if is_train else args.val_data
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, filepath=None):
if filepath:
input_shards = filepath
else:
input_shards = args.train_data if is_train else args.val_data
assert input_shards is not None

num_samples, num_shards = get_dataset_size(input_shards)
Expand Down Expand Up @@ -280,8 +347,11 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0):
return DataInfo(dataloader, None)


def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
input_filename = args.train_data if is_train else args.val_data
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, filepath=None):
if filepath:
input_filename = filepath
else:
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = CsvDataset(
input_filename,
Expand Down Expand Up @@ -331,8 +401,12 @@ def get_data(args, preprocess_fns, epoch=0):
data = {}

if args.train_data:
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch)
if args.debias_sample:
dataloader = DebiasDataLoader(args, preprocess_train, epoch)
data['train'] = DataInfo(dataloader, None)
else:
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch)

if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
Expand Down
6 changes: 6 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def parse_args():
action='store_true',
help="Force use of QuickGELU activation for non-OpenAI transformer models.",
)
parser.add_argument(
"--debias-sample",
default=False,
action='store_true',
help="Enable debias sampling.",
)
parser.add_argument(
"--torchscript",
default=False,
Expand Down