Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MoECudaGraphTensorStore,
get_default_pg_collection,
maybe_skip_or_early_return_by_cudagraph,
save_overload_factor_to_tracker,
)
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
Expand Down Expand Up @@ -233,6 +234,8 @@ def __init__(
)

self.tp_group = pg_collection.tp
self.tp_ep_group = pg_collection.tp_ep
self.dp_group = pg_collection.dp

# Initialize router.
self.router = self.submodules.router(
Expand Down Expand Up @@ -431,6 +434,20 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
"""
return self.token_dispatcher.token_dispatch(hidden_states, probs)

def _routing_map_after_token_dispatch(self) -> Optional[torch.Tensor]:
"""Routing map still held after ``token_dispatch`` (cleared in AllGather ``dispatch_postprocess``).

Flex/HybridEP keep the map on ``_comm_manager``.
"""
td = self.token_dispatcher
rm = getattr(td, "routing_map", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The balanced token count (routing_map.shape[0] * topk) could be computed earlier in MoELayer.forward() directly from hidden_states before any dispatch, rather than reading routing_map in this post-token_dispatch window.

Current approach has two fragilities:

  1. Requires routing_map to still be alive after token_dispatch but before dispatch_postprocess clears it — a timing assumption tied to dispatcher internals.
  2. Needs separate handling for AllGather (routing_map attr) vs Flex/HybridEP (_comm_manager.routing_map) dispatchers — coupling to internal implementation details.

Computing from hidden_states.shape[0] in MoELayer.forward() would remove _routing_map_after_token_dispatch entirely.

if rm is not None:
return rm
cm = getattr(td, "_comm_manager", None)
if cm is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Avoid abbreviated variable names in parallel/distributed code where clarity is critical:

  • tdtoken_dispatcher
  • rmrouting_map
  • cmcomm_manager

Same applies to td (L501) and ws (L495) in experts_compute_dispatch. Use tp_ep_world_size for ws.

return getattr(cm, "routing_map", None)
return None

@maybe_skip_or_early_return_by_cudagraph("shared_experts_compute")
def shared_experts_compute(self, hidden_states: torch.Tensor):
"""Computes the output of the shared experts.
Expand Down Expand Up @@ -467,9 +484,37 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso
for each expert. It then passes the tokens through the local experts.
The output from the experts is preprocessed for the combine step.
"""
routing_map_for_balanced_count = None
if self.config.log_overload_factor:
routing_map_for_balanced_count = self._routing_map_after_token_dispatch()

dispatched_input, tokens_per_expert, permuted_probs = (
self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
)
if self.config.log_overload_factor and routing_map_for_balanced_count is not None:
ws = float(self.tp_ep_group.size())
base = float(routing_map_for_balanced_count.shape[0]) * float(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] base is too generic in this context. Rename to local_balanced_count to make clear it represents the balanced token count for this rank before the tp_ep SUM reduction.

self.config.moe_router_topk
)
# AllGather replicates the full concatenated map on every rank; contribute
# fair share so report()'s SUM over tp_ep matches one global balanced count.
td = self.token_dispatcher
if isinstance(td, MoEAllGatherTokenDispatcher) and (
td.tp_size > 1 or td.ep_size > 1
):
base = base / ws
local_balanced = torch.empty(
(), device=dispatched_input.device, dtype=torch.float32
)
local_balanced.fill_(base)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] torch.empty(()) + fill_() is unnecessarily verbose. Use torch.tensor directly:

local_balanced = torch.tensor(base, device=dispatched_input.device, dtype=torch.float32)

dispatched_input = save_overload_factor_to_tracker(
tensor=dispatched_input,
tokens_per_expert=tokens_per_expert,
local_balanced_token_count=local_balanced,
layer_number=self.layer_number,
tp_ep_group=self.tp_ep_group,
dp_group=self.dp_group,
)
if (
hasattr(self, "_inference_token_dispatcher")
and self.is_inference_cuda_graphed_iteration
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The log_overload_factor block inside experts_compute_dispatch (balanced count retrieval, dispatcher type check, tensor construction, hook registration) should be extracted into a private method such as _record_overload_factor(self, dispatched_input, tokens_per_expert).

Mixing this logic into experts_compute_dispatch hurts readability. A single call site keeps the dispatch method focused on dispatch logic.

Expand Down
Loading