Skip to content

SFT example fails on saving checkpoint after evaluationΒ #759

@dsawczuk-int

Description

@dsawczuk-int

πŸ› Describe the bug

When running Qwen3 SFT example on more than 1 GPU, saving a checkpoint fails if evaluation happens during the same step. Error:

  File "/root/miniforge3/envs/forge/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 383, in _stateful_to_state_dict
    elem.state_dict() if isinstance(elem, Stateful) else elem
    ^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/forge/lib/python3.12/site-packages/torchtitan/components/optimizer.py", line 105, in state_dict
    for sd in map(func, self.model_parts, self.optimizers)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/forge/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 1106, in get_optimizer_state_dict
    optim_state_dict = _get_optim_state_dict(model, optimizers, info)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/forge/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniforge3/envs/forge/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 818, in _get_optim_state_dict
    fqn = fqn_pid_mapping[key]
          ~~~~~~~~~~~~~~~^^^^^
KeyError: 397

Tested on 4xA100.
Changes in qwen3_8b.yaml config:

diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml
index a8d2244..dcf9f1d 100644
--- a/apps/sft/qwen3_8b.yaml
+++ b/apps/sft/qwen3_8b.yaml
@@ -14,7 +14,7 @@ model:
   hf_assets_path: hf://${model_name}
 
 processes:
-  procs: 8
+  procs: 4
   with_gpus: true
 
 optimizer:
@@ -57,7 +57,7 @@ checkpoint:
   initial_load_path: hf://${model_name}  # The path to load the initial checkpoint from. Ignored if `folder` exists.
   initial_load_in_hf: true          # If true, interpret initial_load_path as a HuggingFace model repo
   last_save_in_hf: true
-  interval: 500
+  interval: 50
   async_mode: "disabled"

CMD: python -m apps.sft.main --config apps/sft/qwen3_8b.yaml

If I modify the example to save a checkpoint before evaluation is executed, it seems to work fine. It looks like evaluation alters optimizer state.

Versions

Torchforge: a111042
Other dependencies as installed by build script.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions