9
9
context_attention_fwd ,
10
10
context_attention_fwd_no_prompt_cache ,
11
11
)
12
+ from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import (
13
+ context_attention_fwd_with_v ,
14
+ context_attention_fwd_no_prompt_cache_with_v ,
15
+ )
12
16
13
17
from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
14
18
from lightllm .models .deepseek2 .layer_infer .fused_moe import fused_experts , grouped_topk
18
22
from lightllm .models .llama .infer_struct import LlamaInferStateInfo
19
23
from functools import partial
20
24
from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
25
+ import os
21
26
22
27
23
28
class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -55,6 +60,12 @@ def __init__(
55
60
self .softmax_scale = self .softmax_scale * mscale * mscale
56
61
super ().__init__ (layer_num , tp_rank , world_size , network_config , mode )
57
62
self .tp_o_head_num_ = self .tp_q_head_num_
63
+
64
+ self .num_heads = network_config ["num_attention_heads" ]
65
+ self .num_kv_heads = network_config ["num_key_value_heads" ]
66
+ self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
67
+ self .mla_type = "ACCM"
68
+
58
69
return
59
70
60
71
def _bind_attention (self ):
@@ -97,7 +108,12 @@ def _get_qkv(
97
108
98
109
q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim )
99
110
q_nope , q_rope = torch .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
100
- q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
111
+ if infer_state .use_dynamic_prompt_cache and infer_state .is_prefill :
112
+ self .mla_type = "ACCM"
113
+ else :
114
+ self .mla_type = layer_weight .mla_type
115
+ if self .mla_type == "ACCM" :
116
+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
101
117
102
118
layer_weight .kv_a_proj_with_mqa_ .mm (input , out = cache_kv .view (- 1 , self .kv_lora_rank + self .qk_rope_head_dim ))
103
119
@@ -123,11 +139,153 @@ def _get_o(
123
139
input = input .view (- 1 , self .tp_q_head_num_ * self .kv_lora_rank )
124
140
o_tensor = layer_weight .fuse_vo_weight_ .mm (input )
125
141
else :
126
- input = layer_weight .v_b_proj_ .bmm (input .transpose (0 , 1 )).transpose (0 , 1 )
142
+ if self .mla_type == "ACCM" :
143
+ input = layer_weight .v_b_proj_ .bmm (input .transpose (0 , 1 )).transpose (0 , 1 )
127
144
o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
128
145
return o_tensor
129
146
147
+ def _CC_method (
148
+ self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
149
+ ):
150
+ num_local_heads = self .num_heads
151
+ num_local_kv_heads = self .num_kv_heads
152
+ if self .world_size_ > 1 :
153
+ num_local_heads //= self .world_size_
154
+ num_local_kv_heads //= self .world_size_
155
+ if infer_state .use_dynamic_prompt_cache :
156
+ compressed_kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
157
+ # CC
158
+ compressed_kv , k_pe = torch .split ( # (b*s, 1, kv_lora + qk_r)
159
+ compressed_kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
160
+ )
161
+ compressed_kv = compressed_kv .view (- 1 , layer_weight .kv_lora_rank )
162
+ k = self .alloc_tensor (
163
+ [k_pe .shape [0 ], num_local_kv_heads , layer_weight .qk_nope_head_dim + layer_weight .qk_rope_head_dim ],
164
+ dtype = q [0 ].dtype ,
165
+ )
166
+ k [..., layer_weight .qk_nope_head_dim :] = k_pe
167
+ wk = layer_weight .k_b_proj_ .weight .view (- 1 , layer_weight .k_b_proj_ .weight .shape [- 1 ])
168
+ o_tensor = self .alloc_tensor ([compressed_kv .shape [0 ], wk .shape [0 ]], dtype = q [0 ].dtype )
169
+ torch .mm (compressed_kv , wk .transpose (0 , 1 ), out = o_tensor )
170
+ k [..., : layer_weight .qk_nope_head_dim ] = o_tensor .view (- 1 , num_local_kv_heads , layer_weight .qk_nope_head_dim )
171
+ trans_weight = layer_weight .v_b_proj_ .weight .transpose (1 , 2 )
172
+ wv = trans_weight .view (- 1 , trans_weight .shape [- 1 ])
173
+ o_tensor = self .alloc_tensor ([compressed_kv .shape [0 ], wv .shape [0 ]], dtype = q [0 ].dtype )
174
+ torch .mm (compressed_kv , wv .transpose (0 , 1 ), out = o_tensor )
175
+ v = o_tensor .view (- 1 , num_local_kv_heads , layer_weight .qk_nope_head_dim )
176
+ return self ._context_attention_kernel_with_v (q , k , v , infer_state , layer_weight )
177
+
178
+ def _ACC_method (
179
+ self , q , compressed_kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
180
+ ):
181
+ q_ne , q_pe = q
182
+ num_local_heads = self .num_heads
183
+ num_local_kv_heads = self .num_kv_heads
184
+ if self .world_size_ > 1 :
185
+ num_local_heads //= self .world_size_
186
+ num_local_kv_heads //= self .world_size_
187
+ # ACC
188
+ q = self .alloc_tensor (
189
+ [q_ne .shape [0 ], num_local_heads , self .kv_lora_rank + self .qk_rope_head_dim ], dtype = q_ne .dtype
190
+ )
191
+ q [..., self .kv_lora_rank :] = q_pe
192
+ torch .bmm ( # TODO: 转换成einsum 或者 cublas
193
+ q_ne .transpose (0 , 1 ), # (h, b*s, qk_n)
194
+ layer_weight .k_b_proj_ .weight , # (h, qk_n, kv_lora)
195
+ out = q [..., : self .kv_lora_rank ].view (q_ne .shape [1 ], q_ne .shape [0 ], self .kv_lora_rank ),
196
+ ).transpose (
197
+ 0 , 1
198
+ ) # (b*s, h, kv_lora)
199
+ q_nope , q_rope = torch .split ( # (b*s, h, qk_n + qk_r) -> (b*s, h, qk_n), (b*s, h, qk_r)
200
+ q , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1
201
+ )
202
+ if self .enable_opt_decoding_mha :
203
+ import lightllm_ppl_mla
204
+
205
+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
206
+ kvstarts = torch .cat (
207
+ [infer_state .b_start_loc , infer_state .b_start_loc [- 1 :] + infer_state .b_seq_len [- 1 :]], dim = 0
208
+ )
209
+ lightllm_ppl_mla .decode_mla (
210
+ o_tensor ,
211
+ q ,
212
+ compressed_kv [: infer_state .mem_end , :, :],
213
+ infer_state .b_start_loc ,
214
+ kvstarts ,
215
+ self .softmax_scale ,
216
+ q .shape [- 1 ],
217
+ q_nope .shape [- 1 ],
218
+ )
219
+ output_parallel = o_tensor
220
+ else :
221
+ output_parallel = self ._token_gqa_decode_attention_flashdecoding_origin (
222
+ (q_nope , q_rope ), infer_state , layer_weight
223
+ )
224
+ o_tensor = self .alloc_tensor (
225
+ [output_parallel .shape [1 ], output_parallel .shape [0 ], self .qk_nope_head_dim ], dtype = q_ne .dtype
226
+ )
227
+ torch .bmm ( # TODO: 转换成einsum 或者 cublas
228
+ output_parallel .transpose (0 , 1 ), # (h, b*s, kv_lora)
229
+ layer_weight .v_b_proj_ .weight , # (h, kv_lora, vo_d)
230
+ out = o_tensor ,
231
+ ).transpose (
232
+ 0 , 1
233
+ ) # (b*s, h, vo_d)
234
+ return o_tensor
235
+
130
236
def _context_attention_kernel (
237
+ self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
238
+ ) -> torch .Tensor :
239
+ if self .mla_type == "MIX" :
240
+ return self ._context_attention_kernel_with_CC (q , kv , infer_state , layer_weight , out )
241
+ else :
242
+ return self ._context_attention_kernel_origin (q , kv , infer_state , layer_weight , out )
243
+
244
+ def _context_attention_kernel_with_CC (
245
+ self , q , kv , infer_state : LlamaInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
246
+ ) -> torch .Tensor :
247
+ return self ._CC_method (q , kv , infer_state , layer_weight )
248
+
249
+ def _context_attention_kernel_with_v (
250
+ self , q : Tuple [torch .Tensor , torch .Tensor ], kv , v , infer_state : LlamaInferStateInfo , layer_weight , out = None
251
+ ) -> torch .Tensor :
252
+ q_nope , q_rope = q
253
+ nope_head_dim = q_nope .shape [- 1 ]
254
+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
255
+ if infer_state .use_dynamic_prompt_cache :
256
+ context_attention_fwd_with_v (
257
+ q_nope ,
258
+ q_rope ,
259
+ kv [:, :, :nope_head_dim ],
260
+ kv [:, :, nope_head_dim :],
261
+ v ,
262
+ o_tensor .view (- 1 , self .tp_q_head_num_ , nope_head_dim ),
263
+ infer_state .b_req_idx ,
264
+ infer_state .b_start_loc ,
265
+ infer_state .b_seq_len ,
266
+ infer_state .b_ready_cache_len ,
267
+ infer_state .max_len_in_batch ,
268
+ infer_state .req_manager .req_to_token_indexs ,
269
+ self .softmax_scale ,
270
+ )
271
+ else :
272
+ context_attention_fwd_no_prompt_cache_with_v (
273
+ q_nope ,
274
+ q_rope ,
275
+ kv [:, :, :nope_head_dim ],
276
+ kv [:, :, nope_head_dim :],
277
+ v ,
278
+ o_tensor .view (- 1 , self .tp_q_head_num_ , nope_head_dim ),
279
+ infer_state .b_start_loc ,
280
+ infer_state .b_seq_len ,
281
+ infer_state .max_len_in_batch ,
282
+ self .softmax_scale ,
283
+ )
284
+ q_nope = None
285
+ q_rope = None
286
+ return o_tensor
287
+
288
+ def _context_attention_kernel_origin (
131
289
self , q : Tuple [torch .Tensor , torch .Tensor ], kv , infer_state : LlamaInferStateInfo , layer_weight , out = None
132
290
) -> torch .Tensor :
133
291
q_nope , q_rope = q
@@ -166,6 +324,20 @@ def _context_attention_kernel(
166
324
return o_tensor
167
325
168
326
def _token_gqa_decode_attention_flashdecoding (self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
327
+ if self .mla_type == "MIX" :
328
+ return self ._token_gqa_decode_attention_flashdecoding_with_ACC (q , infer_state , layer_weight , out )
329
+ else :
330
+ return self ._token_gqa_decode_attention_flashdecoding_origin (q , infer_state , layer_weight , out )
331
+
332
+ def _token_gqa_decode_attention_flashdecoding_with_ACC (
333
+ self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
334
+ ):
335
+ compressed_kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][: infer_state .mem_end , :, :]
336
+ return self ._ACC_method (q , compressed_kv , infer_state , layer_weight )
337
+
338
+ def _token_gqa_decode_attention_flashdecoding_origin (
339
+ self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None
340
+ ):
169
341
q_nope , q_rope = q
170
342
kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ]
171
343
kv_rope = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank :]
0 commit comments