From 4f42503d5e00d6483f32ef1a90dc23881a1d8b42 Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Fri, 15 Dec 2023 14:59:24 -0800 Subject: [PATCH] Only use meta device if explicitly requested --- open_lm/main.py | 11 ++++++----- open_lm/params.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 44e64b7a..22666960 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -249,10 +249,13 @@ def check_args(args): raise ValueError("Sync protocol not supported when using resume latest.") if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left: - ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.") + raise ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.") if args.lr_scheduler != "cosine": - ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.") + raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.") + + if args.init_meta_device and not args.fsdp: + raise ValueError("--init-meta-device can only be specified if --fsdp is specified.") def main(args): @@ -399,10 +402,8 @@ def main(args): if args.hf_model is not None: model = create_wrapped_hf_model(args) else: - with torch.device("meta" if args.fsdp else args.device): + with torch.device("meta" if args.init_meta_device else args.device): model = create_model(args) - if not args.fsdp: - model.reset_parameters() args.vocab_size = model.vocab_size args.seq_len = model.seq_len diff --git a/open_lm/params.py b/open_lm/params.py index 4920cbb7..5f957bb9 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -561,6 +561,12 @@ def parse_args(args): default=False, help="If true, ignore parse errors in data loading. This should ideally be False, as errors in dataloading can point to bigger issues in your dataset. However, this can be useful when training on a large dataset which has a couple errors.", ) + parser.add_argument( + "--init-meta-device", + action="store_true", + default=False, + help="If true, initialize the model on the meta device. This allows creating models larger than CPU memory. Can only be specified if --fsdp is also specified." + ) add_model_args(parser)