Skip to content

Commit

Permalink
feat:Redundant code has been fixed,improving performance
Browse files Browse the repository at this point in the history
Signed-off-by: Yaphets24 <[email protected]>
  • Loading branch information
Yaphets24 committed Feb 20, 2025
1 parent 8e30a57 commit 8315b04
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
22 changes: 8 additions & 14 deletions vllm_ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,23 +695,17 @@ def forward(
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
if attn_metadata.num_prefills > 0:
np_positions = np.concatenate([np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens])
positions = torch.tensor(np_positions, device=hidden_states_or_q_c.device)
else:
np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1
positions = torch.tensor(np_positions, device=hidden_states_or_q_c.device)
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)

if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(num_tokens, -1)
k_pe = k_pe.reshape(num_tokens, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe = q_pe.view(ori_q_pe_shape)
k_pe = k_pe.view(ori_k_pe_shape)
else:
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)

if self.w_kc == None or self.w_vc == None:
kv_b_proj_weight = self.kv_b_proj.weight.reshape(self.num_heads,
Expand All @@ -725,13 +719,13 @@ def forward(
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num, -1)
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_cache = torch.cat([kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2)
k_pe = k_pe.repeat(1, self.num_heads, 1)
k_pe = k_pe.expand(-1, self.num_heads, -1)
key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe], dim=2)
else:
kv_heads_num = self.num_kv_heads
q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2), require_contiguous=True)
q_nope_t = torch.transpose(q_nope, 0,1)
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2), require_contiguous=True)
q_nope = torch.transpose(q_nope_out,0,1)
k_cache = torch.cat([kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2)

query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens, self.num_heads, -1)
Expand Down Expand Up @@ -795,10 +789,10 @@ def forward(
batchRunStatusEnable=False, hasQuantOffset=False,
compressType=0, calcType=0, scaleType=0, quantType=0,
inputLayout=0, outDataType=-1, attnOut=attn_output)
attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), require_contiguous=True)
attn_output_t = torch.transpose(attn_output, 0,1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2), require_contiguous=True)
attn_output = torch.transpose(attn_output_t, 0,1)

output, _ = self.o_proj(attn_output.view(num_tokens, -1))
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))

return output
1 change: 1 addition & 0 deletions vllm_ascend/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ def execute_model(
if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, virtual_engine):
model_input.attn_metadata.input_positions = model_input.input_positions
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand Down
5 changes: 2 additions & 3 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def group_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids.to(torch.int32)

def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -121,12 +121,11 @@ def fused_experts(hidden_states: torch.Tensor,
down_out_list = torch.cat(down_out_list, dim=0)
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
routing_weights = topk_weights.to(down_out_list.dtype)
hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None, skip2=None,
bias=None,
scales=routing_weights,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids
)
Expand Down

0 comments on commit 8315b04

Please sign in to comment.