diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/engine/devices/ascend.py index a09fa5f655..acf77dd274 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/engine/devices/ascend.py @@ -17,7 +17,7 @@ def update_step_context(cls, step_context): single_attention_mask = torch.logical_not( torch.tril( torch.ones(step_context.q_seq_length[i], - step_context.kv_seq_length[i], + (step_context.kv_seq_length[i] + 31) & (~31), dtype=torch.bool).cuda(), diagonal=step_context.kv_seq_length[i] - step_context.q_seq_length[i], @@ -28,7 +28,7 @@ def update_step_context(cls, step_context): block_loc = step_context.block_offsets[i][block_idx] token_loc = history_length % block_size for _ in range(step_context.q_seq_length[i]): - kv_start_indices.append(block_loc * block_size + token_loc) + kv_start_indices.append([block_loc * block_size + token_loc]) if _ == step_context.q_seq_length[i] - 1: break token_loc = (token_loc + 1) % block_size diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index bd207a1ecb..c93ac2be0e 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from ..dipu import (apply_rotary_pos_emb, fill_kv_cache, fused_rotary_emb, - multinomial_sampling, paged_attention_fwd, rms_norm) +from ..default import multinomial_sampling +from .apply_rotary_pos_emb import apply_rotary_pos_emb +from .fill_kv_cache import fill_kv_cache +from .fused_rotary_emb import fused_rotary_emb +from .moe_gating_topk_softmax import moe_gating_topk_softmax +from .paged_attention_fwd import paged_attention_fwd +from .rms_norm import rms_norm __all__ = [ 'rms_norm', @@ -8,5 +13,6 @@ 'fused_rotary_emb', 'fill_kv_cache', 'paged_attention_fwd', + 'moe_gating_topk_softmax', 'multinomial_sampling', ] diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py new file mode 100644 index 0000000000..ec5f669feb --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def apply_rotary_pos_emb( + query_states: Tensor, + key_states: Tensor, + cos: Tensor, + sin: Tensor, + position_ids: Tensor, + position_ids_1d: Tensor, + q_embed=None, + k_embed=None, + context=None, +): + bs, head, dim = query_states.shape + num_kv_heads = key_states.shape[1] + query_states_reshaped = query_states.reshape(1, bs, head, dim) + key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None, None) + if q_embed is None: + q_embed = query_states + else: + q_embed.copy_(query_states) + if k_embed is None: + k_embed = key_states + else: + k_embed.copy_(key_states) + return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py new file mode 100644 index 0000000000..3e495a5081 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def fill_kv_cache( + key_states: Tensor, + value_states: Tensor, + key_caches: Tensor, + value_caches: Tensor, + q_start_loc: Tensor, + q_seq_length: Tensor, + kv_seq_length: Tensor, + max_q_seq_length: int, + block_offsets: Tensor, + context: None, +): + """fill key/value state to cache for paged attention.""" + ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, + context.kv_start_indices) diff --git a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py new file mode 100644 index 0000000000..01346bfb58 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def fused_rotary_emb( + query_states: Tensor, + key_states: Tensor, + position_ids: torch.LongTensor, + inv_freq: Tensor, + scaling_factor: float, + out_q: Tensor = None, + out_k: Tensor = None, + context=None, +): + batch, seqlen, head, dim = query_states.shape + num_kv_heads = key_states.shape[-2] + query_states_reshaped = query_states.view(batch, seqlen, head, dim) + key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) + position_ids = position_ids.squeeze(0).unsqueeze(-1) + pos_freq = position_ids / scaling_factor * inv_freq + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None, None) + if out_q is None: + out_q = query_states + else: + out_q.copy_(query_states) + if out_k is None: + out_k = key_states + else: + out_k.copy_(key_states) + return out_q, out_k diff --git a/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..4d9ec312c7 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def moe_gating_topk_softmax(router_logits: Tensor, topk: int): + routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax( + router_logits, topk) + return routing_weights.to(torch.float32), selected_experts.to(torch.int64) diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py new file mode 100644 index 0000000000..40c47d7036 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def flash_context_attention( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + block_size: int, + kv_cache_len: int, + context=None, +): + num_q_heads, dim = query_states.shape[1:3] + num_kv_heads = value_states.shape[1] + batch = q_start_loc.shape[0] + + for i in range(batch): + if torch.equal(q_seq_len[i], kv_seq_len[i]): + ext_ops.context_attention( + query_states, + key_states, + value_states, + q_start_loc[i:i + 1], + q_seq_len[i:i + 1], + num_q_heads, + num_kv_heads, + attn_mask=context.attention_mask[i:i + 1], + attn_output=attn_output, + ) + else: + key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, + num_kv_heads * dim) + ext_ops.paged_prefill_attention( + query_states, + key_cache, + value_cache, + block_offsets, + block_size, + q_start_loc[i:i + 1], + q_seq_len[i:i + 1], + kv_seq_len[i:i + 1], + num_q_heads, + num_kv_heads, + attn_mask=context.attention_mask[i:i + 1], + attn_output=attn_output, + ) + + +def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, + block_offsets, block_size): + num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] + ext_ops.paged_decode_attention( + q, + k_cache, + v_cache, + block_offsets, + block_size, + kv_seq_len, + num_q_heads, + num_kv_heads, + attn_output=attn_output.view(q.shape), + ) + + +def paged_attention_fwd( + query_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: Tensor, + value_cache: Tensor, + attn_output: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_seqlens: Tensor, + max_seqlen: int, + window_size: int = 1, + context=None, +): + is_decoding = query_states.shape[-3] == q_seqlens.size(0) + block_num, block_size, head, dim = key_cache.size() + kv_cache_len = block_num * block_size + k = key_cache.reshape(block_num * block_size, head, dim) + v = value_cache.reshape(block_num * block_size, head, dim) + if not is_decoding: + flash_context_attention( + query_states, + key_states, + value_states, + attn_output, + k, + v, + block_offsets, + q_start_loc, + q_seqlens, + kv_seqlens, + block_size, + kv_cache_len, + context=context, + ) + else: + paged_token_attention( + query_states, + k, + v, + attn_output, + kv_seqlens, + block_offsets, + block_size, + ) diff --git a/lmdeploy/pytorch/kernels/ascend/rms_norm.py b/lmdeploy/pytorch/kernels/ascend/rms_norm.py new file mode 100644 index 0000000000..0abda45f0b --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/rms_norm.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6): + return ext_ops.rms_norm(hidden_states, weight, epsilon) diff --git a/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..b8a55d4225 --- /dev/null +++ b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dispatcher import FunctionDispatcher + +moe_gating_topk_softmax = FunctionDispatcher( + 'moe_gating_topk_softmax').make_caller() diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 68f4ea7f26..8bb77f84e3 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -233,7 +233,7 @@ def __qkv_proj(hidden_states): ) query_states = qkv_states[..., :self.num_key_value_groups, :] query_states = query_states.flatten(1, 2) - key_states = qkv_states[..., -2, :] + key_states = qkv_states[..., -2, :].contiguous() value_states = qkv_states[..., -1, :] return query_states, key_states, value_states diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index b1be9b81df..a8d9f25e44 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -8,6 +8,7 @@ from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd from ..kernels.fused_moe import fused_moe +from ..kernels.moe_gating_topk_softmax import moe_gating_topk_softmax from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -156,6 +157,158 @@ def forward( ) +class PatchedMixtralAttentionAscend(nn.Module): + """Rewrite module of MixtralAttention.""" + + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + for mod_name in ['q_proj', 'k_proj', 'v_proj']: + colwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + rowwise_parallelize_linear(self.o_proj, + loader, + rank=rank, + world_size=world_size, + prefix='o_proj') + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + try: + dist.all_reduce(outputs[0]) + except Exception as e: + print(e) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + attention_mask: Optional[torch.Tensor] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """default rewrite.""" + + context = self.context.context + kv_seq_length = context.kv_seq_length + q_seq_length = context.q_seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + max_kv_seq_length = context.max_kv_seq_length + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + hidden_size = num_heads * self.head_dim + + def __qkv_proj(hidden_states): + """qkv proj.""" + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + if hasattr(self, 'rotary_emb'): + if not hasattr(context, '_cos'): + cos, sin = self.rotary_emb(value_states, + seq_len=max_kv_seq_length) + context._cos = cos + context._sin = sin + else: + cos = context._cos + sin = context._sin + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + context.position_ids_1d, + q_embed=query_states, + k_embed=key_states, + context=context) + return query_states, key_states, value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, self.head_dim) + key_states = key_states.view(-1, num_kv_heads, self.head_dim) + value_states = value_states.view(-1, num_kv_heads, self.head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + # fill kv cache + fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context, + ) + # page attention + attn_output = query_states + window_size = self.config.sliding_window or -1 + paged_attention_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + window_size=window_size, + context=context, + ) + + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of MistralAttention.forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids, + past_key_value, + output_attentions, + attention_mask=attention_mask, + world_size=world_size, + ) + + class PatchedMixtralBLockSparseTop2MLP(nn.Module): def _load_weights(self, loader, rank: int, world_size: int, @@ -255,6 +408,50 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return out_states, router_logits +class PatchedMixtralSparseMoeBlockAscend(nn.Module): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """rewrite moe forward.""" + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = moe_gating_topk_softmax( + router_logits, self.top_k) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device) + + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[None, + top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer( + current_state) * routing_weights[top_x_list, idx_list, None] + + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + class PatchedMixtralModel(nn.Module): def _continuous_batching_forward( diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 9d08e48219..e0f49715b6 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -390,3 +390,37 @@ 'modeling_internlm2.InternLM2FlashAttention2': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', }) + +# ascend mixtral +ASCEND_MODULE_MAP.update({ + 'transformers.models.mixtral.modeling_mixtral.MixtralAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlockAscend', # noqa: E501 +}) + +# ascend qwen1.5 +ASCEND_MODULE_MAP.update({ + 'transformers.models.qwen2.modeling_qwen2.Qwen2Attention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', +}) + +# ascend qwen2 moe +ASCEND_MODULE_MAP.update({ + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeFlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.PatchedQwen2MoeSparseMoeBlockAscend', # noqa: E501 +}) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 02185c703d..048bc370e1 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -146,3 +146,150 @@ def forward( past_key_value, world_size=world_size, ) + + +class PatchedQwen2AttentionAscend(nn.Module): + + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + for mod_name in ['q_proj', 'k_proj', 'v_proj']: + colwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + for mod_name in ['o_proj']: + rowwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of forward. + + Add continuous batching support. Add paged attention support. TP + support. + """ + context = self.context.context + kv_seq_length = context.kv_seq_length + q_seq_length = context.q_seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + max_kv_seq_length = context.max_kv_seq_length + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + head_dim = self.head_dim + hidden_size = num_heads * head_dim + + def __qkv_proj(hidden_states): + """qkv proj.""" + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + if hasattr(self, 'rotary_emb'): + cos, sin = self.rotary_emb(value_states, + seq_len=max_kv_seq_length) + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + context.position_ids_1d, + context=context) + return query_states, key_states, value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context, + ) + + attn_output = query_states + + use_sliding_windows = (getattr(self.config, 'sliding_window', None) + is not None and self.config.use_sliding_window) + window_size = self.config.sliding_window + if not use_sliding_windows: + window_size = -1 + paged_attention_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + window_size=window_size, + context=context, + ) + + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids, + past_key_value, + world_size=world_size, + ) diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index 5db3b68d55..ec7133f112 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -7,6 +7,8 @@ from torch import nn from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from lmdeploy.pytorch.kernels.moe_gating_topk_softmax import \ + moe_gating_topk_softmax class PatchedQwen2MoeSparseMoeBlock(nn.Module): @@ -90,6 +92,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return out_states, router_logits +class PatchedQwen2MoeSparseMoeBlockAscend(nn.Module): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = moe_gating_topk_softmax( + router_logits, self.top_k) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + out_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device) + + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer( + current_state) * routing_weights[top_x, idx, None] + + out_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype)) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_expert_output + + out_states = out_states + shared_expert_output + out_states = out_states.unflatten(0, (-1, sequence_length)) + + return out_states, router_logits + + class PatchedQwen2MoeModel(nn.Module): def _continuous_batching_forward( diff --git a/requirements/runtime.txt b/requirements/runtime.txt index c6a1e74444..3c0e46f19a 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,5 +17,9 @@ tiktoken torch<=2.3.1,>=2.0.0 torchvision<=0.18.1,>=0.15.0 transformers +<<<<<<< HEAD triton>=2.1.0,<=2.3.1; sys_platform == "linux" +======= +triton>=2.1.0,<=2.2.0; sys_platform == "linux" +>>>>>>> support ascend using infer_ext uvicorn