Skip to content

Issue Loading FSDP wrapped module using FULL_STATE_DICT type.  #141

@hbikki

Description

@hbikki

🐛 Describe the bug

Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.

Versions:
pytorch = 2.0.0+cu117
torchx-nightly>=2023.3.15
torchsnapshot=0.1.0

Host Details:
The bellow training is tested on a single node with 8 NPROC_PER_NODE.

Code:

Model training code:

def train() -> None:
    init_process_group(backend="nccl")
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank())
    model = load_model("t5-small")

    fsdp_model = FSDP(
        model,
        auto_wrap_policy=functools.partial(
            transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
        ),
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        device_id=local_rank(),
    )
    <-------training -loop-->
    <-------save_checkpoint-->

stateDictType = FULL_STATE_DICT
related saving/loading code:

  def save_checkpoint() -> None:
        with FSDP.state_dict_type(
            checkpoint.model,
            self.stateDictType):
            Snapshot.take(path=str(save_dir), app_state=app_state)

    def load_checkpoint() -> None:
        with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
            Snapshot(path=str(load_dir)).restore(app_state=app_state)
   

Error stack trace:
https://pastebin.com/ih9qSbwR

.snapshot_metadata for the model on local rank:
https://pastebin.com/t6grkKyX

Does anyone know how to resolve this ? thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions