Skip to content

Commit d67a53d

Browse files
sufubaohiworldwzj
authored andcommitted
Deepseek MTP for dp backend (#923)
1 parent bf96f4b commit d67a53d

File tree

10 files changed

+581
-184
lines changed

10 files changed

+581
-184
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 81 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -311,127 +311,82 @@ def _decode(
311311
return self._token_forward(model_input.input_ids, infer_state)
312312

313313
@torch.no_grad()
314-
def microbatch_overlap_decode(self, batch: DecodeMicroBatch, batch1: DecodeMicroBatch):
315-
assert batch.batch_size == batch1.batch_size
316-
assert batch.mem_indexes.is_cuda
317-
assert batch1.mem_indexes.is_cuda
318-
input_ids, input_ids1 = batch.input_ids, batch1.input_ids
319-
320-
def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
321-
infer_state = self.infer_state_class()
322-
infer_state.is_prefill = False
323-
infer_state.batch_size = cur_batch.batch_size
324-
infer_state.total_token_num = cur_batch.total_token_num
325-
infer_state.max_len_in_batch = cur_batch.max_len_in_batch
326-
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
327-
assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0]
328-
infer_state.b_req_idx = cur_batch.b_req_idx
329-
infer_state.b_seq_len = cur_batch.b_seq_len
330-
infer_state.multimodal_params = None
331-
infer_state.microbatch_index = batch_index
332-
333-
infer_state.mem_manager = self.mem_manager
334-
infer_state.req_manager = self.req_manager
335-
336-
infer_state.mem_index = cur_batch.mem_indexes
337-
infer_state.kv_buffer_shapedtype = (
338-
(cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
339-
self.data_type,
340-
)
341-
infer_state.dist_group = dist_group_manager.get_group(batch_index)
342-
copy_kv_index_to_req(
343-
self.req_manager.req_to_token_indexs, cur_batch.b_req_idx, cur_batch.b_seq_len, infer_state.mem_index
344-
)
345-
return infer_state
314+
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
315+
assert model_input0.batch_size == model_input1.batch_size
316+
assert model_input0.mem_indexes.is_cuda
317+
assert model_input1.mem_indexes.is_cuda
318+
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
346319

347-
infer_state = create_inferstate(batch, 0)
348-
infer_state1 = create_inferstate(batch1, 1)
320+
infer_state0 = self._create_inferstate(model_input0, 0)
321+
copy_kv_index_to_req(
322+
self.req_manager.req_to_token_indexs, model_input0.b_req_idx, model_input0.b_seq_len, infer_state0.mem_index
323+
)
324+
infer_state0.init_some_extra_state(self, input_ids0)
349325

350-
infer_state.init_some_extra_state(self, input_ids)
326+
infer_state1 = self._create_inferstate(model_input1, 1)
327+
copy_kv_index_to_req(
328+
self.req_manager.req_to_token_indexs, model_input1.b_req_idx, model_input1.b_seq_len, infer_state1.mem_index
329+
)
351330
infer_state1.init_some_extra_state(self, input_ids1)
352331

353-
batch_size = batch.batch_size
354-
max_len_in_batch = max(batch.max_len_in_batch, batch1.max_len_in_batch)
332+
batch_size = model_input0.batch_size
333+
max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch)
355334

356335
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
357336
if self.graph.need_capture(batch_size):
358-
infer_state.is_cuda_graph = True
337+
infer_state0.is_cuda_graph = True
359338
infer_state1.is_cuda_graph = True
360339

361-
predict_logits, predict_logits1 = self.graph.capture_decode(
340+
model_output0, model_output1 = self.graph.capture_decode(
362341
self._overlap_tpsp_token_forward,
363-
input_ids,
364-
infer_state,
342+
input_ids0,
343+
infer_state0,
365344
input_ids1=input_ids1,
366345
infer_state1=infer_state1,
367346
)
368347
else:
369-
predict_logits, predict_logits1 = self.graph.replay(
370-
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
348+
model_output0, model_output1 = self.graph.replay(
349+
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
371350
)
372351
else:
373-
predict_logits, predict_logits1 = self._overlap_tpsp_token_forward(
374-
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
352+
model_output0, model_output1 = self._overlap_tpsp_token_forward(
353+
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
375354
)
376-
return predict_logits, predict_logits1
355+
return model_output0, model_output1
377356

378357
@torch.no_grad()
379-
def microbatch_overlap_prefill(self, batch: PrefillMicroBatch, batch1: PrefillMicroBatch):
380-
assert batch.mem_indexes.is_cuda
381-
assert batch1.mem_indexes.is_cuda
382-
input_ids, input_ids1 = batch.input_ids, batch1.input_ids
383-
384-
def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
385-
infer_state = self.infer_state_class()
386-
infer_state.is_prefill = True
387-
infer_state.is_token_healing = self.is_token_healing
388-
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
389-
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
390-
infer_state.batch_size = cur_batch.batch_size
391-
infer_state.total_token_num = cur_batch.total_token_num
392-
infer_state.max_len_in_batch = cur_batch.max_len_in_batch
393-
assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0]
394-
infer_state.b_req_idx = cur_batch.b_req_idx
395-
infer_state.b_seq_len = cur_batch.b_seq_len
396-
if cur_batch.b_ready_cache_len is not None:
397-
infer_state.b_ready_cache_len = cur_batch.b_ready_cache_len
398-
else:
399-
infer_state.b_ready_cache_len = torch.zeros_like(
400-
cur_batch.b_seq_len, dtype=cur_batch.b_seq_len.dtype, device=cur_batch.b_seq_len.device
401-
)
402-
infer_state.multimodal_params = cur_batch.multimodal_params
403-
infer_state.microbatch_index = batch_index
358+
def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput):
359+
assert model_input0.mem_indexes.is_cuda
360+
assert model_input1.mem_indexes.is_cuda
361+
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
404362

405-
infer_state.mem_manager = self.mem_manager
406-
infer_state.req_manager = self.req_manager
407-
408-
infer_state.mem_index = cur_batch.mem_indexes
409-
infer_state.kv_buffer_shapedtype = (
410-
(cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
411-
self.data_type,
412-
)
413-
infer_state.dist_group = dist_group_manager.get_group(batch_index)
414-
init_req_to_token_indexes(
415-
self.req_manager.req_to_token_indexs,
416-
cur_batch.b_req_idx,
417-
cur_batch.b_seq_len,
418-
infer_state.b_ready_cache_len,
419-
cur_batch.max_len_in_batch,
420-
infer_state.mem_index,
421-
)
422-
return infer_state
423-
424-
infer_state = create_inferstate(batch, 0)
425-
infer_state1 = create_inferstate(batch1, 1)
426-
427-
infer_state.init_some_extra_state(self, input_ids)
363+
infer_state0 = self._create_inferstate(model_input0, 0)
364+
init_req_to_token_indexes(
365+
self.req_manager.req_to_token_indexs,
366+
model_input0.b_req_idx,
367+
model_input0.b_seq_len,
368+
infer_state0.b_ready_cache_len,
369+
model_input0.max_len_in_batch,
370+
infer_state0.mem_index,
371+
)
372+
infer_state0.init_some_extra_state(self, input_ids0)
373+
374+
infer_state1 = self._create_inferstate(model_input1, 1)
375+
init_req_to_token_indexes(
376+
self.req_manager.req_to_token_indexs,
377+
model_input1.b_req_idx,
378+
model_input1.b_seq_len,
379+
infer_state1.b_ready_cache_len,
380+
model_input1.max_len_in_batch,
381+
infer_state1.mem_index,
382+
)
428383
infer_state1.init_some_extra_state(self, input_ids1)
429384

430-
predict_logits, predict_logits1 = self._overlap_tpsp_context_forward(
431-
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
385+
model_output0, model_output1 = self._overlap_tpsp_context_forward(
386+
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
432387
)
433388
dist_group_manager.clear_deepep_buffer()
434-
return predict_logits, predict_logits1
389+
return model_output0, model_output1
435390

436391
@final
437392
def _context_forward(self, input_ids, infer_state: InferStateInfo):
@@ -508,9 +463,21 @@ def _overlap_tpsp_token_forward(
508463
predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
509464
input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight
510465
)
511-
466+
512467
g_cache_manager.cache_env_out()
513-
return predict_logits, predict_logits1
468+
is_return_hidden_states = self.spec_algo.is_mtp() or (
469+
self.spec_algo.is_mtp_module() and not self.last_mtp_module
470+
)
471+
model_output = ModelOutput(
472+
logits=predict_logits,
473+
hidden_states=input_embs if is_return_hidden_states else None,
474+
)
475+
476+
model_output1 = ModelOutput(
477+
logits=predict_logits1,
478+
hidden_states=input_embs1 if is_return_hidden_states else None,
479+
)
480+
return model_output, model_output1
514481

515482
@final
516483
def _overlap_tpsp_context_forward(
@@ -528,7 +495,21 @@ def _overlap_tpsp_context_forward(
528495
input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight
529496
)
530497
g_cache_manager.cache_env_out()
531-
return predict_logits, predict_logits1
498+
499+
is_return_hidden_states = self.spec_algo.is_mtp() or (
500+
self.spec_algo.is_mtp_module() and not self.last_mtp_module
501+
)
502+
model_output = ModelOutput(
503+
logits=predict_logits,
504+
hidden_states=input_embs if is_return_hidden_states else None,
505+
)
506+
507+
model_output1 = ModelOutput(
508+
logits=predict_logits1,
509+
hidden_states=input_embs1 if is_return_hidden_states else None,
510+
)
511+
512+
return model_output, model_output1
532513

533514
@final
534515
@torch.no_grad()

lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
2626

2727
self.can_use_mem_size = self.size
2828

29-
rank_in_node = get_current_rank_in_node()
30-
self.shared_can_use_token_num = SharedInt(f"MTP_mem_manger_can_use_token_num_{rank_in_node}")
31-
32-
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
33-
3429
self._init_buffers(
3530
self.size,
3631
dtype,

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def get_chunked_input_token_ids_shift(self, shift=1):
325325
shift_input_ids = np.roll(input_ids, -1 * shift)
326326
chunked_start = self.cur_kv_len
327327
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
328-
is_last_chunked = chunked_end == self.get_cur_total_len() + shift
328+
is_last_chunked = chunked_end == self.get_cur_total_len() - shift
329329
return shift_input_ids[0:chunked_end], is_last_chunked
330330

331331
def get_chuncked_input_token_len(self):

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
88
from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
99
from .dp_backend.impl import DPChunkedPrefillBackend
10+
from .dp_backend.impl_mtp import DPChunkedPrefillWithMTPBackend
1011
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode
1112
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
1213
from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
1212
from lightllm.utils.envs_utils import get_env_start_args
1313

14+
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import (
15+
prepare_prefill_inputs,
16+
prepare_decode_inputs
17+
)
18+
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
1419

1520
class DPChunkedPrefillBackend(ModeBackend):
1621
def __init__(self) -> None:
@@ -24,16 +29,14 @@ def __init__(self) -> None:
2429
pass
2530

2631
def init_custom(self):
27-
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
2832
# 这个地方预先进行一次 prefill 推理,主要是为了填充后续fake请求的第一个token位置,因为填充的decode请求
2933
# 在推理的时候至少是两个token,1个是已经有kv的token,一个是等待计算kv的token,然后生成第三个token,这几个
3034
# token 实际引用的都是 g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,但是需要初始化排除
3135
# nan 值,避免后续构建的fake请求在计算的过程中出现计算错误。
32-
from .pre_process import padded_prepare_prefill_inputs
33-
34-
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal)
35-
self.model.forward(**kwargs)
36-
assert len(run_reqs) == 0 and padded_req_num == 1
36+
self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
37+
model_input, run_reqs = prepare_prefill_inputs([], is_chuncked_mode=True, is_multimodal=self.is_multimodal, pad_for_empty_batch=True)
38+
self.model.forward(model_input)
39+
assert len(run_reqs) == 0 and model_input.batch_size == 1
3740
return
3841

3942
def prefill(self, reqs: List[Tuple]):
@@ -71,15 +74,14 @@ def decode(self):
7174
return
7275

7376
def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
74-
from .pre_process import padded_prepare_prefill_inputs
75-
76-
kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs(
77-
prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal
77+
model_input, run_reqs = prepare_prefill_inputs(
78+
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal, pad_for_empty_batch=True
7879
)
79-
logits = self.model.forward(**kwargs)
80+
model_output: ModelOutput = self.model.forward(model_input)
81+
8082
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
8183
if len(run_reqs) != 0:
82-
logits = logits[0 : len(run_reqs), :]
84+
logits = model_output.logits[0 : len(run_reqs), :]
8385
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
8486
next_token_ids = next_token_ids.detach().cpu().numpy()
8587
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
@@ -89,43 +91,31 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int
8991
return
9092

9193
def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
92-
from .pre_process import padded_prepare_decode_inputs
93-
94-
kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs(
95-
decode_reqs, max_decode_num, is_multimodal=self.is_multimodal
96-
)
97-
logits = self.model.forward(**kwargs)
94+
model_input, run_reqs = prepare_decode_inputs(decode_reqs, pad_for_empty_batch=True)
95+
model_output: ModelOutput = self.model.forward(model_input)
9896

9997
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
10098

10199
if len(run_reqs) != 0:
102-
logits = logits[0 : len(run_reqs), :]
100+
logits = model_output.logits[0 : len(run_reqs), :]
103101
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
104102
next_token_ids = next_token_ids.detach().cpu().numpy()
105103
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
106104
self._post_handle(
107105
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
108106
)
109-
logits = None
110107

111108
def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs):
112109
from .pre_process import padded_overlap_prepare_decode_inputs
113110

114-
(
115-
micro_batch,
116-
run_reqs,
117-
padded_req_num,
118-
micro_batch1,
119-
run_reqs1,
120-
padded_req_num1,
121-
) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
122-
logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1)
111+
micro_input, run_reqs, micro_input1, run_reqs1 = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal)
112+
micro_output, micro_output1 = self.model.microbatch_overlap_decode(micro_input, micro_input1)
123113
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
124114
req_num, req_num1 = len(run_reqs), len(run_reqs1)
125-
all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device)
115+
all_logits = torch.empty((req_num + req_num1, micro_output.logits.shape[1]), dtype=micro_output.logits.dtype, device=micro_output.logits.device)
126116

127-
all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True)
128-
all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
117+
all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True)
118+
all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True)
129119

130120
all_run_reqs = run_reqs + run_reqs1
131121
if all_run_reqs:
@@ -140,21 +130,14 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
140130
def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs):
141131
from .pre_process import padded_overlap_prepare_prefill_inputs
142132

143-
(
144-
micro_batch,
145-
run_reqs,
146-
padded_req_num,
147-
micro_batch1,
148-
run_reqs1,
149-
padded_req_num1,
150-
) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal)
151-
logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1)
133+
micro_input, run_reqs, micro_input1, run_reqs1 = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal)
134+
micro_output, micro_output1 = self.model.microbatch_overlap_prefill(micro_input, micro_input1)
152135
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
153136
req_num, req_num1 = len(run_reqs), len(run_reqs1)
154-
all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device)
137+
all_logits = torch.empty((req_num + req_num1, micro_output.logits.shape[1]), dtype=micro_output.logits.dtype, device=micro_output.logits.device)
155138

156-
all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True)
157-
all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
139+
all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True)
140+
all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True)
158141

159142
all_run_reqs = run_reqs + run_reqs1
160143
if all_run_reqs:

0 commit comments

Comments
 (0)