[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795
Conversation
Greptile SummaryThis PR fixes a bug where Confidence Score: 5/5Safe to merge — bug is well-understood, fix is targeted, and cross-topology resharding test validates the happy path. All remaining findings are P2 style suggestions. The previously discussed shape-mismatch concern (param.shape vs data.shape for uint8 states) is correctly fixed using data.shape. The DTensor parity checks in load_state_dict and set_scaled_state are now consistent. The only open item is the use of assert for runtime validation, which is a quality suggestion and not a correctness blocker. transformer_engine/pytorch/optimizers/fused_adam.py — specifically the assert-based DTensor parity checks in step() at lines 627–638. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[initialize_state / step entry] --> B{param is DTensor?}
B -- Yes --> C[Extract local_param = param._local_tensor]
B -- No --> D[Use param directly]
C --> E[Create local state tensor via torch.empty_like]
D --> E
E --> F{dtype == uint8 FP8?}
F -- Yes --> G[quantizer.make_empty using data.shape\nquantize_ local data]
F -- No --> H[Plain tensor state]
G --> I{param is DTensor?}
H --> I
I -- Yes --> J[DTensor.from_local wraps local state\nusing param.device_mesh + placements]
I -- No --> K[Store plain tensor in self.state]
J --> K
K --> L[step: unwrap DTensor to _local_tensor\nfor CUDA kernel compute]
L --> M[state_dict: get_unscaled_state to float32 local\nre-wrap as DTensor for DCP]
M --> N[load_state_dict: unpack DTensor to local\nset_scaled_state restores quantised form\nre-wrap as DTensor]
Greploops — Automatically fix all review issues by running Reviews (14): Last reviewed commit: "Merge branch 'main' into cye/fused-adam-..." | Re-trigger Greptile |
|
@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor? |
Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?
The fix is to keep the optimizer state in |
pstjohn
left a comment
There was a problem hiding this comment.
TLDR it would fail if you train on 4 ranks and load on 2 ranks, this adds a test for this.
(among other issues with mFSDP)
|
/te-ci L1 pytorch |
1 similar comment
|
/te-ci L1 pytorch |
|
https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47362013 Some random error: Rebased and rerunning CI. |
…sor. Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Add Greptile bug-fixes. Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
… re-sharding test. Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
1cf3948 to
f095376
Compare
… as those tests need to be run in sequence. Signed-off-by: Cory Ye <cye@nvidia.com>
e3ae4ea to
0164aaa
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
|
/te-ci L1 pytorch -> https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47678738 (Fixed misplaced filters.) |
There was a problem hiding this comment.
LGTM. CI has passed as well now. This fix is urgent for 2.14 since if MLM upgrades to 2.14 DCP checkpoint + load functionality wont work for them for models trained with MFSDP2. Thanks for catching and fixing the bug @cspades
Description
DTensor(QuantizedTensor)(FSDP2-only) use case introduced in Add fused_adam, quantized_model_init, and fsdp2 example #2698 whereFusedAdam's optimizer state is converted into a non-distributed Tensor, which is loaded as a global state dictionary by Torch DCP.This will break Megatron-FSDP checkpointing with DCP in MLM:
and causes FSDP2 to not have a distributed optimizer state. (See re-shard tests.)
Details
dcp_resharding_savemust be run with and beforedcp_resharding_load.dcp_resharding_savedeletes existing checkpoint directories, and write a new DCP checkpoint, whiledcp_resharding_loadfinallydeletes the saved DCP checkpoint as well.Testing
--use-precision-aware-optimizerType of change
Checklist: