Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]

def have_expert_num_tokens(self) -> bool:
return True

def num_dispatchers(self) -> int:
return self.num_dispatchers_

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers

def have_expert_num_tokens(self) -> bool:
return True

def num_dispatchers(self) -> int:
return self.num_dispatchers_

Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ def num_dispatchers(self) -> int:
def output_is_reduced(self) -> bool:
return False

def have_expert_num_tokens(self) -> bool:
return True

def prepare(
self,
a1: torch.Tensor,
Expand Down
23 changes: 18 additions & 5 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_valid_deep_gemm,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
Expand Down Expand Up @@ -1248,6 +1249,7 @@ def eplb_map_to_physical_and_record(
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
indices_type: torch.dtype | None = None,
fused_experts_method: Callable | None = None,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
Expand Down Expand Up @@ -1305,13 +1307,24 @@ def eplb_map_to_physical_and_record(
# `expert_load_view`: (num_physical_experts,)

# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
skip_expert_load_scatter_add = (
(fused_experts_method is not None)
and isinstance(fused_experts_method, FusedMoEModularKernel)
and fused_experts_method.prepare_finalize.have_expert_num_tokens()
)

if not skip_expert_load_scatter_add:
logger.debug("expert_load_view update from topk_ids.")
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)

else:
logger.debug("expert_load_view update in modular_kernel.")

if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type)
return topk_ids
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def forward_cuda(
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
fused_experts_method=self.fused_experts,
)

if self.rocm_aiter_moe_enabled:
Expand Down Expand Up @@ -683,6 +684,7 @@ def forward_cuda(
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
expert_load_view=expert_load_view,
)
else:
assert fused_experts is not None
Expand Down Expand Up @@ -2035,6 +2037,7 @@ def select_experts(
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
fused_experts_method: Callable | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
Expand Down Expand Up @@ -2130,6 +2133,7 @@ def select_experts(
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
indices_type=indices_type,
fused_experts_method=fused_experts_method,
)

assert topk_ids.dtype == indices_type or indices_type is None
Expand Down
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def supports_async(self) -> bool:
"""
return False

def have_expert_num_tokens(self) -> bool:
return False

def prepare_async(
self,
a1: torch.Tensor,
Expand Down Expand Up @@ -698,6 +701,9 @@ def __init__(
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
# for EPLB
self.local_to_global_physical_experts = None
self.expert_map = None
assert (
prepare_finalize.activation_format == fused_experts.activation_formats[0]
), (
Expand Down Expand Up @@ -867,6 +873,7 @@ def _prepare(
global_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_load_view: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
Expand Down Expand Up @@ -937,6 +944,26 @@ def _prepare(
_expert_topk_weights,
) = receiver()

# In EPLB, update expert load from expert_num_tokens.
if (
expert_tokens_meta is not None
and expert_load_view is not None
and expert_tokens_meta.expert_num_tokens is not None
and expert_map is not None
):
# Initialize the mapping of the local physical experts
# to global physical experts, after which it will not change.
# expert_load_view: (num_physical_experts,)
# expert_num_tokens: (local_num_physical_experts,)
local_num_experts = expert_tokens_meta.expert_num_tokens.shape[0]
if self.expert_map is None or not torch.equal(self.expert_map, expert_map):
self.expert_map = expert_map.clone()

start_idx = int(torch.distributed.get_rank()) * local_num_experts
expert_load_view[start_idx : start_idx + local_num_experts] += (
expert_tokens_meta.expert_num_tokens
)

# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (
Expand Down Expand Up @@ -1118,6 +1145,7 @@ def forward(
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
expert_load_view: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
Expand All @@ -1142,6 +1170,10 @@ def forward(
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- expert_load_view (Optional[torch.Tensor]): Optional tensor for
tracking expert load statistics. If provided, the kernel will
update it using ExpertTokensMetadata.expert_num_tokens for
better performance.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand All @@ -1163,6 +1195,7 @@ def forward(
global_num_experts,
expert_map,
apply_router_weight_on_input,
expert_load_view=expert_load_view,
)

fused_out = self._fused_experts(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts

def have_expert_num_tokens(self) -> bool:
return True

def max_num_tokens_per_rank(self) -> int | None:
return self.max_num_tokens

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ def apply(
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
fused_experts_method=self.fused_experts,
)

#
Expand Down