Skip to content

Commit

Permalink
Only use meta device if explicitly requested
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Dec 15, 2023
1 parent b2e8762 commit 4f42503
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
11 changes: 6 additions & 5 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,13 @@ def check_args(args):
raise ValueError("Sync protocol not supported when using resume latest.")

if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left:
ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.")
raise ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.")

if args.lr_scheduler != "cosine":
ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.")
raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.")

if args.init_meta_device and not args.fsdp:
raise ValueError("--init-meta-device can only be specified if --fsdp is specified.")


def main(args):
Expand Down Expand Up @@ -399,10 +402,8 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
with torch.device("meta" if args.fsdp else args.device):
with torch.device("meta" if args.init_meta_device else args.device):
model = create_model(args)
if not args.fsdp:
model.reset_parameters()

args.vocab_size = model.vocab_size
args.seq_len = model.seq_len
Expand Down
6 changes: 6 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ def parse_args(args):
default=False,
help="If true, ignore parse errors in data loading. This should ideally be False, as errors in dataloading can point to bigger issues in your dataset. However, this can be useful when training on a large dataset which has a couple errors.",
)
parser.add_argument(
"--init-meta-device",
action="store_true",
default=False,
help="If true, initialize the model on the meta device. This allows creating models larger than CPU memory. Can only be specified if --fsdp is also specified."
)

add_model_args(parser)

Expand Down

0 comments on commit 4f42503

Please sign in to comment.