Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Suchin Gururangan committed Dec 15, 2023
1 parent 39124e7 commit 6196cbc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ out*
tests/assets/*
.vscode/
checkpoints/
experiments/
3 changes: 2 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
StateDictType,
CPUOffload,
)

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from open_lm.data import proc_token
Expand Down Expand Up @@ -153,6 +152,7 @@ def load_data_chunks(args):
)
return [0 for _ in range(len(args.dataset_manifest))], 0


def save_checkpoint(
args,
model,
Expand Down Expand Up @@ -473,6 +473,7 @@ def main(args):
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args["static_graph"] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)

if args.grad_checkpointing:
model.set_grad_checkpointing()
Expand Down

0 comments on commit 6196cbc

Please sign in to comment.