|
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | import vllm |
| 5 | +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant |
| 6 | +from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk |
5 | 7 | from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) |
6 | 8 | from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) |
7 | 9 |
|
8 | 10 |
|
| 11 | +@GroupedTopk.register_oot |
| 12 | +class HPUGroupedTopk(GroupedTopk): |
| 13 | + """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" |
| 14 | + |
| 15 | + def forward_oot( |
| 16 | + self, |
| 17 | + hidden_states: torch.Tensor, |
| 18 | + gating_output: torch.Tensor, |
| 19 | + e_score_correction_bias: torch.Tensor | None = None, |
| 20 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 21 | + |
| 22 | + gating_output = gating_output.float() |
| 23 | + if e_score_correction_bias is not None: |
| 24 | + e_score_correction_bias = e_score_correction_bias.float() |
| 25 | + |
| 26 | + if self.scoring_func == "softmax": |
| 27 | + scores = torch.softmax(gating_output, dim=-1) |
| 28 | + elif self.scoring_func == "sigmoid": |
| 29 | + scores = gating_output.sigmoid() |
| 30 | + else: |
| 31 | + raise ValueError(f"Unsupported scoring function: {self.scoring_func}") |
| 32 | + |
| 33 | + # For batch invariance, use sorted=True to ensure deterministic expert selection |
| 34 | + use_sorted = vllm_is_batch_invariant() |
| 35 | + |
| 36 | + num_token = scores.size(0) |
| 37 | + if e_score_correction_bias is not None: |
| 38 | + # Store original scores before applying correction bias. We use biased |
| 39 | + # scores for expert selection but original scores for routing weights |
| 40 | + original_scores = scores |
| 41 | + scores = scores + e_score_correction_bias.unsqueeze(0) |
| 42 | + scores_tmp = scores.clone().reshape(num_token, self.num_expert_group, -1) |
| 43 | + top1_val, top1_idx = torch.max(scores_tmp, dim=-1) |
| 44 | + scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) |
| 45 | + group_scores, top2_idx = torch.max(scores_tmp, dim=-1) |
| 46 | + group_scores.add_(top1_val) |
| 47 | + else: |
| 48 | + group_scores = (scores.view(num_token, self.num_expert_group, -1).max(dim=-1).values) # [n, n_group] |
| 49 | + if num_token > 1024: |
| 50 | + group_mask = torch.zeros_like(group_scores) |
| 51 | + for i in range(self.topk_group): |
| 52 | + _, group_idx = torch.max(group_scores, dim=-1) |
| 53 | + group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) |
| 54 | + if i < self.topk_group - 1: |
| 55 | + group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) |
| 56 | + else: |
| 57 | + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] |
| 58 | + group_mask = torch.zeros_like(group_scores) # [n, n_group] |
| 59 | + group_mask.scatter_(1, group_idx, 1) # [n, n_group] |
| 60 | + |
| 61 | + tmp_scores = scores.reshape(num_token, self.num_expert_group, -1) + ( |
| 62 | + (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) |
| 63 | + tmp_scores = tmp_scores.reshape(num_token, -1) |
| 64 | + |
| 65 | + if e_score_correction_bias is not None: |
| 66 | + topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted)[1] |
| 67 | + # Use original unbiased scores for the routing weights |
| 68 | + topk_weights = original_scores.gather(1, topk_ids) |
| 69 | + else: |
| 70 | + topk_weights, topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted) |
| 71 | + |
| 72 | + if self.renormalize: |
| 73 | + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
| 74 | + |
| 75 | + if self.routed_scaling_factor != 1.0: |
| 76 | + topk_weights = topk_weights * self.routed_scaling_factor |
| 77 | + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) |
| 78 | + |
| 79 | + |
9 | 80 | @UnquantizedFusedMoEMethod.register_oot |
10 | 81 | class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): |
11 | 82 | """MoE method without quantization.""" |
|
0 commit comments