Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,21 +951,32 @@ def get_megatron_optimizer(
buffer_name='buffers',
)

optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config=config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
optimizer_part = _get_megatron_optimizer_based_on_param_groups(
config=config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
if (
not USING_PYTORCH_OPTIMIZER
and config.use_precision_aware_optimizer
and getattr(optimizer_part.optimizer, "master_weights", None) is not None
):
# NOTE(@cspades): FusedAdam is provided Megatron-FSDP's main weights as
# non-quantized DTensor(s). Megatron-FSDP should NEVER use FusedAdam's
# main weights, complete waste of memory as the optimizer step is still
# applied to the Megatron-FSDP main weight and extended to FusedAdam
# main weights. Override this here.
setattr(optimizer_part.optimizer, "master_weights", False)

optimizers.append(optimizer_part)
model_chunk_offset += 1

if len(optimizers) == 1:
Expand Down
12 changes: 11 additions & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,17 @@ def build_model():
kwargs['megatron_fsdp_main_params_dtype'] = args.megatron_fsdp_main_params_dtype
kwargs['megatron_fsdp_main_grads_dtype'] = args.megatron_fsdp_main_grads_dtype
kwargs['megatron_fsdp_grad_comm_dtype'] = args.megatron_fsdp_grad_comm_dtype
kwargs['megatron_fsdp_use_decoupled_grad'] = args.use_precision_aware_optimizer
kwargs['megatron_fsdp_use_decoupled_grad'] = args.use_precision_aware_optimizer and (
# NOTE(@cspades): Follow the same __post_init__ as in OptimizerConfig:
# OptimizerConfig.use_precision_aware_optimizer_no_fp8_or_ds_fp8
# Megatron-FSDP only needs to be consistent with FusedAdam(use_decoupled_grad=?),
# and because the OR logic of use_precision_aware_optimizer_no_fp8_or_ds_fp8
# isn't particularly restrictive, Megatron-FSDP uses decoupled gradients under
# the same conditions that the distributed optimizer uses decoupled gradients.
args.main_params_dtype != torch.float32
or (args.fp8_recipe is None or args.fp8_recipe == "delayed")
or args.optimizer_cpu_offload
)

# Initialize DDPConfig.
ddp_config = DistributedDataParallelConfig(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,12 @@ def _training_loop(seed=42, **kwargs):
train_iters=NUM_TRAINING_STEPS,
**kwargs,
)
if kwargs.get("use_megatron_fsdp", False) and kwargs.get(
"use_precision_aware_optimizer", False
):
assert (
not optim.optimizer.master_weights
), "Megatron-FSDP should not use FusedAdam master weights."

# Prepare data iterator
data_iterator = make_gpt_mock_data_iterator(
Expand Down Expand Up @@ -759,6 +765,17 @@ def _training_loop(seed=42, **kwargs):
),
id="optim_grads_params_double_buffer",
),
pytest.param(
dict(
data_parallel_sharding_strategy="optim_grads_params",
megatron_fsdp_main_params_dtype=torch.float32,
use_precision_aware_optimizer=True,
fp8_recipe="delayed",
fp8_param_gather=True,
bf16=True,
),
id="optim_grads_params_fused_adam_e2e",
),
pytest.param(
dict(
data_parallel_sharding_strategy="optim_grads_params", fsdp_double_buffer=False
Expand All @@ -776,8 +793,10 @@ def _training_loop(seed=42, **kwargs):
],
)
def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs):
if spec_configs.get("fp8_recipe") == "mxfp8" and not HAVE_TE_MXFP8TENSOR:
pytest.skip("Requires PyTorch with TE MXFP8Tensor support")
if spec_configs.get("fp8_recipe") == "mxfp8" and (
torch.cuda.get_device_capability()[0] < 10 or not HAVE_TE_MXFP8TENSOR
):
pytest.skip("Requires PyTorch & CUDA device with TE MXFP8Tensor support")

nd_topology_str = "_".join([f"{k}{v}" for k, v in nd_topology.items()])
if nd_topology_str not in ref_cache:
Expand Down
Loading