Skip to content

Commit

Permalink
Disable meta device by default (#165)
Browse files Browse the repository at this point in the history
* Disable meta device by default

* Call reset_parameters for no meta device case
  • Loading branch information
achalddave authored Dec 18, 2023
1 parent 1398898 commit 6570b81
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
7 changes: 5 additions & 2 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def check_args(args):
f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown."
)

if args.experimental_meta_device:
print("WARNING: Meta device initialization requested, but this is not currently fully tested.")


def main(args):
args = parse_args(args)
Expand Down Expand Up @@ -420,8 +423,8 @@ def main(args):
if args.hf_model is not None:
model = create_wrapped_hf_model(args)
else:
# Use meta device when FSDP is provided, unless user explicitly requests not to.
with torch.device("meta" if args.fsdp and not args.disable_meta_device else args.device):
# Optional: Use meta device
with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device):
model = create_model(args)

args.vocab_size = model.vocab_size
Expand Down
4 changes: 2 additions & 2 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,10 @@ def parse_args(args):
help="If true, ignore parse errors in data loading. This should ideally be False, as errors in dataloading can point to bigger issues in your dataset. However, this can be useful when training on a large dataset which has a couple errors.",
)
parser.add_argument(
"--disable-meta-device",
"--experimental-meta-device",
action="store_true",
default=False,
help="If True, initialize the model on CPU instead of on meta device. This can be useful for debugging or for new models which do not support the meta device.",
help="If True, initialize the model on meta device. This can be useful for loading large models, but is not currently fully tested.",
)

add_model_args(parser)
Expand Down
4 changes: 3 additions & 1 deletion open_lm/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def __init__(self, dim_model: int, seq_len: int, *_, **__):
self._cos_cached = None
self._sin_cached = None
self._seq_len_cached = 0
self._update_cos_sin_tables(seq_len)
self.seq_len = seq_len
self.reset_parameters()

def reset_parameters(self):
self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
self._update_cos_sin_tables(self.seq_len)

def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None):
# If no seq_len is provided, use the cached one
Expand Down

0 comments on commit 6570b81

Please sign in to comment.