Skip to content

Commit

Permalink
Change handling of accum_freq.
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgiosSmyrnis committed Dec 13, 2023
1 parent 214cdae commit a79032c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 2 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,11 @@ def main(args):

assert (
args.global_batch_size % args.world_size == 0
), "Global batch size is not divisible by number of GPUs, and thus cannot be respected."
), f"Global batch size ({args.global_batch_size}) is not divisible by number of GPUs ({args.world_size}), and thus cannot be respected."

args.per_gpu_batch_size = args.global_batch_size // args.world_size
if args.val_data is not None:
# Make sure that val batch size is set to micro batch size
args.per_gpu_val_batch_size = args.global_val_batch_size // args.world_size // args.accum_freq
args.per_gpu_val_batch_size = args.global_val_batch_size // args.world_size

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")
Expand Down
5 changes: 3 additions & 2 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def parse_args(args):
"--accum-freq",
type=int,
default=1,
help="Update the model every --acum-freq steps.",
help="Update the model every --accum-freq steps.",
)
# arguments for distributed training
parser.add_argument(
Expand Down Expand Up @@ -575,6 +575,7 @@ def parse_args(args):
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"

if args.val_data is not None and args.global_val_batch_size is None:
args.global_val_batch_size = args.global_batch_size
# Make sure that val batch size is set to micro batch size
args.global_val_batch_size = args.global_batch_size // args.accum_freq

return args

0 comments on commit a79032c

Please sign in to comment.