Skip to content

Commit 3860daf

Browse files
committed
Fix incorrectly set decoupled_grad in training.py for MFSDP.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent dcc6d62 commit 3860daf

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

megatron/core/optimizer/__init__.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -951,21 +951,32 @@ def get_megatron_optimizer(
951951
buffer_name='buffers',
952952
)
953953

954-
optimizers.append(
955-
_get_megatron_optimizer_based_on_param_groups(
956-
config=config,
957-
model_chunks=model_chunk,
958-
param_groups=param_groups,
959-
per_model_buffers=buffers,
960-
model_parallel_group=mp_group,
961-
data_parallel_group=dp_cp_group,
962-
data_parallel_group_gloo=intra_dp_cp_group_gloo,
963-
data_parallel_group_idx=model_parallel_rank,
964-
intra_dist_opt_group=intra_dist_opt_group,
965-
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
966-
pg_collection=pg_collection,
967-
)
954+
optimizer_part = _get_megatron_optimizer_based_on_param_groups(
955+
config=config,
956+
model_chunks=model_chunk,
957+
param_groups=param_groups,
958+
per_model_buffers=buffers,
959+
model_parallel_group=mp_group,
960+
data_parallel_group=dp_cp_group,
961+
data_parallel_group_gloo=intra_dp_cp_group_gloo,
962+
data_parallel_group_idx=model_parallel_rank,
963+
intra_dist_opt_group=intra_dist_opt_group,
964+
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
965+
pg_collection=pg_collection,
968966
)
967+
if (
968+
not USING_PYTORCH_OPTIMIZER
969+
and config.use_precision_aware_optimizer
970+
and getattr(optimizer_part.optimizer, "master_weights", None) is not None
971+
):
972+
# NOTE(@cspades): FusedAdam is provided Megatron-FSDP's main weights as
973+
# non-quantized DTensor(s). Megatron-FSDP should NEVER use FusedAdam's
974+
# main weights, complete waste of memory as the optimizer step is still
975+
# applied to the Megatron-FSDP main weight and extended to FusedAdam
976+
# main weights. Override this here.
977+
setattr(optimizer_part.optimizer, "master_weights", False)
978+
979+
optimizers.append(optimizer_part)
969980
model_chunk_offset += 1
970981

971982
if len(optimizers) == 1:

megatron/training/training.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,16 @@ def build_model():
14671467
kwargs['megatron_fsdp_main_params_dtype'] = args.megatron_fsdp_main_params_dtype
14681468
kwargs['megatron_fsdp_main_grads_dtype'] = args.megatron_fsdp_main_grads_dtype
14691469
kwargs['megatron_fsdp_grad_comm_dtype'] = args.megatron_fsdp_grad_comm_dtype
1470-
kwargs['megatron_fsdp_use_decoupled_grad'] = args.use_precision_aware_optimizer
1470+
kwargs['megatron_fsdp_use_decoupled_grad'] = args.use_precision_aware_optimizer and (
1471+
# NOTE(@cspades): Follow the same __post_init__ as in OptimizerConfig:
1472+
# OptimizerConfig.use_precision_aware_optimizer_no_fp8_or_ds_fp8
1473+
# Megatron-FSDP only needs to be consistent with FusedAdam(use_decoupled_grad=?),
1474+
# and because this OR logic isn't particularly restrictive, we can just use
1475+
# use_precision_aware_optimizer_no_fp8_or_ds_fp8 for Megatron-FSDP optimization.
1476+
args.main_params_dtype != torch.float32
1477+
or (args.fp8_recipe is None or args.fp8_recipe == "delayed")
1478+
or args.optimizer_cpu_offload
1479+
)
14711480

14721481
# Initialize DDPConfig.
14731482
ddp_config = DistributedDataParallelConfig(**kwargs)

tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,12 @@ def _training_loop(seed=42, **kwargs):
690690
train_iters=NUM_TRAINING_STEPS,
691691
**kwargs,
692692
)
693+
if kwargs.get("use_megatron_fsdp", False) and kwargs.get(
694+
"use_precision_aware_optimizer", False
695+
):
696+
assert (
697+
not optim.optimizer.master_weights
698+
), "Megatron-FSDP should not use FusedAdam master weights."
693699

694700
# Prepare data iterator
695701
data_iterator = make_gpt_mock_data_iterator(
@@ -759,6 +765,15 @@ def _training_loop(seed=42, **kwargs):
759765
),
760766
id="optim_grads_params_double_buffer",
761767
),
768+
pytest.param(
769+
dict(
770+
data_parallel_sharding_strategy="optim_grads_params",
771+
megatron_fsdp_main_params_dtype=torch.float32,
772+
use_precision_aware_optimizer=True,
773+
fp8_recipe=None,
774+
),
775+
id="optim_grads_params_fused_adam_e2e",
776+
),
762777
pytest.param(
763778
dict(
764779
data_parallel_sharding_strategy="optim_grads_params", fsdp_double_buffer=False
@@ -776,8 +791,10 @@ def _training_loop(seed=42, **kwargs):
776791
],
777792
)
778793
def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs):
779-
if spec_configs.get("fp8_recipe") == "mxfp8" and not HAVE_TE_MXFP8TENSOR:
780-
pytest.skip("Requires PyTorch with TE MXFP8Tensor support")
794+
if spec_configs.get("fp8_recipe") == "mxfp8" and (
795+
torch.cuda.get_device_capability()[0] < 10 or not HAVE_TE_MXFP8TENSOR
796+
):
797+
pytest.skip("Requires PyTorch & CUDA device with TE MXFP8Tensor support")
781798

782799
nd_topology_str = "_".join([f"{k}{v}" for k, v in nd_topology.items()])
783800
if nd_topology_str not in ref_cache:

0 commit comments

Comments
 (0)