-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Nanz/overload factor logging #4110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| MoECudaGraphTensorStore, | ||
| get_default_pg_collection, | ||
| maybe_skip_or_early_return_by_cudagraph, | ||
| save_overload_factor_to_tracker, | ||
| ) | ||
| from megatron.core.transformer.moe.router import TopKRouter | ||
| from megatron.core.transformer.moe.token_dispatcher import ( | ||
|
|
@@ -233,6 +234,8 @@ def __init__( | |
| ) | ||
|
|
||
| self.tp_group = pg_collection.tp | ||
| self.tp_ep_group = pg_collection.tp_ep | ||
| self.dp_group = pg_collection.dp | ||
|
|
||
| # Initialize router. | ||
| self.router = self.submodules.router( | ||
|
|
@@ -431,6 +434,20 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor): | |
| """ | ||
| return self.token_dispatcher.token_dispatch(hidden_states, probs) | ||
|
|
||
| def _routing_map_after_token_dispatch(self) -> Optional[torch.Tensor]: | ||
| """Routing map still held after ``token_dispatch`` (cleared in AllGather ``dispatch_postprocess``). | ||
|
|
||
| Flex/HybridEP keep the map on ``_comm_manager``. | ||
| """ | ||
| td = self.token_dispatcher | ||
| rm = getattr(td, "routing_map", None) | ||
| if rm is not None: | ||
| return rm | ||
| cm = getattr(td, "_comm_manager", None) | ||
| if cm is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] Avoid abbreviated variable names in parallel/distributed code where clarity is critical:
Same applies to |
||
| return getattr(cm, "routing_map", None) | ||
| return None | ||
|
|
||
| @maybe_skip_or_early_return_by_cudagraph("shared_experts_compute") | ||
| def shared_experts_compute(self, hidden_states: torch.Tensor): | ||
| """Computes the output of the shared experts. | ||
|
|
@@ -467,9 +484,37 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso | |
| for each expert. It then passes the tokens through the local experts. | ||
| The output from the experts is preprocessed for the combine step. | ||
| """ | ||
| routing_map_for_balanced_count = None | ||
| if self.config.log_overload_factor: | ||
| routing_map_for_balanced_count = self._routing_map_after_token_dispatch() | ||
|
|
||
| dispatched_input, tokens_per_expert, permuted_probs = ( | ||
| self.token_dispatcher.dispatch_postprocess(hidden_states, probs) | ||
| ) | ||
| if self.config.log_overload_factor and routing_map_for_balanced_count is not None: | ||
| ws = float(self.tp_ep_group.size()) | ||
| base = float(routing_map_for_balanced_count.shape[0]) * float( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] |
||
| self.config.moe_router_topk | ||
| ) | ||
| # AllGather replicates the full concatenated map on every rank; contribute | ||
| # fair share so report()'s SUM over tp_ep matches one global balanced count. | ||
| td = self.token_dispatcher | ||
| if isinstance(td, MoEAllGatherTokenDispatcher) and ( | ||
| td.tp_size > 1 or td.ep_size > 1 | ||
| ): | ||
| base = base / ws | ||
| local_balanced = torch.empty( | ||
| (), device=dispatched_input.device, dtype=torch.float32 | ||
| ) | ||
| local_balanced.fill_(base) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] local_balanced = torch.tensor(base, device=dispatched_input.device, dtype=torch.float32) |
||
| dispatched_input = save_overload_factor_to_tracker( | ||
| tensor=dispatched_input, | ||
| tokens_per_expert=tokens_per_expert, | ||
| local_balanced_token_count=local_balanced, | ||
| layer_number=self.layer_number, | ||
| tp_ep_group=self.tp_ep_group, | ||
| dp_group=self.dp_group, | ||
| ) | ||
| if ( | ||
| hasattr(self, "_inference_token_dispatcher") | ||
| and self.is_inference_cuda_graphed_iteration | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [SUGGESTION] The Mixing this logic into |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[SUGGESTION] The balanced token count (
routing_map.shape[0] * topk) could be computed earlier inMoELayer.forward()directly fromhidden_statesbefore any dispatch, rather than readingrouting_mapin this post-token_dispatchwindow.Current approach has two fragilities:
routing_mapto still be alive aftertoken_dispatchbut beforedispatch_postprocessclears it — a timing assumption tied to dispatcher internals.routing_mapattr) vs Flex/HybridEP (_comm_manager.routing_map) dispatchers — coupling to internal implementation details.Computing from
hidden_states.shape[0]inMoELayer.forward()would remove_routing_map_after_token_dispatchentirely.