Skip to content

Commit

Permalink
add param compatability check
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorVasiljevic-TRI committed Dec 13, 2023
1 parent 81aeb87 commit aaa5d44
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,31 @@ def save_checkpoint(
os.remove(prev)


def check_args(args):
resume_latest = args.resume == "latest"
args.log_path = None

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.")

if resume_latest:
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
if args.save_most_recent:
raise ValueError("Cannot use save-most-recent with remote_sync and resume latest.")
if args.remote_sync_protocol != "s3":
raise ValueError("Sync protocol not supported when using resume latest.")


def main(args):
args = parse_args(args)

# Check the arg list for any incompatibilities.
check_args(args)

requires_training = args.train_data or args.dataset_type == "synthetic" or args.dataset_manifest is not None

if torch.cuda.is_available():
Expand All @@ -243,12 +265,6 @@ def main(args):
# fully initialize distributed device environment
device = init_distributed_device(args)

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.")

# get the name of the experiments
if args.name is None:
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
Expand Down Expand Up @@ -303,13 +319,7 @@ def main(args):
if resume_latest:
resume_from = None
checkpoint_path = args.checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
checkpoint_path = os.path.join(args.remote_sync, args.name)
if args.save_most_recent:
raise ValueError("Cannot use save-most-recent with remote_sync and resume latest.")
if args.remote_sync_protocol != "s3":
raise ValueError("Sync protocol not supported when using resume latest.")

if is_master(args):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
Expand Down

0 comments on commit aaa5d44

Please sign in to comment.