Skip to content

Commit

Permalink
add parameter compatability check (#153)
Browse files Browse the repository at this point in the history
* add param compatability check

* moving logging.error to check_args

* format

* convert logging.error to ValueError

---------

Co-authored-by: Achal Dave <[email protected]>
  • Loading branch information
IgorVasiljevic-TRI and achalddave authored Dec 15, 2023
1 parent 813d501 commit b2e8762
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,36 @@ def save_checkpoint(
os.remove(prev)


def check_args(args):
resume_latest = args.resume == "latest"

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fsdp-block.")

if resume_latest:
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
if args.save_most_recent:
raise ValueError("Cannot use save-most-recent with remote_sync and resume latest.")
if args.remote_sync_protocol != "s3":
raise ValueError("Sync protocol not supported when using resume latest.")

if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left:
ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.")

if args.lr_scheduler != "cosine":
ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.")


def main(args):
args = parse_args(args)

# Check the arg list for any incompatibilities.
check_args(args)

requires_training = args.train_data or args.dataset_type == "synthetic" or args.dataset_manifest is not None

if torch.cuda.is_available():
Expand All @@ -247,12 +274,6 @@ def main(args):
# fully initialize distributed device environment
device = init_distributed_device(args)

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.")

# get the name of the experiments
if args.name is None:
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
Expand Down Expand Up @@ -307,13 +328,7 @@ def main(args):
if resume_latest:
resume_from = None
checkpoint_path = args.checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
checkpoint_path = os.path.join(args.remote_sync, args.name)
if args.save_most_recent:
raise ValueError("Cannot use save-most-recent with remote_sync and resume latest.")
if args.remote_sync_protocol != "s3":
raise ValueError("Sync protocol not supported when using resume latest.")

if is_master(args):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
Expand Down Expand Up @@ -571,10 +586,6 @@ def main(args):
floor=args.dataset_manifest is not None,
)

if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left:
logging.error(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.")
exit(1)

if args.target_mask_left is not None:
# tokens handled with same modulo in dataloading
args.target_mask_left = proc_token(args.target_mask_left, args.vocab_size)
Expand All @@ -595,20 +606,14 @@ def main(args):
else:
total_steps = (data["train"].dataloader.num_batches) * args.epochs

if args.lr_scheduler == "cosine":
scheduler = cosine_lr(
optimizer,
args.lr,
args.warmup,
total_steps,
args.lr_cooldown_end,
args.force_min_lr,
)
else:
logging.error(
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown."
)
exit(1)
scheduler = cosine_lr(
optimizer,
args.lr,
args.warmup,
total_steps,
args.lr_cooldown_end,
args.force_min_lr,
)

# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
Expand Down

0 comments on commit b2e8762

Please sign in to comment.