Skip to content

Commit

Permalink
Allow disabling meta device (#155)
Browse files Browse the repository at this point in the history
* Only use meta device if explicitly requested

* Lint fixes

* Use meta device by default but allow disabling it

* Fix one pytest
  • Loading branch information
achalddave authored Dec 16, 2023
1 parent 6db1660 commit b65ad6b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
15 changes: 10 additions & 5 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,14 @@ def check_args(args):
raise ValueError("Sync protocol not supported when using resume latest.")

if args.target_mask_left is not None and args.target_mask_individual == args.target_mask_left:
ValueError(f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}.")
raise ValueError(
f"--target-mask-left and --target-mask-individual set to same value of {args.target_mask_left}."
)

if args.lr_scheduler != "cosine":
ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.")
raise ValueError(
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown."
)


def main(args):
Expand All @@ -273,6 +277,8 @@ def main(args):

# fully initialize distributed device environment
device = init_distributed_device(args)
if args.fsdp and not args.distributed:
raise ValueError(f"--fsdp can only be specified in distributed mode.")

# get the name of the experiments
if args.name is None:
Expand Down Expand Up @@ -399,10 +405,9 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
with torch.device("meta" if args.fsdp else args.device):
# Use meta device when FSDP is provided, unless user explicitly requests not to.
with torch.device("meta" if args.fsdp and not args.disable_meta_device else args.device):
model = create_model(args)
if not args.fsdp:
model.reset_parameters()

args.vocab_size = model.vocab_size
args.seq_len = model.seq_len
Expand Down
6 changes: 6 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ def parse_args(args):
default=False,
help="If true, ignore parse errors in data loading. This should ideally be False, as errors in dataloading can point to bigger issues in your dataset. However, this can be useful when training on a large dataset which has a couple errors.",
)
parser.add_argument(
"--disable-meta-device",
action="store_true",
default=False,
help="If True, initialize the model on CPU instead of on meta device. This can be useful for debugging or for new models which do not support the meta device.",
)

add_model_args(parser)

Expand Down

0 comments on commit b65ad6b

Please sign in to comment.