Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def start_param_sync(self, force_sync: bool = False):
# (falling back to grouped send/recv internally when sizes differ),
# so no manual padding is needed.
dp_size = self.intra_distributed_optimizer_instance_size
if dp_size == 1:
# Single-rank group (e.g., expt_dp_size == 1): no all-gather needed.
self.param_gather_dispatched = True
return
local_rank = self.intra_distributed_optimizer_instance_rank
group = self.intra_distributed_optimizer_instance_group
layerwise_work_handles = []
Expand Down
12 changes: 8 additions & 4 deletions megatron/core/optimizer/layer_wise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def set_bucket_layerwise_params_list(self, model_chunks):
bucket_list.append(param)
bucket.set_layerwise_params_list(bucket_params_list)
# Do the same for expert parallel bucket groups.
if self.expt_dp_params_list is not None:
for group in model_chunk.expert_parallel_bucket_groups:
for bucket in group.buckets:
for group in model_chunk.expert_parallel_bucket_groups:
for bucket in group.buckets:
if self.expt_dp_params_list is not None:
bucket_params_list = [
[] for _ in range(get_pg_size(self.pg_collection.expt_dp))
]
Expand All @@ -187,7 +187,11 @@ def set_bucket_layerwise_params_list(self, model_chunks):
for param in full_params_list:
if param in bucket.params:
bucket_list.append(param)
bucket.set_layerwise_params_list(bucket_params_list)
else:
# expt_dp_size == 1: single rank owns all params, no
# all-gather needed but data structures must be initialized.
bucket_params_list = [list(bucket.params_list)]
bucket.set_layerwise_params_list(bucket_params_list)

@torch.no_grad()
def allgather_params(self) -> None:
Expand Down
38 changes: 20 additions & 18 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,26 @@ def validate_args(args, defaults={}):
)
args.iterations_to_skip.extend(iterations_to_skip_from_file)

# emerging optimizer check
args.use_layer_wise_distributed_optimizer = False
if args.optimizer not in ('sgd', 'adam'):
if args.optimizer == 'dist_muon':
warn_rank_0(
"optimizer='dist_muon' is deprecated. "
"Use --optimizer muon --use-distributed-optimizer instead."
)
args.optimizer = 'muon'
args.use_layer_wise_distributed_optimizer = True

if args.use_distributed_optimizer:
args.use_layer_wise_distributed_optimizer = True
args.use_distributed_optimizer = False

assert not args.use_torch_fsdp2, "Emerging optimizer does not support Torch-FSDP2 for now."
assert not args.use_megatron_fsdp, "Emerging optimizer does not support Megatron-FSDP for now."
assert args.ckpt_format in ["torch", "torch_dist"], "Emerging optimizer supports torch and torch_dist checkpoint format."


# Make sure all functionality that requires Gloo process groups is disabled.
if not args.use_gloo_process_groups:
if args.use_distributed_optimizer:
Expand Down Expand Up @@ -1472,24 +1492,6 @@ def validate_args(args, defaults={}):
'--no-load-optim with --skip-train --perform-rl-step skips the optimizer; ' \
'--rl-offload-optimizer-during-inference is incompatible (no optimizer to offload).'

# emerging optimizer check
if args.optimizer not in ('sgd', 'adam'):
if args.optimizer == 'dist_muon':
warn_rank_0(
"optimizer='dist_muon' is deprecated. "
"Use --optimizer muon --use-distributed-optimizer instead."
)
args.optimizer = 'muon'
args.use_layer_wise_distributed_optimizer = True

if args.use_distributed_optimizer:
args.use_layer_wise_distributed_optimizer = True
args.use_distributed_optimizer = False

assert not args.use_torch_fsdp2, "Muon optimizer does not support Torch-FSDP2 for now."
assert not args.use_megatron_fsdp, "Muon optimizer does not support Megatron-FSDP for now."
assert args.ckpt_format in ["torch", "torch_dist"], "Muon optimizer supports torch and torch_dist checkpoint format."

# Optimizer CPU offload check
if args.optimizer_cpu_offload:
assert args.use_precision_aware_optimizer, (
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3720,6 +3720,6 @@ def should_disable_forward_pre_hook(args):
"""Block forward pre-hook for certain configurations."""
return (
not args.use_megatron_fsdp
and (args.use_distributed_optimizer or 'dist' in args.optimizer)
and (args.use_distributed_optimizer or args.use_layer_wise_distributed_optimizer)
and args.overlap_param_gather
)
Loading