Skip to content

Commit 2b24404

Browse files
committed
CustomOp: grouped topk
Signed-off-by: Xinyu Chen <[email protected]>
1 parent b8515d5 commit 2b24404

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

vllm_gaudi/ops/hpu_fused_moe.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,81 @@
22

33
import torch
44
import vllm
5+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
6+
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
57
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod)
68
from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp)
79

810

11+
@GroupedTopk.register_oot
12+
class HPUGroupedTopk(GroupedTopk):
13+
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
14+
15+
def forward_oot(
16+
self,
17+
hidden_states: torch.Tensor,
18+
gating_output: torch.Tensor,
19+
e_score_correction_bias: torch.Tensor | None = None,
20+
) -> tuple[torch.Tensor, torch.Tensor]:
21+
22+
gating_output = gating_output.float()
23+
if e_score_correction_bias is not None:
24+
e_score_correction_bias = e_score_correction_bias.float()
25+
26+
if self.scoring_func == "softmax":
27+
scores = torch.softmax(gating_output, dim=-1)
28+
elif self.scoring_func == "sigmoid":
29+
scores = gating_output.sigmoid()
30+
else:
31+
raise ValueError(f"Unsupported scoring function: {self.scoring_func}")
32+
33+
# For batch invariance, use sorted=True to ensure deterministic expert selection
34+
use_sorted = vllm_is_batch_invariant()
35+
36+
num_token = scores.size(0)
37+
if e_score_correction_bias is not None:
38+
# Store original scores before applying correction bias. We use biased
39+
# scores for expert selection but original scores for routing weights
40+
original_scores = scores
41+
scores = scores + e_score_correction_bias.unsqueeze(0)
42+
scores_tmp = scores.clone().reshape(num_token, self.num_expert_group, -1)
43+
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
44+
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
45+
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
46+
group_scores.add_(top1_val)
47+
else:
48+
group_scores = (scores.view(num_token, self.num_expert_group, -1).max(dim=-1).values) # [n, n_group]
49+
if num_token > 1024:
50+
group_mask = torch.zeros_like(group_scores)
51+
for i in range(self.topk_group):
52+
_, group_idx = torch.max(group_scores, dim=-1)
53+
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
54+
if i < self.topk_group - 1:
55+
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
56+
else:
57+
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
58+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
59+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
60+
61+
tmp_scores = scores.reshape(num_token, self.num_expert_group, -1) + (
62+
(1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1)
63+
tmp_scores = tmp_scores.reshape(num_token, -1)
64+
65+
if e_score_correction_bias is not None:
66+
topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted)[1]
67+
# Use original unbiased scores for the routing weights
68+
topk_weights = original_scores.gather(1, topk_ids)
69+
else:
70+
topk_weights, topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted)
71+
72+
if self.renormalize:
73+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
74+
75+
if self.routed_scaling_factor != 1.0:
76+
topk_weights = topk_weights * self.routed_scaling_factor
77+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
78+
79+
980
@UnquantizedFusedMoEMethod.register_oot
1081
class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1182
"""MoE method without quantization."""

0 commit comments

Comments
 (0)