From af4e89cbd8528c61773fafd0a9e36d0fba3bd98d Mon Sep 17 00:00:00 2001 From: chenchiyu Date: Wed, 17 Jul 2024 09:04:18 +0800 Subject: [PATCH 1/8] support ascend using infer_ext --- lmdeploy/pytorch/engine/devices/ascend.py | 2 +- lmdeploy/pytorch/kernels/ascend/__init__.py | 222 +++++++++++++++++- tests/pytorch/kernel/test_apply_rotary.py | 7 +- tests/pytorch/kernel/test_fused_rotary_emb.py | 5 +- tests/pytorch/kernel/test_paged_attention.py | 1 + tests/pytorch/kernel/test_rms_norm.py | 3 + 6 files changed, 234 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/engine/devices/ascend.py index a09fa5f655..7d051dfb19 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/engine/devices/ascend.py @@ -28,7 +28,7 @@ def update_step_context(cls, step_context): block_loc = step_context.block_offsets[i][block_idx] token_loc = history_length % block_size for _ in range(step_context.q_seq_length[i]): - kv_start_indices.append(block_loc * block_size + token_loc) + kv_start_indices.append([block_loc * block_size + token_loc]) if _ == step_context.q_seq_length[i] - 1: break token_loc = (token_loc + 1) % block_size diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index bd207a1ecb..84fe0020a5 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from ..dipu import (apply_rotary_pos_emb, fill_kv_cache, fused_rotary_emb, - multinomial_sampling, paged_attention_fwd, rms_norm) + +import torch +import infer_ext.ops as ext_ops +from torch import Tensor +from ..default import multinomial_sampling __all__ = [ 'rms_norm', @@ -10,3 +13,218 @@ 'paged_attention_fwd', 'multinomial_sampling', ] + +def rms_norm( + hidden_states: Tensor, + weight: Tensor, + epsilon: float = 1e-6 +): + return ext_ops.rms_norm(hidden_states, weight, epsilon) + +def apply_rotary_pos_emb( + query_states: Tensor, + key_states: Tensor, + cos: Tensor, + sin: Tensor, + position_ids: Tensor, + position_ids_1d: Tensor, + q_embed=None, + k_embed=None, + context=None, +): + bs, head, dim = query_states.shape + num_kv_heads = key_states.shape[1] + query_states_reshaped = query_states.reshape(1, bs, head, dim) + key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb( + query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, + None, None, None + ) + if q_embed is None: + q_embed = query_states + else: + q_embed.copy_(query_states) + if k_embed is None: + k_embed = key_states + else: + k_embed.copy_(key_states) + return q_embed, k_embed + +def fused_rotary_emb( + query_states: Tensor, + key_states: Tensor, + position_ids: torch.LongTensor, + inv_freq: Tensor, + scaling_factor: float, + out_q: Tensor = None, + out_k: Tensor = None, + context=None, +): + batch, seqlen, head, dim = query_states.shape + num_kv_heads = key_states.shape[-2] + query_states_reshaped = query_states.view(batch, seqlen, head, dim) + key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) + position_ids = position_ids.squeeze(0).unsqueeze(-1) + pos_freq = position_ids / scaling_factor * inv_freq + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1) + .repeat(1, 1, 1, 2).to(query_states.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1) + .repeat(1, 1, 1, 2).to(query_states.dtype)) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None, None) + if out_q is None: + out_q = query_states + else: + out_q.copy_(query_states) + if out_k is None: + out_k = key_states + else: + out_k.copy_(key_states) + return out_q, out_k + +def fill_kv_cache( + key_states: Tensor, + value_states: Tensor, + key_caches: Tensor, + value_caches: Tensor, + q_start_loc: Tensor, + q_seq_length: Tensor, + kv_seq_length: Tensor, + max_q_seq_length: int, + block_offsets: Tensor, + context: None, +): + """fill key/value state to cache for paged attention.""" + ext_ops.fill_kv_cache(key_states, value_states, key_caches, + value_caches, context.kv_start_indices) + +def flash_context_attention( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + block_size: int, + kv_cache_len: int, + context=None, +): + num_q_heads, dim = query_states.shape[1:3] + num_kv_heads = value_states.shape[1] + batch = q_start_loc.shape[0] + + for i in range(batch): + if torch.equal(q_seq_len[i], kv_seq_len[i]): + ext_ops.context_attention( + attn_output, + query_states, + key_states, + value_states, + q_start_loc[i:i+1], + q_seq_len[i:i+1], + num_q_heads, + num_kv_heads, + context.attention_mask[i:i+1], + ) + else: + key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + ext_ops.paged_prefill_attention( + attn_output, + query_states, + key_cache, + value_cache, + block_offsets, + block_size, + q_start_loc[i:i+1], + q_seq_len[i:i+1], + kv_seq_len[i:i+1], + num_q_heads, + num_kv_heads, + context.attention_mask[i:i+1], + ) + +def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, + block_offsets, block_size): + num_kv_heads = k_cache.shape[1] + bs, num_q_heads, dim = q.shape + kv_cache_len = k_cache.shape[0] + q = q.reshape(bs, 1, num_q_heads * dim) + k_cache = k_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + v_cache = v_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + ext_ops.paged_decode_attention( + attn_output.view(q.shape), + q, + k_cache, + v_cache, + block_offsets, + block_size, + kv_seq_len, + num_q_heads, + num_kv_heads, + ) + +def paged_attention_fwd( + query_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: Tensor, + value_cache: Tensor, + attn_output: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_seqlens: Tensor, + max_seqlen: int, + window_size: int = 1, + context=None, +): + is_decoding = query_states.shape[-3] == q_seqlens.size(0) + block_num, block_size, head, dim = key_cache.size() + kv_cache_len = block_num * block_size + k = key_cache.reshape(block_num * block_size, head, dim) + v = value_cache.reshape(block_num * block_size, head, dim) + if not is_decoding: + flash_context_attention( + query_states, + key_states, + value_states, + attn_output, + k, + v, + block_offsets, + q_start_loc, + q_seqlens, + kv_seqlens, + block_size, + kv_cache_len, + context=context, + ) + else: + paged_token_attention( + query_states, + k, + v, + attn_output, + kv_seqlens, + block_offsets, + block_size, + ) diff --git a/tests/pytorch/kernel/test_apply_rotary.py b/tests/pytorch/kernel/test_apply_rotary.py index e13c71d4ec..bc54a3cf3d 100644 --- a/tests/pytorch/kernel/test_apply_rotary.py +++ b/tests/pytorch/kernel/test_apply_rotary.py @@ -1,4 +1,7 @@ import pytest +import infer_ext +from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager +get_device_manager().set_context(DeviceContext(device_type='ascend')) import torch from lmdeploy.pytorch.kernels import apply_rotary_pos_emb @@ -31,7 +34,7 @@ def num_heads_k(self, request): @pytest.fixture def feature_dim(self): - yield 16 + yield 128 @pytest.fixture def seq_length(self, batch_size): @@ -82,7 +85,7 @@ def gt(self, q_states, k_states, cached_cos, cached_sin, position_ids_1d): yield q_embed, k_embed @pytest.mark.parametrize('dtype', - [torch.bfloat16, torch.float16, torch.float32], + [torch.bfloat16, torch.float16, torch.float32][1:], indirect=True) @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True) diff --git a/tests/pytorch/kernel/test_fused_rotary_emb.py b/tests/pytorch/kernel/test_fused_rotary_emb.py index c0504f51ca..43795d8898 100644 --- a/tests/pytorch/kernel/test_fused_rotary_emb.py +++ b/tests/pytorch/kernel/test_fused_rotary_emb.py @@ -1,5 +1,8 @@ import pytest import torch +import infer_ext +from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager +get_device_manager().set_context(DeviceContext(device_type='ascend')) from torch import nn from lmdeploy.pytorch.kernels.fused_rotary_emb import fused_rotary_emb @@ -77,7 +80,7 @@ def batch_size(self): @pytest.fixture def head_dim(self): - yield 64 + yield 128 @pytest.fixture def q_num_heads(self): diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 90dc153aeb..b3552b99b7 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -271,6 +271,7 @@ def window_gt(self, conti_q, conti_kv, seq_lens, history_lens, win_size): kv_lens, window_size=(win_size, win_size)) + @pytest.mark.skip() @pytest.mark.parametrize('feat_dim', [16], indirect=True) @pytest.mark.parametrize('feat_dim_v', [16], indirect=True) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], diff --git a/tests/pytorch/kernel/test_rms_norm.py b/tests/pytorch/kernel/test_rms_norm.py index 0511ac5f43..21e2706fe9 100644 --- a/tests/pytorch/kernel/test_rms_norm.py +++ b/tests/pytorch/kernel/test_rms_norm.py @@ -1,4 +1,7 @@ import pytest +import infer_ext +from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager +get_device_manager().set_context(DeviceContext(device_type='ascend')) import torch From d0e901a25fba6e77545ff8b9055b3ac3bed5d02e Mon Sep 17 00:00:00 2001 From: chenchiyu Date: Thu, 18 Jul 2024 07:59:18 +0800 Subject: [PATCH 2/8] fix(ascend): make infer_ext using TND format q,k,v in paged_token_attention --- lmdeploy/pytorch/kernels/ascend/__init__.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index 84fe0020a5..c50ba16f1a 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -164,12 +164,7 @@ def flash_context_attention( def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_offsets, block_size): - num_kv_heads = k_cache.shape[1] - bs, num_q_heads, dim = q.shape - kv_cache_len = k_cache.shape[0] - q = q.reshape(bs, 1, num_q_heads * dim) - k_cache = k_cache.reshape(1, kv_cache_len, num_kv_heads * dim) - v_cache = v_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] ext_ops.paged_decode_attention( attn_output.view(q.shape), q, From 90ed8fb4d61dc417eeb97ceee8241ca5dcd8f48b Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 23 Jul 2024 10:57:32 +0000 Subject: [PATCH 3/8] support ascend using infer_ext --- lmdeploy/pytorch/kernels/ascend/__init__.py | 219 +----------------- .../kernels/ascend/apply_rotary_pos_emb.py | 41 ++++ .../pytorch/kernels/ascend/fill_kv_cache.py | 20 ++ .../kernels/ascend/fused_rotary_emb.py | 43 ++++ .../kernels/ascend/paged_attention_fwd.py | 117 ++++++++++ lmdeploy/pytorch/kernels/ascend/rms_norm.py | 7 + lmdeploy/pytorch/models/internlm2.py | 2 +- requirements/runtime.txt | 4 + tests/pytorch/kernel/test_apply_rotary.py | 7 +- tests/pytorch/kernel/test_fused_rotary_emb.py | 5 +- tests/pytorch/kernel/test_paged_attention.py | 1 - tests/pytorch/kernel/test_rms_norm.py | 3 - 12 files changed, 241 insertions(+), 228 deletions(-) create mode 100644 lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py create mode 100644 lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py create mode 100644 lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py create mode 100644 lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py create mode 100644 lmdeploy/pytorch/kernels/ascend/rms_norm.py diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index c50ba16f1a..24aeaac0a3 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. - -import torch -import infer_ext.ops as ext_ops -from torch import Tensor from ..default import multinomial_sampling +from .apply_rotary_pos_emb import apply_rotary_pos_emb +from .fill_kv_cache import fill_kv_cache +from .fused_rotary_emb import fused_rotary_emb +from .paged_attention_fwd import paged_attention_fwd +from .rms_norm import rms_norm __all__ = [ 'rms_norm', @@ -13,213 +14,3 @@ 'paged_attention_fwd', 'multinomial_sampling', ] - -def rms_norm( - hidden_states: Tensor, - weight: Tensor, - epsilon: float = 1e-6 -): - return ext_ops.rms_norm(hidden_states, weight, epsilon) - -def apply_rotary_pos_emb( - query_states: Tensor, - key_states: Tensor, - cos: Tensor, - sin: Tensor, - position_ids: Tensor, - position_ids_1d: Tensor, - q_embed=None, - k_embed=None, - context=None, -): - bs, head, dim = query_states.shape - num_kv_heads = key_states.shape[1] - query_states_reshaped = query_states.reshape(1, bs, head, dim) - key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = cos[position_ids_1d].view(1, bs, 1, -1) - sin = sin[position_ids_1d].view(1, bs, 1, -1) - if context: - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - cached_cos = context.cos if context else cos - cached_sin = context.sin if context else sin - ext_ops.apply_rotary_pos_emb( - query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, - None, None, None - ) - if q_embed is None: - q_embed = query_states - else: - q_embed.copy_(query_states) - if k_embed is None: - k_embed = key_states - else: - k_embed.copy_(key_states) - return q_embed, k_embed - -def fused_rotary_emb( - query_states: Tensor, - key_states: Tensor, - position_ids: torch.LongTensor, - inv_freq: Tensor, - scaling_factor: float, - out_q: Tensor = None, - out_k: Tensor = None, - context=None, -): - batch, seqlen, head, dim = query_states.shape - num_kv_heads = key_states.shape[-2] - query_states_reshaped = query_states.view(batch, seqlen, head, dim) - key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) - position_ids = position_ids.squeeze(0).unsqueeze(-1) - pos_freq = position_ids / scaling_factor * inv_freq - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1) - .repeat(1, 1, 1, 2).to(query_states.dtype)) - sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1) - .repeat(1, 1, 1, 2).to(query_states.dtype)) - if context: - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - cached_cos = context.cos if context else cos - cached_sin = context.sin if context else sin - ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, - cached_cos, cached_sin, None, None, None) - if out_q is None: - out_q = query_states - else: - out_q.copy_(query_states) - if out_k is None: - out_k = key_states - else: - out_k.copy_(key_states) - return out_q, out_k - -def fill_kv_cache( - key_states: Tensor, - value_states: Tensor, - key_caches: Tensor, - value_caches: Tensor, - q_start_loc: Tensor, - q_seq_length: Tensor, - kv_seq_length: Tensor, - max_q_seq_length: int, - block_offsets: Tensor, - context: None, -): - """fill key/value state to cache for paged attention.""" - ext_ops.fill_kv_cache(key_states, value_states, key_caches, - value_caches, context.kv_start_indices) - -def flash_context_attention( - query_states: Tensor, - key_states: Tensor, - value_states: Tensor, - attn_output: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seq_len: Tensor, - kv_seq_len: Tensor, - block_size: int, - kv_cache_len: int, - context=None, -): - num_q_heads, dim = query_states.shape[1:3] - num_kv_heads = value_states.shape[1] - batch = q_start_loc.shape[0] - - for i in range(batch): - if torch.equal(q_seq_len[i], kv_seq_len[i]): - ext_ops.context_attention( - attn_output, - query_states, - key_states, - value_states, - q_start_loc[i:i+1], - q_seq_len[i:i+1], - num_q_heads, - num_kv_heads, - context.attention_mask[i:i+1], - ) - else: - key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) - value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) - ext_ops.paged_prefill_attention( - attn_output, - query_states, - key_cache, - value_cache, - block_offsets, - block_size, - q_start_loc[i:i+1], - q_seq_len[i:i+1], - kv_seq_len[i:i+1], - num_q_heads, - num_kv_heads, - context.attention_mask[i:i+1], - ) - -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - block_offsets, block_size): - num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] - ext_ops.paged_decode_attention( - attn_output.view(q.shape), - q, - k_cache, - v_cache, - block_offsets, - block_size, - kv_seq_len, - num_q_heads, - num_kv_heads, - ) - -def paged_attention_fwd( - query_states: Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - key_cache: Tensor, - value_cache: Tensor, - attn_output: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - kv_seqlens: Tensor, - max_seqlen: int, - window_size: int = 1, - context=None, -): - is_decoding = query_states.shape[-3] == q_seqlens.size(0) - block_num, block_size, head, dim = key_cache.size() - kv_cache_len = block_num * block_size - k = key_cache.reshape(block_num * block_size, head, dim) - v = value_cache.reshape(block_num * block_size, head, dim) - if not is_decoding: - flash_context_attention( - query_states, - key_states, - value_states, - attn_output, - k, - v, - block_offsets, - q_start_loc, - q_seqlens, - kv_seqlens, - block_size, - kv_cache_len, - context=context, - ) - else: - paged_token_attention( - query_states, - k, - v, - attn_output, - kv_seqlens, - block_offsets, - block_size, - ) diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py new file mode 100644 index 0000000000..5d33021e6e --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def apply_rotary_pos_emb( + query_states: Tensor, + key_states: Tensor, + cos: Tensor, + sin: Tensor, + position_ids: Tensor, + position_ids_1d: Tensor, + q_embed=None, + k_embed=None, + context=None, +): + bs, head, dim = query_states.shape + num_kv_heads = key_states.shape[1] + query_states_reshaped = query_states.reshape(1, bs, head, dim) + key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb( + query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, + None, None, None + ) + if q_embed is None: + q_embed = query_states + else: + q_embed.copy_(query_states) + if k_embed is None: + k_embed = key_states + else: + k_embed.copy_(key_states) + return q_embed, k_embed diff --git a/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py new file mode 100644 index 0000000000..26a190d50d --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def fill_kv_cache( + key_states: Tensor, + value_states: Tensor, + key_caches: Tensor, + value_caches: Tensor, + q_start_loc: Tensor, + q_seq_length: Tensor, + kv_seq_length: Tensor, + max_q_seq_length: int, + block_offsets: Tensor, + context: None, +): + """fill key/value state to cache for paged attention.""" + ext_ops.fill_kv_cache(key_states, value_states, key_caches, + value_caches, context.kv_start_indices) diff --git a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py new file mode 100644 index 0000000000..7dd85c49d3 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def fused_rotary_emb( + query_states: Tensor, + key_states: Tensor, + position_ids: torch.LongTensor, + inv_freq: Tensor, + scaling_factor: float, + out_q: Tensor = None, + out_k: Tensor = None, + context=None, +): + batch, seqlen, head, dim = query_states.shape + num_kv_heads = key_states.shape[-2] + query_states_reshaped = query_states.view(batch, seqlen, head, dim) + key_states_reshaped = key_states.view(batch, seqlen, num_kv_heads, dim) + position_ids = position_ids.squeeze(0).unsqueeze(-1) + pos_freq = position_ids / scaling_factor * inv_freq + if not (hasattr(context, 'cos') or hasattr(context, 'sin')): + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1) + .repeat(1, 1, 1, 2).to(query_states.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1) + .repeat(1, 1, 1, 2).to(query_states.dtype)) + if context: + setattr(context, 'cos', cos) + setattr(context, 'sin', sin) + cached_cos = context.cos if context else cos + cached_sin = context.sin if context else sin + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None, None) + if out_q is None: + out_q = query_states + else: + out_q.copy_(query_states) + if out_k is None: + out_k = key_states + else: + out_k.copy_(key_states) + return out_q, out_k diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py new file mode 100644 index 0000000000..563ff7c324 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def flash_context_attention( + query_states: Tensor, + key_states: Tensor, + value_states: Tensor, + attn_output: Tensor, + key_cache: Tensor, + value_cache: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seq_len: Tensor, + kv_seq_len: Tensor, + block_size: int, + kv_cache_len: int, + context=None, +): + num_q_heads, dim = query_states.shape[1:3] + num_kv_heads = value_states.shape[1] + batch = q_start_loc.shape[0] + + for i in range(batch): + if torch.equal(q_seq_len[i], kv_seq_len[i]): + ext_ops.context_attention( + attn_output, + query_states, + key_states, + value_states, + q_start_loc[i:i+1], + q_seq_len[i:i+1], + num_q_heads, + num_kv_heads, + context.attention_mask[i:i+1], + ) + else: + key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + ext_ops.paged_prefill_attention( + attn_output, + query_states, + key_cache, + value_cache, + block_offsets, + block_size, + q_start_loc[i:i+1], + q_seq_len[i:i+1], + kv_seq_len[i:i+1], + num_q_heads, + num_kv_heads, + context.attention_mask[i:i+1], + ) + +def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, + block_offsets, block_size): + num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] + ext_ops.paged_decode_attention( + attn_output.view(q.shape), + q, + k_cache, + v_cache, + block_offsets, + block_size, + kv_seq_len, + num_q_heads, + num_kv_heads, + ) + +def paged_attention_fwd( + query_states: Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + key_cache: Tensor, + value_cache: Tensor, + attn_output: Tensor, + block_offsets: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_seqlens: Tensor, + max_seqlen: int, + window_size: int = 1, + context=None, +): + is_decoding = query_states.shape[-3] == q_seqlens.size(0) + block_num, block_size, head, dim = key_cache.size() + kv_cache_len = block_num * block_size + k = key_cache.reshape(block_num * block_size, head, dim) + v = value_cache.reshape(block_num * block_size, head, dim) + if not is_decoding: + flash_context_attention( + query_states, + key_states, + value_states, + attn_output, + k, + v, + block_offsets, + q_start_loc, + q_seqlens, + kv_seqlens, + block_size, + kv_cache_len, + context=context, + ) + else: + paged_token_attention( + query_states, + k, + v, + attn_output, + kv_seqlens, + block_offsets, + block_size, + ) diff --git a/lmdeploy/pytorch/kernels/ascend/rms_norm.py b/lmdeploy/pytorch/kernels/ascend/rms_norm.py new file mode 100644 index 0000000000..0abda45f0b --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/rms_norm.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +from torch import Tensor + + +def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6): + return ext_ops.rms_norm(hidden_states, weight, epsilon) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 68f4ea7f26..8bb77f84e3 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -233,7 +233,7 @@ def __qkv_proj(hidden_states): ) query_states = qkv_states[..., :self.num_key_value_groups, :] query_states = query_states.flatten(1, 2) - key_states = qkv_states[..., -2, :] + key_states = qkv_states[..., -2, :].contiguous() value_states = qkv_states[..., -1, :] return query_states, key_states, value_states diff --git a/requirements/runtime.txt b/requirements/runtime.txt index c6a1e74444..3c0e46f19a 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -17,5 +17,9 @@ tiktoken torch<=2.3.1,>=2.0.0 torchvision<=0.18.1,>=0.15.0 transformers +<<<<<<< HEAD triton>=2.1.0,<=2.3.1; sys_platform == "linux" +======= +triton>=2.1.0,<=2.2.0; sys_platform == "linux" +>>>>>>> support ascend using infer_ext uvicorn diff --git a/tests/pytorch/kernel/test_apply_rotary.py b/tests/pytorch/kernel/test_apply_rotary.py index bc54a3cf3d..e13c71d4ec 100644 --- a/tests/pytorch/kernel/test_apply_rotary.py +++ b/tests/pytorch/kernel/test_apply_rotary.py @@ -1,7 +1,4 @@ import pytest -import infer_ext -from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager -get_device_manager().set_context(DeviceContext(device_type='ascend')) import torch from lmdeploy.pytorch.kernels import apply_rotary_pos_emb @@ -34,7 +31,7 @@ def num_heads_k(self, request): @pytest.fixture def feature_dim(self): - yield 128 + yield 16 @pytest.fixture def seq_length(self, batch_size): @@ -85,7 +82,7 @@ def gt(self, q_states, k_states, cached_cos, cached_sin, position_ids_1d): yield q_embed, k_embed @pytest.mark.parametrize('dtype', - [torch.bfloat16, torch.float16, torch.float32][1:], + [torch.bfloat16, torch.float16, torch.float32], indirect=True) @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True) diff --git a/tests/pytorch/kernel/test_fused_rotary_emb.py b/tests/pytorch/kernel/test_fused_rotary_emb.py index 43795d8898..c0504f51ca 100644 --- a/tests/pytorch/kernel/test_fused_rotary_emb.py +++ b/tests/pytorch/kernel/test_fused_rotary_emb.py @@ -1,8 +1,5 @@ import pytest import torch -import infer_ext -from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager -get_device_manager().set_context(DeviceContext(device_type='ascend')) from torch import nn from lmdeploy.pytorch.kernels.fused_rotary_emb import fused_rotary_emb @@ -80,7 +77,7 @@ def batch_size(self): @pytest.fixture def head_dim(self): - yield 128 + yield 64 @pytest.fixture def q_num_heads(self): diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index b3552b99b7..90dc153aeb 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -271,7 +271,6 @@ def window_gt(self, conti_q, conti_kv, seq_lens, history_lens, win_size): kv_lens, window_size=(win_size, win_size)) - @pytest.mark.skip() @pytest.mark.parametrize('feat_dim', [16], indirect=True) @pytest.mark.parametrize('feat_dim_v', [16], indirect=True) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], diff --git a/tests/pytorch/kernel/test_rms_norm.py b/tests/pytorch/kernel/test_rms_norm.py index 21e2706fe9..0511ac5f43 100644 --- a/tests/pytorch/kernel/test_rms_norm.py +++ b/tests/pytorch/kernel/test_rms_norm.py @@ -1,7 +1,4 @@ import pytest -import infer_ext -from lmdeploy.pytorch.devices.device_manager import DeviceContext, get_device_manager -get_device_manager().set_context(DeviceContext(device_type='ascend')) import torch From f170fb8798b998c6783fe5269cddb2ed08354c52 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Tue, 23 Jul 2024 11:01:29 +0000 Subject: [PATCH 4/8] support ascend mixtral --- lmdeploy/pytorch/kernels/ascend/__init__.py | 2 + .../kernels/ascend/apply_rotary_pos_emb.py | 6 +- .../pytorch/kernels/ascend/fill_kv_cache.py | 4 +- .../kernels/ascend/fused_rotary_emb.py | 10 +- .../kernels/ascend/moe_gating_topk_softmax.py | 10 + .../kernels/ascend/paged_attention_fwd.py | 21 +- .../kernels/moe_gating_topk_softmax.py | 5 + lmdeploy/pytorch/models/mixtral.py | 196 ++++++++++++++++++ lmdeploy/pytorch/models/module_map.py | 12 ++ 9 files changed, 247 insertions(+), 19 deletions(-) create mode 100644 lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py create mode 100644 lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py diff --git a/lmdeploy/pytorch/kernels/ascend/__init__.py b/lmdeploy/pytorch/kernels/ascend/__init__.py index 24aeaac0a3..c93ac2be0e 100644 --- a/lmdeploy/pytorch/kernels/ascend/__init__.py +++ b/lmdeploy/pytorch/kernels/ascend/__init__.py @@ -3,6 +3,7 @@ from .apply_rotary_pos_emb import apply_rotary_pos_emb from .fill_kv_cache import fill_kv_cache from .fused_rotary_emb import fused_rotary_emb +from .moe_gating_topk_softmax import moe_gating_topk_softmax from .paged_attention_fwd import paged_attention_fwd from .rms_norm import rms_norm @@ -12,5 +13,6 @@ 'fused_rotary_emb', 'fill_kv_cache', 'paged_attention_fwd', + 'moe_gating_topk_softmax', 'multinomial_sampling', ] diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py index 5d33021e6e..ec5f669feb 100644 --- a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -26,10 +26,8 @@ def apply_rotary_pos_emb( setattr(context, 'sin', sin) cached_cos = context.cos if context else cos cached_sin = context.sin if context else sin - ext_ops.apply_rotary_pos_emb( - query_states_reshaped, key_states_reshaped, cached_cos, cached_sin, - None, None, None - ) + ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, + cached_cos, cached_sin, None, None, None) if q_embed is None: q_embed = query_states else: diff --git a/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py index 26a190d50d..3e495a5081 100644 --- a/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/ascend/fill_kv_cache.py @@ -16,5 +16,5 @@ def fill_kv_cache( context: None, ): """fill key/value state to cache for paged attention.""" - ext_ops.fill_kv_cache(key_states, value_states, key_caches, - value_caches, context.kv_start_indices) + ext_ops.fill_kv_cache(key_states, value_states, key_caches, value_caches, + context.kv_start_indices) diff --git a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py index 7dd85c49d3..01346bfb58 100644 --- a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py +++ b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py @@ -21,10 +21,12 @@ def fused_rotary_emb( position_ids = position_ids.squeeze(0).unsqueeze(-1) pos_freq = position_ids / scaling_factor * inv_freq if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = (torch.cos(pos_freq).view(batch, seqlen, 1, -1) - .repeat(1, 1, 1, 2).to(query_states.dtype)) - sin = (torch.sin(pos_freq).view(batch, seqlen, 1, -1) - .repeat(1, 1, 1, 2).to(query_states.dtype)) + cos = (torch.cos(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) + sin = (torch.sin(pos_freq).view(batch, seqlen, 1, + -1).repeat(1, 1, 1, + 2).to(query_states.dtype)) if context: setattr(context, 'cos', cos) setattr(context, 'sin', sin) diff --git a/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..4d9ec312c7 --- /dev/null +++ b/lmdeploy/pytorch/kernels/ascend/moe_gating_topk_softmax.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import infer_ext.ops as ext_ops +import torch +from torch import Tensor + + +def moe_gating_topk_softmax(router_logits: Tensor, topk: int): + routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax( + router_logits, topk) + return routing_weights.to(torch.float32), selected_experts.to(torch.int64) diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py index 563ff7c324..eed147b623 100644 --- a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -21,7 +21,7 @@ def flash_context_attention( ): num_q_heads, dim = query_states.shape[1:3] num_kv_heads = value_states.shape[1] - batch = q_start_loc.shape[0] + batch = q_start_loc.shape[0] for i in range(batch): if torch.equal(q_seq_len[i], kv_seq_len[i]): @@ -30,15 +30,16 @@ def flash_context_attention( query_states, key_states, value_states, - q_start_loc[i:i+1], - q_seq_len[i:i+1], + q_start_loc[i:i + 1], + q_seq_len[i:i + 1], num_q_heads, num_kv_heads, - context.attention_mask[i:i+1], + context.attention_mask[i:i + 1], ) else: key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) - value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, + num_kv_heads * dim) ext_ops.paged_prefill_attention( attn_output, query_states, @@ -46,14 +47,15 @@ def flash_context_attention( value_cache, block_offsets, block_size, - q_start_loc[i:i+1], - q_seq_len[i:i+1], - kv_seq_len[i:i+1], + q_start_loc[i:i + 1], + q_seq_len[i:i + 1], + kv_seq_len[i:i + 1], num_q_heads, num_kv_heads, - context.attention_mask[i:i+1], + context.attention_mask[i:i + 1], ) + def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_offsets, block_size): num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] @@ -69,6 +71,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, num_kv_heads, ) + def paged_attention_fwd( query_states: Tensor, key_states: torch.Tensor, diff --git a/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py new file mode 100644 index 0000000000..b8a55d4225 --- /dev/null +++ b/lmdeploy/pytorch/kernels/moe_gating_topk_softmax.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dispatcher import FunctionDispatcher + +moe_gating_topk_softmax = FunctionDispatcher( + 'moe_gating_topk_softmax').make_caller() diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index b1be9b81df..54d9222b02 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -8,6 +8,7 @@ from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd from ..kernels.fused_moe import fused_moe +from ..kernels.moe_gating_topk_softmax import moe_gating_topk_softmax from ..weight_loader.dist_utils import (colwise_parallelize_linear, rowwise_parallelize_linear) @@ -156,6 +157,157 @@ def forward( ) +class PatchedMixtralAttentionAscend(nn.Module): + """Rewrite module of MixtralAttention.""" + + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + for mod_name in ['q_proj', 'k_proj', 'v_proj']: + colwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + rowwise_parallelize_linear(self.o_proj, + loader, + rank=rank, + world_size=world_size, + prefix='o_proj') + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + try: + dist.all_reduce(outputs[0]) + except Exception as e: + print(e) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + attention_mask: Optional[torch.Tensor] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """default rewrite.""" + + context = self.context.context + kv_seq_length = context.kv_seq_length + q_seq_length = context.q_seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + max_kv_seq_length = context.max_kv_seq_length + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + hidden_size = num_heads * self.head_dim + + def __qkv_proj(hidden_states): + """qkv proj.""" + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + if hasattr(self, 'rotary_emb'): + if not hasattr(context, '_cos'): + cos, sin = self.rotary_emb(value_states, + seq_len=max_kv_seq_length) + context._cos = cos + context._sin = sin + else: + cos = context._cos + sin = context._sin + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + context.position_ids_1d, + q_embed=query_states, + k_embed=key_states, + context=context) + return query_states, key_states, value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, self.head_dim) + key_states = key_states.view(-1, num_kv_heads, self.head_dim) + value_states = value_states.view(-1, num_kv_heads, self.head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + # fill kv cache + fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context, + ) + # page attention + attn_output = query_states + window_size = self.config.sliding_window or -1 + paged_attention_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + window_size=window_size, + context=context, + ) + + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of MistralAttention.forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids, + past_key_value, + output_attentions, + attention_mask=attention_mask, + world_size=world_size, + ) + + class PatchedMixtralBLockSparseTop2MLP(nn.Module): def _load_weights(self, loader, rank: int, world_size: int, @@ -255,6 +407,50 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return out_states, router_logits +class PatchedMixtralSparseMoeBlockAscend(nn.Module): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """rewrite moe forward.""" + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = moe_gating_topk_softmax( + router_logits, self.top_k) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device) + + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[None, + top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer( + current_state) * routing_weights[top_x_list, idx_list, None] + + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + class PatchedMixtralModel(nn.Module): def _continuous_batching_forward( diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 9d08e48219..98934b5124 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -390,3 +390,15 @@ 'modeling_internlm2.InternLM2FlashAttention2': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', }) + +# ascend mixtral +ASCEND_MODULE_MAP.update({ + 'transformers.models.mixtral.modeling_mixtral.MixtralAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralAttentionAscend', + 'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlockAscend', # noqa: E501 +}) From 41d99857ed5f0d135b89a61b89d398ab1b7d99df Mon Sep 17 00:00:00 2001 From: CyCle1024 Date: Thu, 8 Aug 2024 16:27:27 +0800 Subject: [PATCH 5/8] feat: change infer_ext ops function param order (#2) --- lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py index eed147b623..c2730ae25b 100644 --- a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -26,7 +26,6 @@ def flash_context_attention( for i in range(batch): if torch.equal(q_seq_len[i], kv_seq_len[i]): ext_ops.context_attention( - attn_output, query_states, key_states, value_states, @@ -35,13 +34,13 @@ def flash_context_attention( num_q_heads, num_kv_heads, context.attention_mask[i:i + 1], + attn_output=attn_output, ) else: key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) ext_ops.paged_prefill_attention( - attn_output, query_states, key_cache, value_cache, @@ -53,6 +52,7 @@ def flash_context_attention( num_q_heads, num_kv_heads, context.attention_mask[i:i + 1], + attn_output=attn_output, ) @@ -60,7 +60,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_offsets, block_size): num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] ext_ops.paged_decode_attention( - attn_output.view(q.shape), q, k_cache, v_cache, @@ -69,6 +68,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, kv_seq_len, num_q_heads, num_kv_heads, + attn_output=attn_output.view(q.shape), ) From 33f70489094e4ef6fbc27afd2f4e1dcea8c101a9 Mon Sep 17 00:00:00 2001 From: yaofengchen <67218893+yao-fengchen@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:26:54 +0800 Subject: [PATCH 6/8] feat: support ascend qwen2 and qwen2_moe (#6) * feat: support ascend qwen2 and qwen2_moe * fix: fix ascend mixtral --- lmdeploy/pytorch/models/mixtral.py | 1 + lmdeploy/pytorch/models/module_map.py | 22 ++++ lmdeploy/pytorch/models/qwen2.py | 147 ++++++++++++++++++++++++++ lmdeploy/pytorch/models/qwen2_moe.py | 45 ++++++++ 4 files changed, 215 insertions(+) diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py index 54d9222b02..a8d9f25e44 100644 --- a/lmdeploy/pytorch/models/mixtral.py +++ b/lmdeploy/pytorch/models/mixtral.py @@ -292,6 +292,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Rewrite of MistralAttention.forward.""" diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 98934b5124..e0f49715b6 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -402,3 +402,25 @@ 'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlockAscend', # noqa: E501 }) + +# ascend qwen1.5 +ASCEND_MODULE_MAP.update({ + 'transformers.models.qwen2.modeling_qwen2.Qwen2Attention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', +}) + +# ascend qwen2 moe +ASCEND_MODULE_MAP.update({ + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeFlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend', + 'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.PatchedQwen2MoeSparseMoeBlockAscend', # noqa: E501 +}) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 02185c703d..048bc370e1 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -146,3 +146,150 @@ def forward( past_key_value, world_size=world_size, ) + + +class PatchedQwen2AttentionAscend(nn.Module): + + def _load_weights(self, loader, rank: int, world_size: int, + device: torch.device): + """load weights.""" + for mod_name in ['q_proj', 'k_proj', 'v_proj']: + colwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + for mod_name in ['o_proj']: + rowwise_parallelize_linear(getattr(self, mod_name), + loader, + rank=rank, + world_size=world_size, + prefix=mod_name) + + @classmethod + def _distribute_output_fn(cls, outputs, **kwargs): + """Distribution output hook.""" + dist.all_reduce(outputs[0]) + return outputs + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of forward. + + Add continuous batching support. Add paged attention support. TP + support. + """ + context = self.context.context + kv_seq_length = context.kv_seq_length + q_seq_length = context.q_seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + max_kv_seq_length = context.max_kv_seq_length + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_key_value_heads // world_size + head_dim = self.head_dim + hidden_size = num_heads * head_dim + + def __qkv_proj(hidden_states): + """qkv proj.""" + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + if hasattr(self, 'rotary_emb'): + cos, sin = self.rotary_emb(value_states, + seq_len=max_kv_seq_length) + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + context.position_ids_1d, + context=context) + return query_states, key_states, value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context, + ) + + attn_output = query_states + + use_sliding_windows = (getattr(self.config, 'sliding_window', None) + is not None and self.config.use_sliding_window) + window_size = self.config.sliding_window + if not use_sliding_windows: + window_size = -1 + paged_attention_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + window_size=window_size, + context=context, + ) + + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids, + past_key_value, + world_size=world_size, + ) diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index 5db3b68d55..ec7133f112 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -7,6 +7,8 @@ from torch import nn from lmdeploy.pytorch.kernels.fused_moe import fused_moe +from lmdeploy.pytorch.kernels.moe_gating_topk_softmax import \ + moe_gating_topk_softmax class PatchedQwen2MoeSparseMoeBlock(nn.Module): @@ -90,6 +92,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return out_states, router_logits +class PatchedQwen2MoeSparseMoeBlockAscend(nn.Module): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = moe_gating_topk_softmax( + router_logits, self.top_k) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + out_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device) + + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer( + current_state) * routing_weights[top_x, idx, None] + + out_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype)) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_expert_output + + out_states = out_states + shared_expert_output + out_states = out_states.unflatten(0, (-1, sequence_length)) + + return out_states, router_logits + + class PatchedQwen2MoeModel(nn.Module): def _continuous_batching_forward( From 55413f286448c4c039272ee03ebf726cebcb72ef Mon Sep 17 00:00:00 2001 From: Wei Tao <1136862851@qq.com> Date: Mon, 19 Aug 2024 22:03:44 +0800 Subject: [PATCH 7/8] ascend: align attention mask to 32bytes (#7) --- lmdeploy/pytorch/engine/devices/ascend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/engine/devices/ascend.py index 7d051dfb19..acf77dd274 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/engine/devices/ascend.py @@ -17,7 +17,7 @@ def update_step_context(cls, step_context): single_attention_mask = torch.logical_not( torch.tril( torch.ones(step_context.q_seq_length[i], - step_context.kv_seq_length[i], + (step_context.kv_seq_length[i] + 31) & (~31), dtype=torch.bool).cuda(), diagonal=step_context.kv_seq_length[i] - step_context.q_seq_length[i], From a459fdce7135965377c9548626db48e17b8832aa Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Tue, 20 Aug 2024 16:58:56 +0800 Subject: [PATCH 8/8] fix attn args (#9) --- lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py index c2730ae25b..40c47d7036 100644 --- a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -33,7 +33,7 @@ def flash_context_attention( q_seq_len[i:i + 1], num_q_heads, num_kv_heads, - context.attention_mask[i:i + 1], + attn_mask=context.attention_mask[i:i + 1], attn_output=attn_output, ) else: @@ -51,7 +51,7 @@ def flash_context_attention( kv_seq_len[i:i + 1], num_q_heads, num_kv_heads, - context.attention_mask[i:i + 1], + attn_mask=context.attention_mask[i:i + 1], attn_output=attn_output, )