diff --git a/open_lm/main.py b/open_lm/main.py index 22666960..44e64b7a 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -249,13 +249,10 @@ def check_args(args): raise ValueError("Sync protocol not supported when using resume latest.") if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left: - raise ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.") + ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.") if args.lr_scheduler != "cosine": - raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.") - - if args.init_meta_device and not args.fsdp: - raise ValueError("--init-meta-device can only be specified if --fsdp is specified.") + ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.") def main(args): @@ -402,8 +399,10 @@ def main(args): if args.hf_model is not None: model = create_wrapped_hf_model(args) else: - with torch.device("meta" if args.init_meta_device else args.device): + with torch.device("meta" if args.fsdp else args.device): model = create_model(args) + if not args.fsdp: + model.reset_parameters() args.vocab_size = model.vocab_size args.seq_len = model.seq_len diff --git a/open_lm/params.py b/open_lm/params.py index 5f957bb9..4920cbb7 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -561,12 +561,6 @@ def parse_args(args): default=False, help="If true, ignore parse errors in data loading. This should ideally be False, as errors in dataloading can point to bigger issues in your dataset. However, this can be useful when training on a large dataset which has a couple errors.", ) - parser.add_argument( - "--init-meta-device", - action="store_true", - default=False, - help="If true, initialize the model on the meta device. This allows creating models larger than CPU memory. Can only be specified if --fsdp is also specified." - ) add_model_args(parser)