diff --git a/open_lm/main.py b/open_lm/main.py index c3968fc9..4779d20b 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -109,7 +109,9 @@ def load_model(args, model): global_step = checkpoint.get("step", None) if next(iter(sd.items()))[0].startswith("module"): sd = {k[len("module.") :]: v for k, v in sd.items()} - if args.distributed: + if args.fsdp: + model.load_state_dict(sd) + elif args.distributed: model.module.load_state_dict(sd) else: model.load_state_dict(sd) diff --git a/setup.py b/setup.py index a6de7b58..d2a07311 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def _read_reqs(relpath): setuptools.setup( name="open_lm", - version="0.0.24", + version="0.0.25", author=[ "Suchin Gururangan*", "Mitchell Wortsman*",