-
Notifications
You must be signed in to change notification settings - Fork 19
Description
I notice the orig_size from batch is just set as (resolution, resolution) in the training code.
Is their any bug when the shape of input_data is not exactly (resolution, resolution) ?
class CustomImageDataset(Dataset): def __init__(self, img_dir, sample_size): """ Args: img_dir (string): Directory with all the images and text files. sample_size (tuple): Desired sample size as (height, width). """ self.img_dir = img_dir self.sample_size = sample_size return ( image, text, (self.sample_size, self.sample_size), (c_top, c_left), )
train_dataset = CustomImageDataset("/mnt/data/wangfuyun/cc3m", args.resolution)
for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet, discriminator): image, text, orig_size, crop_coords = batch