diff --git a/tests/shared.py b/tests/shared.py index 091be0b1..75938f18 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -100,7 +100,7 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False): args = MockTrainArgs(model) # only want to look at one batch - args.train_num_samples = args.batch_size + args.train_num_samples = args.global_batch_size # increase learning rate and remove warmup for maximize change to model weights args.lr = 1e-3