@@ -268,3 +268,83 @@ def torch_function(origin_cls, func, types, args=(), kwargs=None):
268268
269269 BasevLLMParameter .__torch_function__ = classmethod (torch_function )
270270 return
271+
272+ @classmethod
273+ def has_optimized_grouped_topk (cls ) -> bool :
274+ """
275+ Return if current platform has optimized grouped_topk op.
276+ """
277+ return True
278+
279+ @classmethod
280+ def grouped_topk (
281+ cls ,
282+ hidden_states : torch .Tensor ,
283+ gating_output : torch .Tensor ,
284+ topk : int ,
285+ renormalize : bool ,
286+ num_expert_group : int = 0 ,
287+ topk_group : int = 0 ,
288+ scoring_func : str = "softmax" ,
289+ routed_scaling_factor : float = 1.0 ,
290+ e_score_correction_bias : torch .Tensor | None = None ,
291+ ) -> tuple [torch .Tensor , torch .Tensor ]:
292+
293+ gating_output = gating_output .float ()
294+ if e_score_correction_bias is not None :
295+ e_score_correction_bias = e_score_correction_bias .float ()
296+
297+ if scoring_func == "softmax" :
298+ scores = torch .softmax (gating_output , dim = - 1 )
299+ elif scoring_func == "sigmoid" :
300+ scores = gating_output .sigmoid ()
301+ else :
302+ raise ValueError (f"Unsupported scoring function: { scoring_func } " )
303+
304+ # For batch invariance, use sorted=True to ensure deterministic expert selection
305+ from vllm .model_executor .layers .batch_invariant import (
306+ vllm_is_batch_invariant , )
307+ use_sorted = vllm_is_batch_invariant ()
308+
309+ num_token = scores .size (0 )
310+ if e_score_correction_bias is not None :
311+ # Store original scores before applying correction bias. We use biased
312+ # scores for expert selection but original scores for routing weights
313+ original_scores = scores
314+ scores = scores + e_score_correction_bias .unsqueeze (0 )
315+ scores_tmp = scores .clone ().reshape (num_token , num_expert_group , - 1 )
316+ top1_val , top1_idx = torch .max (scores_tmp , dim = - 1 )
317+ scores_tmp .scatter_ (- 1 , top1_idx .unsqueeze (- 1 ), torch .finfo (scores .dtype ).min )
318+ group_scores , top2_idx = torch .max (scores_tmp , dim = - 1 )
319+ group_scores .add_ (top1_val )
320+ else :
321+ group_scores = (scores .view (num_token , num_expert_group , - 1 ).max (dim = - 1 ).values ) # [n, n_group]
322+ if num_token > 1024 :
323+ group_mask = torch .zeros_like (group_scores )
324+ for i in range (topk_group ):
325+ _ , group_idx = torch .max (group_scores , dim = - 1 )
326+ group_mask .scatter_ (1 , group_idx .unsqueeze (- 1 ), 1 )
327+ if i < topk_group - 1 :
328+ group_scores .scatter_ (1 , group_idx .unsqueeze (- 1 ), torch .finfo (scores .dtype ).min )
329+ else :
330+ group_idx = torch .topk (group_scores , k = topk_group , dim = - 1 , sorted = use_sorted )[1 ] # [n, top_k_group]
331+ group_mask = torch .zeros_like (group_scores ) # [n, n_group]
332+ group_mask .scatter_ (1 , group_idx , 1 ) # [n, n_group]
333+
334+ tmp_scores = scores .reshape (num_token , num_expert_group , - 1 ) + (
335+ (1 - group_mask ) * torch .finfo (scores .dtype ).min ).unsqueeze (- 1 )
336+ tmp_scores = tmp_scores .reshape (num_token , - 1 )
337+
338+ if e_score_correction_bias is not None :
339+ topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = use_sorted )[1 ]
340+ # Use original unbiased scores for the routing weights
341+ topk_weights = original_scores .gather (1 , topk_ids )
342+ else :
343+ topk_weights , topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = use_sorted )
344+
345+ if renormalize :
346+ topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
347+
348+ if routed_scaling_factor != 1.0 :
349+ topk_weights = topk_weights * routed_scaling_factor
350+ return topk_weights .to (torch .float32 ), topk_ids .to (torch .int32 )
0 commit comments