-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Cherry pick from 0.7.1 to keep the main code newest (#127)
Cherry pick from 0.7.1 to keep the main code newest Signed-off-by: wangxiyuan <[email protected]>
- Loading branch information
1 parent
36991b2
commit 5f46501
Showing
11 changed files
with
1,137 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import torch | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
|
||
|
||
def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: | ||
import torch_npu | ||
|
||
out = torch_npu.npu_swiglu(x) | ||
return out | ||
|
||
|
||
SiluAndMul.forward_oot = silu_and_mul_forward_oot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Callable, Optional | ||
|
||
import torch | ||
import torch_npu | ||
from vllm.model_executor.layers.fused_moe.layer import \ | ||
UnquantizedFusedMoEMethod | ||
|
||
|
||
def group_topk(hidden_states: torch.Tensor, | ||
gating_output: torch.Tensor, | ||
topk: int, | ||
renormalize: bool, | ||
num_expert_group: Optional[int] = 0, | ||
topk_group: Optional[int] = 0, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None): | ||
|
||
assert hidden_states.shape[0] == gating_output.shape[0], ( | ||
"Number of tokens mismatch") | ||
|
||
if scoring_func == "softmax": | ||
scores = torch.softmax(gating_output, dim=-1) | ||
elif scoring_func == "sigmoid": | ||
scores = gating_output.sigmoid() | ||
else: | ||
raise ValueError(f"Unsupported scoring function: {scoring_func}") | ||
|
||
if e_score_correction_bias is not None: | ||
# Store original scores before applying correction bias. We use biased | ||
# scores for expert selection but original scores for routing weights | ||
original_scores = scores | ||
scores = scores + e_score_correction_bias.unsqueeze(0) | ||
|
||
torch_npu.npu_group_topk(input=scores, | ||
out=scores, | ||
group_num=num_expert_group, | ||
k=topk_group) | ||
if e_score_correction_bias is not None: | ||
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] | ||
# Use original unbiased scores for the routing weights | ||
topk_weights = original_scores.gather(1, topk_ids) | ||
else: | ||
topk_weights, topk_ids = torch.topk(scores, | ||
k=topk, | ||
dim=-1, | ||
sorted=False) | ||
|
||
if renormalize: | ||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
|
||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) | ||
|
||
|
||
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, | ||
w2: torch.Tensor, topk_weights: torch.Tensor, | ||
topk_ids: torch.Tensor, top_k: int): | ||
# Check constraints. | ||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" | ||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" | ||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" | ||
assert w1.is_contiguous(), "Expert weights1 must be contiguous" | ||
assert w2.is_contiguous(), "Expert weights2 must be contiguous" | ||
assert hidden_states.dtype in [ | ||
torch.float32, torch.float16, torch.bfloat16 | ||
] | ||
ori_shape = hidden_states.shape | ||
if len(ori_shape) == 3: | ||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
||
num_tokens, _ = hidden_states.shape | ||
E, N, _ = w1.shape | ||
|
||
row_idx_len = num_tokens * top_k | ||
row_idx = torch.arange(0, | ||
row_idx_len, | ||
dtype=torch.int32, | ||
device=topk_weights.device).view(top_k, -1).permute( | ||
1, 0).contiguous() | ||
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( | ||
hidden_states, | ||
row_idx=row_idx, | ||
expert_idx=topk_ids, | ||
active_num=num_tokens) | ||
|
||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( | ||
expanded_expert_idx, E) | ||
expert_tokens = expert_tokens.to(torch.int64) | ||
|
||
w1 = w1.transpose(1, 2) | ||
gate_up_out_list = torch_npu.npu_grouped_matmul(x=[expanded_x], | ||
weight=[w1], | ||
split_item=2, | ||
group_list_type=0, | ||
group_type=0, | ||
group_list=expert_tokens) | ||
|
||
# TODO: Remove this in the future. | ||
gate_up_out = torch.cat(gate_up_out_list, dim=0) | ||
gate_up_out = torch_npu.npu_swiglu(gate_up_out) | ||
|
||
w2 = w2.transpose(1, 2) | ||
down_out_list = torch_npu.npu_grouped_matmul(x=[gate_up_out], | ||
weight=[w2], | ||
split_item=2, | ||
group_list_type=0, | ||
group_type=0, | ||
group_list=expert_tokens) | ||
|
||
down_out_list = torch.cat(down_out_list, dim=0) | ||
# TODO: Reorder device memory 2 times here, replace the current | ||
# implementation here when suitable operators become available. | ||
routing_weights = topk_weights.to(down_out_list.dtype) | ||
hidden_states = torch_npu.npu_moe_finalize_routing( | ||
down_out_list, | ||
skip1=None, | ||
skip2=None, | ||
bias=None, | ||
scales=routing_weights, | ||
expanded_src_to_dst_row=expanded_row_idx, | ||
export_for_source_row=topk_ids) | ||
if len(ori_shape) == 3: | ||
hidden_states = hidden_states.view(ori_shape) | ||
return hidden_states | ||
|
||
|
||
def forward_oot( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
use_grouped_topk: bool, | ||
top_k: int, | ||
router_logits: torch.Tensor, | ||
renormalize: bool, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
|
||
topk_weights, topk_ids = group_topk( | ||
hidden_states=x, | ||
gating_output=router_logits, | ||
topk=top_k, | ||
renormalize=renormalize, | ||
num_expert_group=num_expert_group, | ||
topk_group=topk_group, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias) | ||
|
||
return fused_experts(hidden_states=x, | ||
w1=layer.w13_weight, | ||
w2=layer.w2_weight, | ||
topk_weights=topk_weights, | ||
topk_ids=topk_ids, | ||
top_k=top_k) | ||
|
||
|
||
UnquantizedFusedMoEMethod.forward_oot = forward_oot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding | ||
|
||
|
||
def rope_forward_oot( | ||
self, | ||
positions: torch.Tensor, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
offsets: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
import torch_npu | ||
|
||
if self.cos_sin_cache.device != query.device: | ||
self.cos_sin_cache = self.cos_sin_cache.to(query.device) | ||
if self.cos_sin_cache.dtype != query.dtype: | ||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) | ||
if offsets is not None: | ||
raise NotImplementedError( | ||
"Batched rotary embedding is currently not supported on NPU.") | ||
else: | ||
# TODO: Remove the contiguous in the future. | ||
query = query.contiguous() | ||
key = key.contiguous() | ||
torch_npu.npu_rope( | ||
positions, | ||
query, | ||
key, | ||
self.head_size, | ||
self.cos_sin_cache, | ||
self.is_neox_style, | ||
) | ||
|
||
return query, key | ||
|
||
|
||
RotaryEmbedding.forward_oot = rope_forward_oot |
Oops, something went wrong.