Skip to content

Commit e191a3d

Browse files
committed
platform: optimize grouped topk op
Signed-off-by: Xinyu Chen <[email protected]>
1 parent e38c8e9 commit e191a3d

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

vllm_gaudi/platform.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,83 @@ def torch_function(origin_cls, func, types, args=(), kwargs=None):
268268

269269
BasevLLMParameter.__torch_function__ = classmethod(torch_function)
270270
return
271+
272+
@classmethod
273+
def has_optimized_grouped_topk(cls) -> bool:
274+
"""
275+
Return if current platform has optimized grouped_topk op.
276+
"""
277+
return True
278+
279+
@classmethod
280+
def grouped_topk(
281+
cls,
282+
hidden_states: torch.Tensor,
283+
gating_output: torch.Tensor,
284+
topk: int,
285+
renormalize: bool,
286+
num_expert_group: int = 0,
287+
topk_group: int = 0,
288+
scoring_func: str = "softmax",
289+
routed_scaling_factor: float = 1.0,
290+
e_score_correction_bias: torch.Tensor | None = None,
291+
) -> tuple[torch.Tensor, torch.Tensor]:
292+
293+
gating_output = gating_output.float()
294+
if e_score_correction_bias is not None:
295+
e_score_correction_bias = e_score_correction_bias.float()
296+
297+
if scoring_func == "softmax":
298+
scores = torch.softmax(gating_output, dim=-1)
299+
elif scoring_func == "sigmoid":
300+
scores = gating_output.sigmoid()
301+
else:
302+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
303+
304+
# For batch invariance, use sorted=True to ensure deterministic expert selection
305+
from vllm.model_executor.layers.batch_invariant import (
306+
vllm_is_batch_invariant, )
307+
use_sorted = vllm_is_batch_invariant()
308+
309+
num_token = scores.size(0)
310+
if e_score_correction_bias is not None:
311+
# Store original scores before applying correction bias. We use biased
312+
# scores for expert selection but original scores for routing weights
313+
original_scores = scores
314+
scores = scores + e_score_correction_bias.unsqueeze(0)
315+
scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1)
316+
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
317+
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
318+
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
319+
group_scores.add_(top1_val)
320+
else:
321+
group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group]
322+
if num_token > 1024:
323+
group_mask = torch.zeros_like(group_scores)
324+
for i in range(topk_group):
325+
_, group_idx = torch.max(group_scores, dim=-1)
326+
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
327+
if i < topk_group - 1:
328+
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
329+
else:
330+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group]
331+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
332+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
333+
334+
tmp_scores = scores.reshape(num_token, num_expert_group, -1) + (
335+
(1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1)
336+
tmp_scores = tmp_scores.reshape(num_token, -1)
337+
338+
if e_score_correction_bias is not None:
339+
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
340+
# Use original unbiased scores for the routing weights
341+
topk_weights = original_scores.gather(1, topk_ids)
342+
else:
343+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)
344+
345+
if renormalize:
346+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
347+
348+
if routed_scaling_factor != 1.0:
349+
topk_weights = topk_weights * routed_scaling_factor
350+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

0 commit comments

Comments
 (0)