Skip to content

[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795

Merged
vthumbe1503 merged 17 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp
Apr 4, 2026
Merged

[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795
vthumbe1503 merged 17 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp

Conversation

@cspades
Copy link
Copy Markdown
Member

@cspades cspades commented Mar 24, 2026

Description

This will break Megatron-FSDP checkpointing with DCP in MLM:

[rank0]:   File "/opt/megatron-lm/megatron/training/checkpointing.py", line 1015, in preprocess_fsdp_dtensor_state_dict
[rank0]:     model_state_dict, optimizer_state_dict = handle_swiglu_in_state_dict(
[rank0]:                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/megatron-lm/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 318, in handle_swiglu_in_state_dict
[rank0]:     weight_w, weight_v = split_swiglu_linear_fc1(
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/megatron-lm/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 246, in split_swiglu_linear_fc1
[rank0]:     local_tensor = data.to_local()
[rank0]:                    ^^^^^^^^^^^^^
[rank0]: AttributeError: 'Tensor' object has no attribute 'to_local'

and causes FSDP2 to not have a distributed optimizer state. (See re-shard tests.)

Details

  • We wrap the optimizer state as a DTensor matching the distribution characteristics of the original DTensor parameter the state is associated with. So the state is always a DTensor while its compute and steps are applied to the local Tensor. (There are very few line changes if you ignore variable renaming.)
  • Test Structure
    • dcp_resharding_save must be run with and before dcp_resharding_load.
    • dcp_resharding_save deletes existing checkpoint directories, and write a new DCP checkpoint, while dcp_resharding_load finally deletes the saved DCP checkpoint as well.

Testing

  • TE CI/CD
TE_PATH=/workspace/TransformerEngine ./qa/L1_pytorch_distributed_unittest/test.sh

OR

pytest -v -s tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding
  • Megatron-LM + --use-precision-aware-optimizer
# TE@00ba0b493c27f32e2f210b0022132c50da78dac7 (Llama 8B + Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 15:18:07.588704] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9614.0 | throughput per GPU (TFLOP/s/GPU): 1403.6 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131176E+00 | loss scale: 1.0 | grad norm: 5.337 | number of skipped iterations:   0 | number of nan iterations:   0 |

# This PR (Llama 8B Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 14:58:55.856189] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9588.0 | throughput per GPU (TFLOP/s/GPU): 1407.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131045E+00 | loss scale: 1.0 | grad norm: 5.336 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • This PR fixes the Megatron-FSDP un-even DTensor preprocessing error, and training works fine:
# Megatron-FSDP + self.use_precision_aware_optimizer=True + --use-precision-aware-optimizer + BF16 + HFSDP
# And: --save-interval 1 and --ckpt-format fsdp_dtensor to reproduce the checkpointing error quickly.
[2026-03-31 08:37:02.253124] iteration        3/15258789 | consumed samples:          384 | elapsed time per iteration (ms): 18363.2 | throughput per GPU (TFLOP/s/GPU): 734.8 | learning rate: 1.474559E-08 | global batch size:   128 | lm loss: 1.213362E+01 | loss scale: 1.0 | grad norm: 0.000 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • To reproduce the FSDP2 error motivating this PR, use the broken FusedAdam code before this PR/commit and run this test that saves a checkpoint with 4 GPUs and loads the saved checkpoint with 2 GPUs:
torchrun --nproc-per-node 4 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_save" && torchrun --nproc-per-node 2 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_load"

E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg
E           Traceback (most recent call last): (RANK 1)
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 193, in reduce_scatter
E               local_data = map_fun()
E                            ^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 90, in wrapper
E               result = func(*args, **kwargs)
E                        ^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 269, in local_step
E               local_plan = planner.create_local_plan()
E                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 352, in create_local_plan
E               return create_default_local_load_plan(
E                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 485, in create_default_local_load_plan
E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cspades cspades marked this pull request as ready for review March 24, 2026 17:42
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 24, 2026

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR fixes a bug where FusedAdam optimizer states (exp_avg, exp_avg_sq, master_param) were stored as plain local tensors even when the associated parameter was a DTensor, causing Torch DCP to interpret them as global (non-sharded) state and producing shape-mismatch errors during checkpoint resharding. The fix wraps each optimizer state in DTensor.from_local after initialization, unwraps to the local tensor only for actual compute in step(), and adds DTensor-parity assertions throughout.

Confidence Score: 5/5

Safe to merge — bug is well-understood, fix is targeted, and cross-topology resharding test validates the happy path.

All remaining findings are P2 style suggestions. The previously discussed shape-mismatch concern (param.shape vs data.shape for uint8 states) is correctly fixed using data.shape. The DTensor parity checks in load_state_dict and set_scaled_state are now consistent. The only open item is the use of assert for runtime validation, which is a quality suggestion and not a correctness blocker.

transformer_engine/pytorch/optimizers/fused_adam.py — specifically the assert-based DTensor parity checks in step() at lines 627–638.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Core fix: _initialize_state now wraps optimizer states as DTensors; step/state_dict/load_state_dict updated accordingly. Validation assertions use assert which can be silently disabled.
transformer_engine/pytorch/tensor/float8_tensor.py Extends slice.Tensor/select.int dispatch to propagate the transpose cache slice alongside the data slice, using the same t_dim formula already proven in split.Tensor.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Adds dcp_resharding_save and dcp_resharding_load test phases with proper rank-0-only cleanup, barrier synchronisation, and reference-output verification.
tests/pytorch/distributed/test_torch_fsdp2.py Adds test_fsdp2_fused_adam_dcp_resharding orchestrator that runs the save (4 ranks) then load (2 ranks) phases sequentially; correctly excluded from test_fsdp2_fused_adam_tests via -k filter.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[initialize_state / step entry] --> B{param is DTensor?}
    B -- Yes --> C[Extract local_param = param._local_tensor]
    B -- No --> D[Use param directly]
    C --> E[Create local state tensor via torch.empty_like]
    D --> E
    E --> F{dtype == uint8 FP8?}
    F -- Yes --> G[quantizer.make_empty using data.shape\nquantize_ local data]
    F -- No --> H[Plain tensor state]
    G --> I{param is DTensor?}
    H --> I
    I -- Yes --> J[DTensor.from_local wraps local state\nusing param.device_mesh + placements]
    I -- No --> K[Store plain tensor in self.state]
    J --> K
    K --> L[step: unwrap DTensor to _local_tensor\nfor CUDA kernel compute]
    L --> M[state_dict: get_unscaled_state to float32 local\nre-wrap as DTensor for DCP]
    M --> N[load_state_dict: unpack DTensor to local\nset_scaled_state restores quantised form\nre-wrap as DTensor]
Loading

Greploops — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.
Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Reviews (14): Last reviewed commit: "Merge branch 'main' into cye/fused-adam-..." | Re-trigger Greptile

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 24, 2026

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?

  • Add fused_adam, quantized_model_init, and fsdp2 example #2698 introduced logic during the FusedAdam.__init__ such that if the TE model parameters are DTensor, then it will change the optimizer state to normal Tensor.
    • The reason is because empty_like does not pick up the correct dtype from DTensor (from in-line commentary), when the local data is QuantizedTensor. Note that Megatron-FSDP's main weights are FP32, not QuantizedTensor, so our code worked with the original FusedAdam.
  • When Megatron-FSDP (or Megatron-LM's distributed optimizer) performs its first optimizer.step(), Megatron-FSDP exposes FP32 DTensor main weights to the FusedAdam optimizer, and because of the above logic, normal Tensor optimizer states are constructed from the DTensor main weights.
  • Megatron-FSDP depends on DTensor optimizer states for DCP checkpointing of FusedAdam's state, because we employ un-even sharding. Instead, it now sees normal Tensors, and this may break our DCP integration and/or un-even DTensor metadata.

The fix is to keep the optimizer state in DTensor form if the model is in DTensor form, and localize or perform in-place operations to the local Tensor for all FusedAdam operations.

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 25, 2026

Copy link
Copy Markdown
Contributor

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR it would fail if you train on 4 ranks and load on 2 ranks, this adds a test for this.

(among other issues with mFSDP)

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 25, 2026

/te-ci L1 pytorch

1 similar comment
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47362013

Some random error:

ERROR: file or directory not found: /opt/pytorch/lightning-thunder/thunder/tests/test_transformer_engine_executor.py

Rebased and rerunning CI.

cspades and others added 10 commits March 31, 2026 09:00
…sor.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Add Greptile bug-fixes.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
… re-sharding test.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fused-adam-dcp branch from 1cf3948 to f095376 Compare March 31, 2026 16:07
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

… as those tests need to be run in sequence.

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fused-adam-dcp branch from e3ae4ea to 0164aaa Compare March 31, 2026 19:59
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

@cspades cspades changed the title If model parameters are DTensors, optimizer states should also be DTensors. [FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors. Mar 31, 2026
@vthumbe1503 vthumbe1503 added 2.14.0 bug Something isn't working labels Apr 3, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Apr 3, 2026

https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47670017 🙏🏻

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Apr 4, 2026

/te-ci L1 pytorch -> https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47678738 (Fixed misplaced filters.)

Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. CI has passed as well now. This fix is urgent for 2.14 since if MLM upgrades to 2.14 DCP checkpoint + load functionality wont work for them for models trained with MFSDP2. Thanks for catching and fixing the bug @cspades

@vthumbe1503 vthumbe1503 merged commit 5abadf4 into NVIDIA:main Apr 4, 2026
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.14.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants