Skip to content

Commit 5abadf4

Browse files
cspadespre-commit-ci[bot]greptile-apps[bot]vthumbe1503
authored
[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors. (#2795)
* If model parameters are DTensors, optimizer state should also be DTensor. Signed-off-by: Cory Ye <cye@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Unpack DTensor in FusedAdam.step(). Signed-off-by: Cory Ye <cye@nvidia.com> * Apply suggestions from code review Add Greptile bug-fixes. Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com> * Revert erroneous Greptile diff. Signed-off-by: Cory Ye <cye@nvidia.com> * Add DTensor parity check to FusedAdam.step(). Signed-off-by: Cory Ye <cye@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add DTensor handling in state_dict and load_state_dict, and add a DCP re-sharding test. Signed-off-by: Cory Ye <cye@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test commentary. Signed-off-by: Cory Ye <cye@nvidia.com> * Filter out DCP resharding tests from the 2 GPU FusedAdam test matrix, as those tests need to be run in sequence. Signed-off-by: Cory Ye <cye@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix float8 Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * xfail block scaling Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * Fix rebase error, pytest filters were shoved into a different test. Signed-off-by: Cory Ye <cye@nvidia.com> --------- Signed-off-by: Cory Ye <cye@nvidia.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
1 parent e83c097 commit 5abadf4

File tree

4 files changed

+345
-28
lines changed

4 files changed

+345
-28
lines changed

tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py

Lines changed: 181 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616
fused_adam_fp8_master_weights, fused_adam_fp8_master_weights_no_meta,
1717
fused_adam_bf16, fused_adam_fp8_no_master, fused_adam_bf16_store_param_remainders,
1818
fuse_wgrad_accumulation, dcp_output_parity, dcp_output_parity_async,
19-
safetensors_fp32_export
19+
dcp_resharding_save, dcp_resharding_load, safetensors_fp32_export
2020
2121
Available --recipe values:
2222
DelayedScaling, Float8CurrentScaling, Float8BlockScaling,
2323
MXFP8BlockScaling, NVFP4BlockScaling
24+
25+
Note: dcp_resharding_save and dcp_resharding_load are two phases of a single
26+
cross-topology test. Run dcp_resharding_save under a larger world_size first
27+
(e.g. --nproc_per_node=4), then run dcp_resharding_load under a smaller one
28+
(e.g. --nproc_per_node=2). The orchestration is handled automatically by
29+
test_fsdp2_fused_adam_dcp_resharding in test_torch_fsdp2.py.
2430
"""
2531

2632
import argparse
@@ -465,7 +471,8 @@ def test_safetensors_fp32_export(recipe_name):
465471
if recipe_name == "MXFP8BlockScaling":
466472
pytest.xfail(
467473
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
468-
"MXFP8 quantized tensors, causing illegal memory access"
474+
"MXFP8 quantized tensors, causing illegal memory access. "
475+
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
469476
)
470477

471478
from safetensors.torch import load_file, save_file
@@ -554,7 +561,8 @@ def test_dcp_output_parity(recipe_name, async_save):
554561
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
555562
"MXFP8 quantized tensors, causing illegal memory access: "
556563
"/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function "
557-
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered"
564+
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered. "
565+
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
558566
)
559567

560568
if recipe_name == "NVFP4BlockScaling":
@@ -740,6 +748,173 @@ def test_dcp_output_parity(recipe_name, async_save):
740748
shutil.rmtree(checkpoint_dir, ignore_errors=True)
741749

742750

751+
def test_dcp_resharding_save(recipe_name):
752+
"""Phase 1 of the DCP resharding test: train with current world_size and save checkpoint.
753+
754+
Trains a model for NUM_STEPS, records the forward-pass output, and writes:
755+
- A DCP checkpoint to /tmp/te_test_fsdp2_dcp_resharding_<recipe>/
756+
- A reference output tensor to /tmp/te_test_fsdp2_dcp_resharding_<recipe>_ref.pt
757+
758+
These artifacts are consumed by test_dcp_resharding_load, which runs under
759+
a *different* world_size (typically half as many ranks) to verify that DCP
760+
correctly reshards the checkpoint into the new topology.
761+
762+
The two phases are orchestrated by test_fsdp2_fused_adam_dcp_resharding in
763+
test_torch_fsdp2.py using two sequential plain torchrun invocations.
764+
"""
765+
recipe = get_recipe_from_string(recipe_name)
766+
767+
import torch.distributed.checkpoint as dcp
768+
769+
world_size, device = _get_dist_info()
770+
rank = int(os.environ.get("RANK", "0"))
771+
checkpoint_dir = f"/tmp/te_test_fsdp2_dcp_resharding_{recipe_name}"
772+
ref_output_path = f"/tmp/te_test_fsdp2_dcp_resharding_{recipe_name}_ref.pt"
773+
774+
if rank == 0:
775+
shutil.rmtree(checkpoint_dir, ignore_errors=True)
776+
if os.path.exists(ref_output_path):
777+
os.remove(ref_output_path)
778+
dist.barrier()
779+
780+
model = _build_model(fp8_init=True, recipe=recipe)
781+
model = _shard_model(model, world_size)
782+
783+
optimizer = te.optimizers.FusedAdam(
784+
model.parameters(),
785+
lr=1e-3,
786+
master_weights=True,
787+
master_weight_dtype=torch.float32,
788+
)
789+
790+
# Fixed seed so the load phase reproduces the exact same input tensor.
791+
torch.manual_seed(12345)
792+
torch.cuda.manual_seed(12345)
793+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
794+
target = torch.randn_like(x)
795+
796+
for _ in range(NUM_STEPS):
797+
optimizer.zero_grad(set_to_none=True)
798+
with te.autocast(enabled=True, recipe=recipe):
799+
output = model(x)
800+
loss = F.mse_loss(output, target)
801+
loss.backward()
802+
optimizer.step()
803+
804+
# Record the reference output before saving.
805+
with torch.no_grad():
806+
with te.autocast(enabled=True, recipe=recipe):
807+
ref_output = model(x).clone().cpu()
808+
809+
dist.barrier()
810+
if rank == 0:
811+
torch.save(ref_output, ref_output_path)
812+
813+
if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
814+
model_state = {
815+
k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")
816+
}
817+
else:
818+
model_state = model.state_dict()
819+
820+
dcp.save(
821+
{"model": model_state, "optimizer": optimizer.state_dict()}, checkpoint_id=checkpoint_dir
822+
)
823+
dist.barrier()
824+
825+
826+
def test_dcp_resharding_load(recipe_name):
827+
"""Phase 2 of the DCP resharding test: load into a different world_size and verify parity.
828+
829+
Loads the DCP checkpoint written by test_dcp_resharding_save (which ran
830+
under a larger world_size, e.g. 4 ranks) into a fresh model sharded over
831+
the current, smaller world_size (e.g. 2 ranks). Asserts that the model
832+
output after loading is bitwise-identical to the reference saved in phase 1,
833+
confirming that DCP resharding correctly reconstructs all parameter shards.
834+
"""
835+
recipe = get_recipe_from_string(recipe_name)
836+
837+
import torch.distributed.checkpoint as dcp
838+
839+
world_size, device = _get_dist_info()
840+
rank = int(os.environ.get("RANK", "0"))
841+
checkpoint_dir = f"/tmp/te_test_fsdp2_dcp_resharding_{recipe_name}"
842+
ref_output_path = f"/tmp/te_test_fsdp2_dcp_resharding_{recipe_name}_ref.pt"
843+
844+
try:
845+
model2 = _build_model(fp8_init=True, recipe=recipe)
846+
model2 = _shard_model(model2, world_size)
847+
848+
optimizer2 = te.optimizers.FusedAdam(
849+
model2.parameters(),
850+
lr=1e-3,
851+
master_weights=True,
852+
master_weight_dtype=torch.float32,
853+
)
854+
855+
# Same fixed seed as the save phase to reproduce identical x/target.
856+
torch.manual_seed(12345)
857+
torch.cuda.manual_seed(12345)
858+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
859+
target = torch.randn_like(x)
860+
861+
# Populate optimizer state so load_state_dict has a matching structure.
862+
optimizer2.zero_grad(set_to_none=True)
863+
with te.autocast(enabled=True, recipe=recipe):
864+
out_tmp = model2(x)
865+
F.mse_loss(out_tmp, target).backward()
866+
optimizer2.step()
867+
868+
if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
869+
model2_state = {
870+
k: v for k, v in model2.state_dict().items() if not k.endswith("_extra_state")
871+
}
872+
else:
873+
model2_state = model2.state_dict()
874+
875+
state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()}
876+
dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
877+
model2.load_state_dict(
878+
state_to_load["model"],
879+
strict=(
880+
False
881+
if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling)
882+
else True
883+
),
884+
)
885+
optimizer2.load_state_dict(state_to_load["optimizer"])
886+
887+
with torch.no_grad():
888+
with te.autocast(enabled=True, recipe=recipe):
889+
loaded_output = model2(x).cpu()
890+
891+
if rank == 0:
892+
ref_output = torch.load(ref_output_path, weights_only=True)
893+
894+
if isinstance(recipe, transformer_engine.common.recipe.DelayedScaling):
895+
torch.testing.assert_close(
896+
loaded_output,
897+
ref_output,
898+
rtol=0.05,
899+
atol=0.1,
900+
msg=lambda m: f"Resharded model output differs from reference: {m}",
901+
)
902+
else:
903+
torch.testing.assert_close(
904+
loaded_output,
905+
ref_output,
906+
rtol=0,
907+
atol=0,
908+
msg=lambda m: f"Resharded model output differs from reference: {m}",
909+
)
910+
finally:
911+
dist.barrier()
912+
if rank == 0:
913+
shutil.rmtree(checkpoint_dir, ignore_errors=True)
914+
if os.path.exists(ref_output_path):
915+
os.remove(ref_output_path)
916+
917+
743918
TESTS = {
744919
"fused_adam_fp8_master_weights": test_fused_adam_fp8_master_weights,
745920
"fused_adam_fp8_master_weights_no_meta": test_fused_adam_fp8_master_weights_no_meta,
@@ -749,13 +924,15 @@ def test_dcp_output_parity(recipe_name, async_save):
749924
"fuse_wgrad_accumulation": test_fuse_wgrad_accumulation,
750925
"dcp_output_parity": functools.partial(test_dcp_output_parity, async_save=False),
751926
"dcp_output_parity_async": functools.partial(test_dcp_output_parity, async_save=True),
927+
"dcp_resharding_save": test_dcp_resharding_save,
928+
"dcp_resharding_load": test_dcp_resharding_load,
752929
"safetensors_fp32_export": test_safetensors_fp32_export,
753930
}
754931

755932

756933
if __name__ == "__main__":
757934
parser = argparse.ArgumentParser()
758-
parser.add_argument("--test", required=True, choices=list(TESTS.keys()))
935+
parser.add_argument("--test", required=True, choices=sorted(TESTS.keys()))
759936
parser.add_argument(
760937
"--recipe",
761938
type=str,

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
import subprocess
8+
import sys
89
from pathlib import Path
910

1011
sys.path.append(str(Path(__file__).resolve().parent.parent))
@@ -18,6 +19,12 @@
1819
NUM_PROCS: int = torch.cuda.device_count()
1920
_FSDP2_DIR = Path(__file__).parent.resolve() / "fsdp2_tests"
2021

22+
# Import some utilities from PyTest-owned conftest.py.
23+
sys.path.insert(0, str(_FSDP2_DIR))
24+
from conftest import _parametrize_recipes
25+
26+
sys.path.pop(0)
27+
2128

2229
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
2330
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@@ -59,6 +66,10 @@ def test_fsdp2_fused_adam_tests():
5966
"-v",
6067
"-s",
6168
"--tb=short",
69+
# The following 2 tests need to be run in sequence,
70+
# as they depend on each other.
71+
"-k",
72+
"not dcp_resharding_save and not dcp_resharding_load",
6273
],
6374
valid_returncodes=(0, 5),
6475
env=os.environ,
@@ -90,6 +101,70 @@ def test_fsdp2_mem_leak_tests():
90101
assert result.returncode in (0, 5), f"Inner pytest failed with exit code {result.returncode}"
91102

92103

104+
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs for DP4→DP2 resharding test")
105+
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
106+
@pytest.mark.parametrize("recipe", _parametrize_recipes())
107+
def test_fsdp2_fused_adam_dcp_resharding(recipe):
108+
"""DCP checkpoint saved with DP4 loads correctly into DP2 (cross-topology resharding).
109+
110+
Runs two sequential torchrun invocations against run_fsdp2_fused_adam.py:
111+
1. nproc=4 → dcp_resharding_save (train + write checkpoint + ref output)
112+
2. nproc=2 → dcp_resharding_load (load checkpoint, assert output parity)
113+
"""
114+
if recipe == "MXFP8BlockScaling":
115+
pytest.xfail(
116+
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
117+
"MXFP8 quantized tensors, causing illegal memory access. "
118+
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
119+
)
120+
if recipe == "NVFP4BlockScaling":
121+
pytest.xfail(
122+
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
123+
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
124+
)
125+
if recipe == "Float8BlockScaling":
126+
pytest.xfail(
127+
"Float8BlockScaling doesnt work for DCP resharding with scale inv padding "
128+
"not being handled correctly for slice ops"
129+
)
130+
131+
test_path = _FSDP2_DIR / "run_fsdp2_fused_adam.py"
132+
133+
# Phase 1: save checkpoint with 4 ranks.
134+
result = subprocess.run(
135+
[
136+
"torchrun",
137+
"--nproc_per_node=4",
138+
"--local-ranks-filter=0",
139+
str(test_path),
140+
"--test",
141+
"dcp_resharding_save",
142+
"--recipe",
143+
recipe,
144+
],
145+
env=os.environ,
146+
timeout=300,
147+
)
148+
assert result.returncode == 0, f"DCP resharding save phase failed: {result.returncode}"
149+
150+
# Phase 2: load checkpoint with 2 ranks (different topology).
151+
result = subprocess.run(
152+
[
153+
"torchrun",
154+
"--nproc_per_node=2",
155+
"--local-ranks-filter=0",
156+
str(test_path),
157+
"--test",
158+
"dcp_resharding_load",
159+
"--recipe",
160+
recipe,
161+
],
162+
env=os.environ,
163+
timeout=300,
164+
)
165+
assert result.returncode == 0, f"DCP resharding load phase failed: {result.returncode}"
166+
167+
93168
def test_dummy() -> None:
94169
"""Dummy test
95170

0 commit comments

Comments
 (0)