diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 929cff79980c..63c784c07f0b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -68,6 +68,9 @@ def __init__( # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + def have_expert_num_tokens(self) -> bool: + return True + def num_dispatchers(self) -> int: return self.num_dispatchers_ diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 500bcefcfaa9..b85b2c7d5fd6 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -94,6 +94,9 @@ def __init__( self.handles: list[tuple | None] = [None, None] self.num_dispatchers_ = num_dispatchers + def have_expert_num_tokens(self) -> bool: + return True + def num_dispatchers(self) -> int: return self.num_dispatchers_ diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7fd8511e297d..89db5694146f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -510,6 +510,9 @@ def num_dispatchers(self) -> int: def output_is_reduced(self) -> bool: return False + def have_expert_num_tokens(self) -> bool: + return True + def prepare( self, a1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d0f5eb498127..86c2951c30f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -31,6 +31,7 @@ _valid_deep_gemm, deep_gemm_moe_fp8, ) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) @@ -1248,6 +1249,7 @@ def eplb_map_to_physical_and_record( logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, indices_type: torch.dtype | None = None, + fused_experts_method: Callable | None = None, ) -> torch.Tensor: """ Map the logical expert ids to physical expert ids @@ -1305,13 +1307,24 @@ def eplb_map_to_physical_and_record( # `expert_load_view`: (num_physical_experts,) # `torch.bincount` is not compilable, so use `scatter_add_` instead. - topk_ids_flatten = topk_ids.flatten() - expert_load_view.scatter_add_( - dim=0, - index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + skip_expert_load_scatter_add = ( + (fused_experts_method is not None) + and isinstance(fused_experts_method, FusedMoEModularKernel) + and fused_experts_method.prepare_finalize.have_expert_num_tokens() ) + if not skip_expert_load_scatter_add: + logger.debug("expert_load_view update from topk_ids.") + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + ) + + else: + logger.debug("expert_load_view update in modular_kernel.") + if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) return topk_ids diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 55aa2593193a..51700a88e8a4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -647,6 +647,7 @@ def forward_cuda( zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, num_fused_shared_experts=layer.num_fused_shared_experts, + fused_experts_method=self.fused_experts, ) if self.rocm_aiter_moe_enabled: @@ -683,6 +684,7 @@ def forward_cuda( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + expert_load_view=expert_load_view, ) else: assert fused_experts is not None @@ -2035,6 +2037,7 @@ def select_experts( zero_expert_num: int | None = None, zero_expert_type: str | None = None, num_fused_shared_experts: int = 0, + fused_experts_method: Callable | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the @@ -2130,6 +2133,7 @@ def select_experts( logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, indices_type=indices_type, + fused_experts_method=fused_experts_method, ) assert topk_ids.dtype == indices_type or indices_type is None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3b5916f8ccaf..d925b82fcfef 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -190,6 +190,9 @@ def supports_async(self) -> bool: """ return False + def have_expert_num_tokens(self) -> bool: + return False + def prepare_async( self, a1: torch.Tensor, @@ -698,6 +701,9 @@ def __init__( self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + # for EPLB + self.local_to_global_physical_experts = None + self.expert_map = None assert ( prepare_finalize.activation_format == fused_experts.activation_formats[0] ), ( @@ -867,6 +873,7 @@ def _prepare( global_num_experts: int, expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, + expert_load_view: torch.Tensor | None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -937,6 +944,26 @@ def _prepare( _expert_topk_weights, ) = receiver() + # In EPLB, update expert load from expert_num_tokens. + if ( + expert_tokens_meta is not None + and expert_load_view is not None + and expert_tokens_meta.expert_num_tokens is not None + and expert_map is not None + ): + # Initialize the mapping of the local physical experts + # to global physical experts, after which it will not change. + # expert_load_view: (num_physical_experts,) + # expert_num_tokens: (local_num_physical_experts,) + local_num_experts = expert_tokens_meta.expert_num_tokens.shape[0] + if self.expert_map is None or not torch.equal(self.expert_map, expert_map): + self.expert_map = expert_map.clone() + + start_idx = int(torch.distributed.get_rank()) * local_num_experts + expert_load_view[start_idx : start_idx + local_num_experts] += ( + expert_tokens_meta.expert_num_tokens + ) + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_weights = ( @@ -1118,6 +1145,7 @@ def forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, + expert_load_view: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1142,6 +1170,10 @@ def forward( - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - expert_load_view (Optional[torch.Tensor]): Optional tensor for + tracking expert load statistics. If provided, the kernel will + update it using ExpertTokensMetadata.expert_num_tokens for + better performance. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1163,6 +1195,7 @@ def forward( global_num_experts, expert_map, apply_router_weight_on_input, + expert_load_view=expert_load_view, ) fused_out = self._fused_experts( diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 2766a2c2249f..1c541964a8ae 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -82,6 +82,9 @@ def __init__( def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts + def have_expert_num_tokens(self) -> bool: + return True + def max_num_tokens_per_rank(self) -> int | None: return self.max_num_tokens diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f82eccb88ce0..cb28bd566d16 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1288,6 +1288,7 @@ def apply( zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, num_fused_shared_experts=layer.num_fused_shared_experts, + fused_experts_method=self.fused_experts, ) #