Skip to content

Commit

Permalink
Update vqa_fine_tune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewwangva authored Feb 23, 2023
1 parent 008bfc5 commit cd009bd
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/training/vqa_fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
#dataset_df = dataset_df[0:12800]
b_size = args.batch_size
if(split == "validation"):
b_size = args.batch_size
dataset = VQATextDataset(dataset_df,
split,
transforms,
labelencoder,
tokenizer=tokenizer,
)
b_size = args.batch_size * 20
dataset_df = dataset_df[0:12800]
dataset = VQATextDataset(dataset_df,
split,
transforms,
labelencoder,
tokenizer=tokenizer,
)
dataloader = DataLoader(
dataset,
batch_size=b_size,
Expand Down Expand Up @@ -222,7 +223,7 @@ def parse_args(args):
"--workers", type=int, default=2, help="Number of dataloader workers per GPU."
)
parser.add_argument(
"--batch-size", type=int, default=256, help="Batch size per GPU."
"--batch-size", type=int, default=128, help="Batch size per GPU."
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs to train for."
Expand Down

0 comments on commit cd009bd

Please sign in to comment.