Skip to content

Commit bd0712e

Browse files
blueswhenniushengxiao
andauthored
feat: add cc, acc method for deepseek2 (#618)
Co-authored-by: niushengxiao <[email protected]>
1 parent f855b48 commit bd0712e

File tree

3 files changed

+610
-2
lines changed

3 files changed

+610
-2
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
context_attention_fwd,
1010
context_attention_fwd_no_prompt_cache,
1111
)
12+
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import (
13+
context_attention_fwd_with_v,
14+
context_attention_fwd_no_prompt_cache_with_v,
15+
)
1216

1317
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
1418
from lightllm.models.deepseek2.layer_infer.fused_moe import fused_experts, grouped_topk
@@ -18,6 +22,7 @@
1822
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
1923
from functools import partial
2024
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
25+
import os
2126

2227

2328
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -55,6 +60,12 @@ def __init__(
5560
self.softmax_scale = self.softmax_scale * mscale * mscale
5661
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
5762
self.tp_o_head_num_ = self.tp_q_head_num_
63+
64+
self.num_heads = network_config["num_attention_heads"]
65+
self.num_kv_heads = network_config["num_key_value_heads"]
66+
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
67+
self.mla_type = "ACCM"
68+
5869
return
5970

6071
def _bind_attention(self):
@@ -97,7 +108,12 @@ def _get_qkv(
97108

98109
q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim)
99110
q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
100-
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
111+
if infer_state.use_dynamic_prompt_cache and infer_state.is_prefill:
112+
self.mla_type = "ACCM"
113+
else:
114+
self.mla_type = layer_weight.mla_type
115+
if self.mla_type == "ACCM":
116+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
101117

102118
layer_weight.kv_a_proj_with_mqa_.mm(input, out=cache_kv.view(-1, self.kv_lora_rank + self.qk_rope_head_dim))
103119

@@ -123,11 +139,153 @@ def _get_o(
123139
input = input.view(-1, self.tp_q_head_num_ * self.kv_lora_rank)
124140
o_tensor = layer_weight.fuse_vo_weight_.mm(input)
125141
else:
126-
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
142+
if self.mla_type == "ACCM":
143+
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
127144
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
128145
return o_tensor
129146

147+
def _CC_method(
148+
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
149+
):
150+
num_local_heads = self.num_heads
151+
num_local_kv_heads = self.num_kv_heads
152+
if self.world_size_ > 1:
153+
num_local_heads //= self.world_size_
154+
num_local_kv_heads //= self.world_size_
155+
if infer_state.use_dynamic_prompt_cache:
156+
compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
157+
# CC
158+
compressed_kv, k_pe = torch.split( # (b*s, 1, kv_lora + qk_r)
159+
compressed_kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
160+
)
161+
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
162+
k = self.alloc_tensor(
163+
[k_pe.shape[0], num_local_kv_heads, layer_weight.qk_nope_head_dim + layer_weight.qk_rope_head_dim],
164+
dtype=q[0].dtype,
165+
)
166+
k[..., layer_weight.qk_nope_head_dim :] = k_pe
167+
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.k_b_proj_.weight.shape[-1])
168+
o_tensor = self.alloc_tensor([compressed_kv.shape[0], wk.shape[0]], dtype=q[0].dtype)
169+
torch.mm(compressed_kv, wk.transpose(0, 1), out=o_tensor)
170+
k[..., : layer_weight.qk_nope_head_dim] = o_tensor.view(-1, num_local_kv_heads, layer_weight.qk_nope_head_dim)
171+
trans_weight = layer_weight.v_b_proj_.weight.transpose(1, 2)
172+
wv = trans_weight.view(-1, trans_weight.shape[-1])
173+
o_tensor = self.alloc_tensor([compressed_kv.shape[0], wv.shape[0]], dtype=q[0].dtype)
174+
torch.mm(compressed_kv, wv.transpose(0, 1), out=o_tensor)
175+
v = o_tensor.view(-1, num_local_kv_heads, layer_weight.qk_nope_head_dim)
176+
return self._context_attention_kernel_with_v(q, k, v, infer_state, layer_weight)
177+
178+
def _ACC_method(
179+
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
180+
):
181+
q_ne, q_pe = q
182+
num_local_heads = self.num_heads
183+
num_local_kv_heads = self.num_kv_heads
184+
if self.world_size_ > 1:
185+
num_local_heads //= self.world_size_
186+
num_local_kv_heads //= self.world_size_
187+
# ACC
188+
q = self.alloc_tensor(
189+
[q_ne.shape[0], num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim], dtype=q_ne.dtype
190+
)
191+
q[..., self.kv_lora_rank :] = q_pe
192+
torch.bmm( # TODO: 转换成einsum 或者 cublas
193+
q_ne.transpose(0, 1), # (h, b*s, qk_n)
194+
layer_weight.k_b_proj_.weight, # (h, qk_n, kv_lora)
195+
out=q[..., : self.kv_lora_rank].view(q_ne.shape[1], q_ne.shape[0], self.kv_lora_rank),
196+
).transpose(
197+
0, 1
198+
) # (b*s, h, kv_lora)
199+
q_nope, q_rope = torch.split( # (b*s, h, qk_n + qk_r) -> (b*s, h, qk_n), (b*s, h, qk_r)
200+
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
201+
)
202+
if self.enable_opt_decoding_mha:
203+
import lightllm_ppl_mla
204+
205+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
206+
kvstarts = torch.cat(
207+
[infer_state.b_start_loc, infer_state.b_start_loc[-1:] + infer_state.b_seq_len[-1:]], dim=0
208+
)
209+
lightllm_ppl_mla.decode_mla(
210+
o_tensor,
211+
q,
212+
compressed_kv[: infer_state.mem_end, :, :],
213+
infer_state.b_start_loc,
214+
kvstarts,
215+
self.softmax_scale,
216+
q.shape[-1],
217+
q_nope.shape[-1],
218+
)
219+
output_parallel = o_tensor
220+
else:
221+
output_parallel = self._token_gqa_decode_attention_flashdecoding_origin(
222+
(q_nope, q_rope), infer_state, layer_weight
223+
)
224+
o_tensor = self.alloc_tensor(
225+
[output_parallel.shape[1], output_parallel.shape[0], self.qk_nope_head_dim], dtype=q_ne.dtype
226+
)
227+
torch.bmm( # TODO: 转换成einsum 或者 cublas
228+
output_parallel.transpose(0, 1), # (h, b*s, kv_lora)
229+
layer_weight.v_b_proj_.weight, # (h, kv_lora, vo_d)
230+
out=o_tensor,
231+
).transpose(
232+
0, 1
233+
) # (b*s, h, vo_d)
234+
return o_tensor
235+
130236
def _context_attention_kernel(
237+
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
238+
) -> torch.Tensor:
239+
if self.mla_type == "MIX":
240+
return self._context_attention_kernel_with_CC(q, kv, infer_state, layer_weight, out)
241+
else:
242+
return self._context_attention_kernel_origin(q, kv, infer_state, layer_weight, out)
243+
244+
def _context_attention_kernel_with_CC(
245+
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
246+
) -> torch.Tensor:
247+
return self._CC_method(q, kv, infer_state, layer_weight)
248+
249+
def _context_attention_kernel_with_v(
250+
self, q: Tuple[torch.Tensor, torch.Tensor], kv, v, infer_state: LlamaInferStateInfo, layer_weight, out=None
251+
) -> torch.Tensor:
252+
q_nope, q_rope = q
253+
nope_head_dim = q_nope.shape[-1]
254+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
255+
if infer_state.use_dynamic_prompt_cache:
256+
context_attention_fwd_with_v(
257+
q_nope,
258+
q_rope,
259+
kv[:, :, :nope_head_dim],
260+
kv[:, :, nope_head_dim:],
261+
v,
262+
o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim),
263+
infer_state.b_req_idx,
264+
infer_state.b_start_loc,
265+
infer_state.b_seq_len,
266+
infer_state.b_ready_cache_len,
267+
infer_state.max_len_in_batch,
268+
infer_state.req_manager.req_to_token_indexs,
269+
self.softmax_scale,
270+
)
271+
else:
272+
context_attention_fwd_no_prompt_cache_with_v(
273+
q_nope,
274+
q_rope,
275+
kv[:, :, :nope_head_dim],
276+
kv[:, :, nope_head_dim:],
277+
v,
278+
o_tensor.view(-1, self.tp_q_head_num_, nope_head_dim),
279+
infer_state.b_start_loc,
280+
infer_state.b_seq_len,
281+
infer_state.max_len_in_batch,
282+
self.softmax_scale,
283+
)
284+
q_nope = None
285+
q_rope = None
286+
return o_tensor
287+
288+
def _context_attention_kernel_origin(
131289
self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
132290
) -> torch.Tensor:
133291
q_nope, q_rope = q
@@ -166,6 +324,20 @@ def _context_attention_kernel(
166324
return o_tensor
167325

168326
def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
327+
if self.mla_type == "MIX":
328+
return self._token_gqa_decode_attention_flashdecoding_with_ACC(q, infer_state, layer_weight, out)
329+
else:
330+
return self._token_gqa_decode_attention_flashdecoding_origin(q, infer_state, layer_weight, out)
331+
332+
def _token_gqa_decode_attention_flashdecoding_with_ACC(
333+
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
334+
):
335+
compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][: infer_state.mem_end, :, :]
336+
return self._ACC_method(q, compressed_kv, infer_state, layer_weight)
337+
338+
def _token_gqa_decode_attention_flashdecoding_origin(
339+
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
340+
):
169341
q_nope, q_rope = q
170342
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank]
171343
kv_rope = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :]

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def __init__(
7575
self.disable_qk_absorb = disable_qk_absorb
7676
self.disable_vo_absorb = disable_vo_absorb
7777
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
78+
# mla_type = "ACCM", "MIX"
79+
# MIX是prefilled CC,decoding ACC
80+
self.mla_type = "MIX"
81+
if not disable_vo_absorb or not disable_qk_absorb:
82+
self.mla_type = "ACCM"
7883
return
7984

8085
def _parse_config(self):

0 commit comments

Comments
 (0)