diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 873ab00d4..c7760e995 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1,8 +1,10 @@ import os # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +import copy import json import torch +import torch.nn.functional as F from typing import final from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -19,7 +21,9 @@ from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch + logger = init_logger(__name__) @@ -40,6 +44,7 @@ class TpPartBaseModel: infer_state_class = InferStateInfo def __init__(self, kvargs): + self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] self.max_total_token_num = kvargs["max_total_token_num"] @@ -71,6 +76,8 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3" + self._init_datatype() self._init_config() self._verify_must() @@ -92,6 +99,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() + self._init_padded_req() self._init_cudagraph() self._check_max_len_infer() torch.cuda.empty_cache() @@ -226,263 +234,151 @@ def _init_custom(self): pass @torch.no_grad() - def forward( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids: torch.Tensor, - mem_indexes: torch.Tensor, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, - b_ready_cache_len: torch.Tensor = None, - multimodal_params=None, - is_prefill=True, - ): - assert mem_indexes.is_cuda - - if is_prefill: - return self._prefill( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len, - multimodal_params, - ) + def forward(self, model_input: ModelInput): + assert model_input.mem_indexes.is_cuda + + if model_input.is_prefill: + return self._prefill(model_input) else: - return self._decode( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - multimodal_params, - ) + return self._decode(model_input) - def _prefill( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len, - multimodal_params, - ): + def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): infer_state = self.infer_state_class() - infer_state.is_prefill = True + infer_state.is_prefill = model_input.is_prefill infer_state.is_token_healing = self.is_token_healing infer_state.return_all_prompt_logics = self.return_all_prompt_logics infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - infer_state.batch_size = batch_size - infer_state.total_token_num = total_token_num - infer_state.max_len_in_batch = max_len_in_batch - assert b_req_idx.shape[0] == b_seq_len.shape[0] - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - if b_ready_cache_len is not None: - infer_state.b_ready_cache_len = b_ready_cache_len - else: - infer_state.b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=b_seq_len.dtype, device=b_seq_len.device) - infer_state.multimodal_params = multimodal_params + infer_state.batch_size = model_input.batch_size + infer_state.total_token_num = model_input.total_token_num + infer_state.max_len_in_batch = model_input.max_len_in_batch + assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] + infer_state.b_req_idx = model_input.b_req_idx + infer_state.b_seq_len = model_input.b_seq_len + if model_input.is_prefill: + if model_input.b_ready_cache_len is not None: + infer_state.b_ready_cache_len = model_input.b_ready_cache_len + else: + infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len) + + infer_state.multimodal_params = model_input.multimodal_params infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - infer_state.mem_index = mem_indexes + infer_state.mem_index = model_input.mem_indexes infer_state.kv_buffer_shapedtype = ( - (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + (model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), self.data_type, ) - infer_state.dist_group = dist_group_manager.get_default_group() + infer_state.microbatch_index = microbatch_index + infer_state.dist_group = dist_group_manager.get_group(microbatch_index) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - b_req_idx, - b_seq_len, - infer_state.b_ready_cache_len, - max_len_in_batch, - infer_state.mem_index, - ) + # 特殊模型,特殊模式的特定变量初始化操作。 + infer_state.deepseekv3_mtp_draft_input_hiddens = model_input.deepseekv3_mtp_draft_input_hiddens - infer_state.init_some_extra_state(self, input_ids) - predict_logics = self._context_forward(input_ids, infer_state) - return predict_logics + return infer_state - def _decode( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - multimodal_params, - ): - infer_state = self.infer_state_class() - infer_state.is_prefill = False - infer_state.batch_size = batch_size - infer_state.total_token_num = total_token_num - infer_state.max_len_in_batch = max_len_in_batch - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - assert b_req_idx.shape[0] == b_seq_len.shape[0] - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.multimodal_params = multimodal_params + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): + if model_input.batch_size == new_batch_size: + return model_input - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager + assert model_input.batch_size <= new_batch_size - infer_state.mem_index = mem_indexes - infer_state.kv_buffer_shapedtype = ( - (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, + padded_batch_size = new_batch_size - model_input.batch_size + new_model_input = copy.copy(model_input) + new_model_input.batch_size = new_batch_size + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) + new_model_input.b_req_idx = F.pad( + new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID + ) + new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + new_model_input.mem_indexes = F.pad( + new_model_input.mem_indexes, + (0, padded_batch_size), + mode="constant", + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) - infer_state.dist_group = dist_group_manager.get_default_group() - copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) - - infer_state.init_some_extra_state(self, input_ids) - if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch): - if self.graph.need_capture(batch_size): - infer_state.is_cuda_graph = True - predict_logics = self.graph.capture_decode(self._token_forward, input_ids, infer_state) - else: - predict_logics = self.graph.replay(input_ids, infer_state) - else: - predict_logics = self._token_forward(input_ids, infer_state) - return predict_logics - @torch.no_grad() - def microbatch_overlap_decode(self, batch: DecodeMicroBatch, batch1: DecodeMicroBatch): - assert batch.batch_size == batch1.batch_size - assert batch.mem_indexes.is_cuda - assert batch1.mem_indexes.is_cuda - input_ids, input_ids1 = batch.input_ids, batch1.input_ids - - def create_inferstate(cur_batch: DecodeMicroBatch, batch_index): - infer_state = self.infer_state_class() - infer_state.is_prefill = False - infer_state.batch_size = cur_batch.batch_size - infer_state.total_token_num = cur_batch.total_token_num - infer_state.max_len_in_batch = cur_batch.max_len_in_batch - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0] - infer_state.b_req_idx = cur_batch.b_req_idx - infer_state.b_seq_len = cur_batch.b_seq_len - infer_state.multimodal_params = None - infer_state.microbatch_index = batch_index - - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager - - infer_state.mem_index = cur_batch.mem_indexes - infer_state.kv_buffer_shapedtype = ( - (cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, + # 特殊模型,特殊模式的特殊变量的特殊 padding + if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: + new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + new_batch_size=new_batch_size, ) - infer_state.dist_group = dist_group_manager.get_group(batch_index) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, cur_batch.b_req_idx, cur_batch.b_seq_len, infer_state.mem_index - ) - return infer_state - infer_state = create_inferstate(batch, 0) - infer_state1 = create_inferstate(batch1, 1) + return new_model_input - infer_state.init_some_extra_state(self, input_ids) - infer_state1.init_some_extra_state(self, input_ids1) + def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int): + padded_batch_size = model_output.logits.shape[0] + if padded_batch_size == origin_batch_size: + return model_output + new_model_output = copy.copy(model_output) + new_model_output.logits = new_model_output.logits[0:origin_batch_size] - batch_size = batch.batch_size - max_len_in_batch = max(batch.max_len_in_batch, batch1.max_len_in_batch) + # 特殊模型,特殊模式的特殊变量的特殊 unpad + if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens + new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] - if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch): - if self.graph.need_capture(batch_size): - infer_state.is_cuda_graph = True - infer_state1.is_cuda_graph = True + return new_model_output - predict_logics, predict_logics1 = self.graph.capture_decode( - self._overlap_tpsp_token_forward, - input_ids, - infer_state, - input_ids1=input_ids1, - infer_state1=infer_state1, - ) - else: - predict_logics, predict_logics1 = self.graph.replay( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 - ) - else: - predict_logics, predict_logics1 = self._overlap_tpsp_token_forward( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 + def _prefill( + self, + model_input: ModelInput, + ): + infer_state = self._create_inferstate(model_input) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input.b_req_idx, + model_input.b_seq_len, + infer_state.b_ready_cache_len, + model_input.max_len_in_batch, + infer_state.mem_index, + ) + + infer_state.init_some_extra_state(self, model_input.input_ids) + return self._context_forward(model_input.input_ids, infer_state) + + def _decode( + self, + model_input: ModelInput, + ) -> ModelOutput: + if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) + padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) + infer_state = self._create_inferstate(padded_model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, ) - return predict_logics, predict_logics1 + infer_state.init_some_extra_state(self, padded_model_input.input_ids) - @torch.no_grad() - def microbatch_overlap_prefill(self, batch: PrefillMicroBatch, batch1: PrefillMicroBatch): - assert batch.mem_indexes.is_cuda - assert batch1.mem_indexes.is_cuda - input_ids, input_ids1 = batch.input_ids, batch1.input_ids - - def create_inferstate(cur_batch: PrefillMicroBatch, batch_index): - infer_state = self.infer_state_class() - infer_state.is_prefill = True - infer_state.is_token_healing = self.is_token_healing - infer_state.return_all_prompt_logics = self.return_all_prompt_logics - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - infer_state.batch_size = cur_batch.batch_size - infer_state.total_token_num = cur_batch.total_token_num - infer_state.max_len_in_batch = cur_batch.max_len_in_batch - assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0] - infer_state.b_req_idx = cur_batch.b_req_idx - infer_state.b_seq_len = cur_batch.b_seq_len - if cur_batch.b_ready_cache_len is not None: - infer_state.b_ready_cache_len = cur_batch.b_ready_cache_len - else: - infer_state.b_ready_cache_len = torch.zeros_like( - cur_batch.b_seq_len, dtype=cur_batch.b_seq_len.dtype, device=cur_batch.b_seq_len.device + if self.graph.need_capture(find_graph_batch_size): + infer_state.is_cuda_graph = True + model_output: ModelOutput = self.graph.capture_decode( + self._token_forward, padded_model_input.input_ids, infer_state ) - infer_state.multimodal_params = cur_batch.multimodal_params - infer_state.microbatch_index = batch_index - - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager + else: + model_output: ModelOutput = self.graph.replay(padded_model_input.input_ids, infer_state) - infer_state.mem_index = cur_batch.mem_indexes - infer_state.kv_buffer_shapedtype = ( - (cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, + model_output = self._create_unpad_decode_model_output( + model_output, origin_batch_size=model_input.batch_size ) - infer_state.dist_group = dist_group_manager.get_group(batch_index) - init_req_to_token_indexes( + else: + infer_state = self._create_inferstate(model_input) + copy_kv_index_to_req( self.req_manager.req_to_token_indexs, - cur_batch.b_req_idx, - cur_batch.b_seq_len, - infer_state.b_ready_cache_len, - cur_batch.max_len_in_batch, + infer_state.b_req_idx, + infer_state.b_seq_len, infer_state.mem_index, ) - return infer_state + infer_state.init_some_extra_state(self, model_input.input_ids) + model_output = self._token_forward(model_input.input_ids, infer_state) - infer_state = create_inferstate(batch, 0) - infer_state1 = create_inferstate(batch1, 1) - - infer_state.init_some_extra_state(self, input_ids) - infer_state1.init_some_extra_state(self, input_ids1) - - predict_logics, predict_logics1 = self._overlap_tpsp_context_forward( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 - ) - dist_group_manager.clear_deepep_buffer() - return predict_logics, predict_logics1 + return model_output @final def _context_forward(self, input_ids, infer_state: InferStateInfo): @@ -499,10 +395,20 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo): input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logics = post_method(input_embs, infer_state, self.pre_post_weight) + predict_logits = post_method(input_embs, infer_state, self.pre_post_weight) g_cache_manager.cache_env_out() - return predict_logics + + model_output = ModelOutput(logits=predict_logits) + + # 特殊模型特殊模式的额外输出 + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = input_embs + + # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 + # 该调用没有实际意义 + dist_group_manager.clear_deepep_buffer() + return model_output @final def _token_forward(self, input_ids, infer_state: InferStateInfo): @@ -521,53 +427,214 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo): input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logics = post_method(input_embs, infer_state, self.pre_post_weight) + predict_logits = post_method(input_embs, infer_state, self.pre_post_weight) + + if self.is_deepseekv3_mtp_mode: + graph_out_hiddens = g_cache_manager.alloc_tensor( + input_embs.shape, + data_type=input_embs.dtype, + is_graph_out=True, + microbatch_index=infer_state.microbatch_index, + graph_out_key=520, + ) + graph_out_hiddens.copy_(input_embs) g_cache_manager.cache_env_out() - return predict_logics + + model_output = ModelOutput(logits=predict_logits) + + # 特殊模型特殊模式的额外输出 + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + + return model_output + + @torch.no_grad() + def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput): + assert model_input0.mem_indexes.is_cuda + assert model_input1.mem_indexes.is_cuda + input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids + + infer_state0 = self._create_inferstate(model_input0, 0) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input0.b_req_idx, + model_input0.b_seq_len, + infer_state0.b_ready_cache_len, + model_input0.max_len_in_batch, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, input_ids0) + + infer_state1 = self._create_inferstate(model_input1, 1) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input1.b_req_idx, + model_input1.b_seq_len, + infer_state1.b_ready_cache_len, + model_input1.max_len_in_batch, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, input_ids1) + + model_output0, model_output1 = self._overlap_tpsp_context_forward( + input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 + ) + + # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 + # 该调用没有实际意义 + dist_group_manager.clear_deepep_buffer() + return model_output0, model_output1 + + @torch.no_grad() + def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): + assert model_input0.batch_size == model_input1.batch_size + assert model_input0.mem_indexes.is_cuda + assert model_input1.mem_indexes.is_cuda + + origin_batch_size = model_input0.batch_size + max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) + + if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) + padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) + infer_state0 = self._create_inferstate(padded_model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) + infer_state1 = self._create_inferstate(padded_model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + + if self.graph.need_capture(find_graph_batch_size): + infer_state0.is_cuda_graph = True + infer_state1.is_cuda_graph = True + + model_output0, model_output1 = self.graph.capture_decode( + self._overlap_tpsp_token_forward, + padded_model_input0.input_ids, + infer_state0, + input_ids1=padded_model_input1.input_ids, + infer_state1=infer_state1, + ) + else: + model_output0, model_output1 = self.graph.replay( + padded_model_input0.input_ids, + infer_state0, + input_ids1=padded_model_input1.input_ids, + infer_state1=infer_state1, + ) + model_output0 = self._create_unpad_decode_model_output(model_output0, origin_batch_size=origin_batch_size) + model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) + else: + infer_state0 = self._create_inferstate(model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, model_input0.input_ids) + infer_state1 = self._create_inferstate(model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, model_input1.input_ids) + + model_output0, model_output1 = self._overlap_tpsp_token_forward( + model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1 + ) + return model_output0, model_output1 @final - def _overlap_tpsp_token_forward( + def _overlap_tpsp_context_forward( self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo ): - g_cache_manager.cache_env_in( - is_cuda_graph=infer_state.is_cuda_graph, - cur_batch_size=infer_state.batch_size, - cuda_graph_max_batch_size=self.graph_max_batch_size, - ) - input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( + g_cache_manager.cache_env_in() + input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight ) - for i in range(self.layers_num): - input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( + input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) - - predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward( + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight ) - g_cache_manager.cache_env_out() - return predict_logics, predict_logics1 + + model_output = ModelOutput(logits=predict_logits) + model_output1 = ModelOutput(logits=predict_logits1) + + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = input_embs + model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1 + + return model_output, model_output1 @final - def _overlap_tpsp_context_forward( + def _overlap_tpsp_token_forward( self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo ): - g_cache_manager.cache_env_in() - input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( + g_cache_manager.cache_env_in( + is_cuda_graph=infer_state.is_cuda_graph, + cur_batch_size=infer_state.batch_size, + cuda_graph_max_batch_size=self.graph_max_batch_size, + ) + input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward( input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight ) + for i in range(self.layers_num): - input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( + input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) - predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward( + + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight ) + + if self.is_deepseekv3_mtp_mode: + graph_out_hiddens = g_cache_manager.alloc_tensor( + input_embs.shape, + data_type=input_embs.dtype, + is_graph_out=True, + microbatch_index=0, + graph_out_key=520, + ) + graph_out_hiddens.copy_(input_embs) + graph_out_hiddens1 = g_cache_manager.alloc_tensor( + input_embs1.shape, + data_type=input_embs1.dtype, + is_graph_out=True, + microbatch_index=1, + graph_out_key=520, + ) + graph_out_hiddens1.copy_(input_embs1) + g_cache_manager.cache_env_out() - return predict_logics, predict_logics1 + + model_output = ModelOutput(logits=predict_logits) + model_output1 = ModelOutput(logits=predict_logits1) + + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1 + + return model_output, model_output1 @final @torch.no_grad() @@ -590,20 +657,22 @@ def _check_max_len_infer(self): b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") total_token_num = self.batch_max_tokens - logics = self.forward( - 1, - total_token_num, - self.batch_max_tokens, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, + model_input = ModelInput( + batch_size=1, + total_token_num=total_token_num, + max_len_in_batch=self.batch_max_tokens, + input_ids=dummy_input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, is_prefill=True, - multimodal_params=[], + b_ready_cache_len=b_ready_cache_len, + ) + model_output = self.forward( + model_input, ) - prob_out = torch.softmax(logics, dim=-1) - logics = None + prob_out = torch.softmax(model_output.logits, dim=-1) + del model_output torch.argmax(prob_out, dim=1, keepdim=True) prob_out = None self.req_manager.free_all() @@ -620,3 +689,63 @@ def _check_max_len_infer(self): logger.error(exception_str) raise Exception(exception_str) return + + @final + @torch.no_grad() + def _init_padded_req(self): + """ + 对 padded 所使用的req 进行初始化, 目前有非常多的地方需要使用,所以将其初始化固定为固定流程 + """ + # 做一次 同步 + torch.distributed.barrier() + + # prefill init padding req. + prefill_input_len = 1 + batch_size = 1 + dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor( + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + mem_indexes = torch.tensor( + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + total_token_num = prefill_input_len * batch_size + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=prefill_input_len, + input_ids=dummy_input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + multimodal_params=[], + **self._gen_special_model_input(total_token_num), + ) + + model_output: ModelOutput = self.forward(model_input) + del model_input + del dummy_input_ids + del b_req_idx + del mem_indexes + del b_seq_len + del b_ready_cache_len + del model_output + torch.cuda.empty_cache() + return + + def _gen_special_model_input(self, token_num: int): + special_model_input = {} + + is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(self.__class__) + if is_deepseekv3_mtp_draft_model: + special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( + token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" + ) + else: + special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None + + return special_model_input diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py new file mode 100644 index 000000000..d9d131056 --- /dev/null +++ b/lightllm/common/basemodel/batch_objs.py @@ -0,0 +1,39 @@ +import torch +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelInput: + # 通用变量 + batch_size: int + total_token_num: int + max_len_in_batch: int + input_ids: torch.Tensor + mem_indexes: torch.Tensor + b_req_idx: torch.Tensor + b_seq_len: torch.Tensor + is_prefill: bool = False + b_ready_cache_len: torch.Tensor = None + multimodal_params: list = field(default_factory=list) + + # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 + # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 + + # deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下 + # 的 draft 模型的输入 + deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + + +@dataclass +class ModelOutput: + # 通用变量 + logits: torch.Tensor + + # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 + # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 + + # deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model + # 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens + # 输入 + deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 41cd59039..dc615eb46 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -1,10 +1,14 @@ import os import torch import copy +import bisect +from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from .infer_struct import InferStateInfo + logger = init_logger(__name__) @@ -17,15 +21,48 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch - self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + self.args = get_env_start_args() + self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + + # gen cuda graph batch_sizes + # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] + # and [graph_split_batch_size + graph_grow_step_size, + # graph_split_batch_size + 2 * graph_grow_step_size, ..., self.max_batch_size] + graph_split_batch_size = self.args.graph_split_batch_size + max_batch_size = self.max_batch_size + graph_grow_step_size = self.args.graph_grow_step_size + + batch_sizes = [i for i in range(1, graph_split_batch_size + 1)] + for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): + batch_sizes.append(_batch_size) + + batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) + batch_sizes.append(max_batch_size) + batch_sizes.sort() + + self.cuda_graph_batch_sizes = batch_sizes + assert batch_sizes[-1] == self.max_batch_size + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def need_capture(self, batch_size): - return batch_size not in self.graph + find_batch_size = self.find_closest_graph_batch_size(batch_size) + if find_batch_size is not None: + return find_batch_size not in self.graph + else: + assert False, "dead code" - def _capture_decode(self, decode_func, input_ids, infer_state): + def find_closest_graph_batch_size(self, batch_size): + index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) + if index < len(self.cuda_graph_batch_sizes): + find_batch_size = self.cuda_graph_batch_sizes[index] + return find_batch_size + else: + return None + + def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] @@ -46,12 +83,19 @@ def _capture_decode(self, decode_func, input_ids, infer_state): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - predict_logics = decode_func(input_ids, infer_state) - self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics) + model_output = decode_func(input_ids, infer_state) + self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() - return predict_logics + return model_output - def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids1, infer_state1): + def _capture_decode_overlap( + self, + decode_func, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: torch.Tensor, + infer_state1: InferStateInfo, + ): dist_group: CustomProcessGroup = infer_state.dist_group dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() @@ -68,20 +112,27 @@ def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - predict_logics, predict_logics1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) + model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) self.graph[batch_size] = ( graph_obj, input_ids, infer_state, input_ids1, infer_state1, - predict_logics, - predict_logics1, + model_output, + model_output1, ) graph_obj.replay() - return predict_logics, predict_logics1 + return model_output, model_output1 - def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, infer_state1=None): + def capture_decode( + self, + decode_func, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: Optional[torch.Tensor] = None, + infer_state1: Optional[torch.Tensor] = None, + ): """ Capture the cuda graph for the decoding stage. input_ids1 and infer_state1 is used for the overlap. @@ -92,15 +143,21 @@ def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, i assert input_ids1 is None and infer_state1 is None return self._capture_decode(decode_func, input_ids, infer_state) - def _replay(self, input_ids, infer_state): + def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo): batch_size = input_ids.shape[0] - graph_obj, graph_input_ids, graph_infer_state, graph_predict_logics = self.graph[batch_size] + graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size] graph_input_ids.copy_(input_ids) graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() - return graph_predict_logics + return graph_output - def _replay_overlap(self, input_ids, infer_state, input_ids1, infer_state1): + def _replay_overlap( + self, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: torch.Tensor, + infer_state1: InferStateInfo, + ): batch_size = input_ids.shape[0] ( graph_obj, @@ -108,15 +165,15 @@ def _replay_overlap(self, input_ids, infer_state, input_ids1, infer_state1): graph_infer_state, graph_input_ids1, graph_infer_state1, - graph_predict_logics, - graph_predict_logics1, + graph_model_output, + graph_model_output1, ) = self.graph[batch_size] graph_input_ids.copy_(input_ids) graph_infer_state.copy_for_cuda_graph(infer_state) graph_input_ids1.copy_(input_ids1) graph_infer_state1.copy_for_cuda_graph(infer_state1) graph_obj.replay() - return graph_predict_logics, graph_predict_logics1 + return graph_model_output, graph_model_output1 def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): if self.enable_decode_microbatch_overlap: @@ -128,52 +185,42 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): @torch.no_grad() def warmup(self, model): logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.") - for batch_size in range(self.max_batch_size, 0, -1): - # dummy prefill - prefill_input_len = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + # for typing easy + from .basemodel import TpPartBaseModel + + model: TpPartBaseModel = model + + # decode cuda graph init + for batch_size in self.cuda_graph_batch_sizes[::-1]: + seq_len = 2 + total_token_num = batch_size * seq_len + max_len_in_batch = self.graph_max_len_in_batch + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - total_token_num = prefill_input_len * batch_size - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - multimodal_params=[], + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) - mem_indexes = None - prob_out = torch.softmax(logics, dim=-1) - logics = None - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - prob_out = None - predict_ids = predict_ids.detach().cpu().numpy() - torch.cuda.empty_cache() + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len.fill_(seq_len) - # dummy decoding, capture the cudagraph - total_token_num += batch_size - b_seq_len += 1 - mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda() - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len + 1, - torch.from_numpy(predict_ids).cuda().reshape(-1), - mem_indexes, - b_req_idx, - b_seq_len, + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, is_prefill=False, + **model._gen_special_model_input(batch_size), ) - mem_indexes = None + model_output: ModelOutput = model.forward(model_input) + del model_output + del input_ids + del mem_indexes + del b_req_idx + del b_seq_len + model.mem_manager.free_all() model.req_manager.free_all() # release local tensors @@ -181,6 +228,7 @@ def warmup(self, model): if isinstance(var_value, torch.Tensor): del locals()[var_name] torch.cuda.empty_cache() + logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph." @@ -189,64 +237,52 @@ def warmup(self, model): @torch.no_grad() def warmup_overlap(self, model): logger.info("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it.") - for batch_size in range(self.max_batch_size, 0, -1): + # for typing easy + from .basemodel import TpPartBaseModel + + model: TpPartBaseModel = model + + for batch_size in self.cuda_graph_batch_sizes[::-1]: decode_batches = [] for micro_batch_index in [0, 1]: - # dummy prefill - prefill_input_len = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + # dummy decoding, capture the cudagraph + seq_len = 2 + total_token_num = batch_size * seq_len + max_len_in_batch = self.graph_max_len_in_batch + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - total_token_num = prefill_input_len * batch_size - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - multimodal_params=[], + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) - mem_indexes = None - prob_out = torch.softmax(logics, dim=-1) - logics = None - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - prob_out = None - predict_ids = predict_ids.detach().cpu().numpy() - torch.cuda.empty_cache() - - # dummy decoding, capture the cudagraph - total_token_num += batch_size - b_seq_len += 1 - mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda() + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len.fill_(seq_len) - micro_batch = DecodeMicroBatch( + micro_batch = ModelInput( + is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=prefill_input_len + 1, - input_ids=torch.from_numpy(predict_ids).cuda().reshape(-1), + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, mem_indexes=mem_indexes, b_req_idx=b_req_idx, b_seq_len=b_seq_len, + **model._gen_special_model_input(batch_size), ) decode_batches.append(micro_batch) + del micro_batch for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] torch.cuda.empty_cache() + _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) model.mem_manager.free_all() model.req_manager.free_all() + del decode_batches + # release local tensors for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index a946219fb..021de6843 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -2,7 +2,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager from lightllm.distributed import CustomProcessGroup -from typing import Tuple, Any +from typing import Tuple, Any, Optional from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params @@ -54,6 +54,15 @@ def __init__(self): self.max_q_seq_len: int = None self.max_kv_seq_len: int = None + # 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在 + # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 + # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 + + # deepseekv3 mtp draft model 使用的额外输入参数, + # 在开启 mtp_mode == deepseekv3 时,mtp draft model + # 的输入会用到,其他模型和场景都不会用到 + self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: ( @@ -82,7 +91,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ) = gen_decode_params(b_seq_len=self.b_seq_len) self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] - def copy_for_cuda_graph(self, new_infer_state): + def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): if isinstance(attr_value, torch.Tensor): attr_ = getattr(self, attr_name, None) diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 570654b55..0cf0c7ae9 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -44,7 +44,8 @@ def __del__(self): class CudaGraphCacheTensorManager: def __init__(self, cuda_graph_max_batch_size: int): self.cuda_graph_max_batch_size = cuda_graph_max_batch_size - self.graph_out_tensor_dict: Dict[int, torch.Tensor] = {} + # Dict[graph_out_key, Dict[microbatch_index, tensor_chache]] + self.graph_out_tensor_dict: Dict[int, Dict[int, torch.Tensor]] = collections.defaultdict(dict) self.managed_total_tensor_bytes = 0 return @@ -56,6 +57,7 @@ def alloc_tensor_for_cuda_graph( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: assert microbatch_index in [0, 1] if not is_graph_out: @@ -66,13 +68,16 @@ def alloc_tensor_for_cuda_graph( max_size = size // cur_batch_size * self.cuda_graph_max_batch_size # graph out tensor, 只有一个, 不需要进行引用计数管理 - if microbatch_index not in self.graph_out_tensor_dict: + microbatch_index_to_tensor_cache = self.graph_out_tensor_dict[graph_out_key] + + if microbatch_index not in microbatch_index_to_tensor_cache: graph_out_tensor = torch.empty((max_size,), dtype=data_type, device=device, requires_grad=False) logger.info(f"pid {os.getpid()} cuda graph alloc graph out mem {shape} {data_type} {size} {max_size}") self.managed_total_tensor_bytes += graph_out_tensor.element_size() * graph_out_tensor.numel() logger.info(f"cuda graph managed_total_tensor_bytes: {self.managed_total_tensor_bytes}") - self.graph_out_tensor_dict[microbatch_index] = graph_out_tensor - return self.graph_out_tensor_dict[microbatch_index][0:size].view(shape) + microbatch_index_to_tensor_cache[microbatch_index] = graph_out_tensor + + return self.graph_out_tensor_dict[graph_out_key][microbatch_index][0:size].view(shape) class CacheTensorManager: def __init__(self): @@ -119,6 +124,7 @@ def alloc_tensor( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: # shape 类型转换 if isinstance(shape, list): @@ -126,7 +132,13 @@ def alloc_tensor( # 是 cuda graph的时候,由cuda graph manager 接管 if self.is_cuda_graph: return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph( - self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out, microbatch_index + self.cuda_graph_cur_batch_size, + shape, + data_type, + device, + is_graph_out, + microbatch_index, + graph_out_key, ) # 回收可能消亡的 tensor @@ -191,6 +203,7 @@ def alloc_tensor( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: return torch.empty(shape, dtype=data_type, device=device, requires_grad=False) diff --git a/lightllm/common/basemodel/microbatch_overlap_objs.py b/lightllm/common/basemodel/microbatch_overlap_objs.py deleted file mode 100644 index 35835cd34..000000000 --- a/lightllm/common/basemodel/microbatch_overlap_objs.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -from dataclasses import dataclass - - -@dataclass -class DecodeMicroBatch: - batch_size: int - total_token_num: int - max_len_in_batch: int - input_ids: torch.Tensor - mem_indexes: torch.Tensor - b_req_idx: torch.Tensor - b_seq_len: torch.Tensor - - -@dataclass -class PrefillMicroBatch: - batch_size: int - total_token_num: int - max_len_in_batch: int - input_ids: torch.Tensor - mem_indexes: torch.Tensor - b_req_idx: torch.Tensor - b_seq_len: torch.Tensor - b_ready_cache_len: torch.Tensor - multimodal_params: list diff --git a/lightllm/common/basemodel/triton_kernel/gen_mtp_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_mtp_prefill_params.py new file mode 100644 index 000000000..8b1ca912f --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/gen_mtp_prefill_params.py @@ -0,0 +1,51 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def _gen_mtp_new_input_ids( + b1_cu_q_seq_len_ptr, old_input_ids_ptr, insert_tail_input_ids, new_input_ids_ptr, BLOCK: tl.constexpr +): + batch_index = tl.program_id(0) + start_index = tl.load(b1_cu_q_seq_len_ptr + batch_index) + end_index = tl.load(b1_cu_q_seq_len_ptr + batch_index + 1) + offs = tl.arange(0, BLOCK) + + for iter_start_index in tl.range(start_index + 1, end_index, BLOCK, num_stages=3): + input_offs = iter_start_index + offs + t_input_ids = tl.load(old_input_ids_ptr + input_offs, mask=input_offs < end_index, other=0) + tl.store(new_input_ids_ptr + input_offs - 1, t_input_ids, mask=input_offs - 1 < end_index - 1) + tail_token_id = tl.load(insert_tail_input_ids + batch_index) + tl.store(new_input_ids_ptr + end_index - 1, tail_token_id) + return + + +@torch.no_grad() +def gen_mtp_new_input_ids( + input_ids: torch.Tensor, + b_next_token_ids: torch.Tensor, + b_seq_len: torch.Tensor, + b_ready_cache_len: Optional[torch.Tensor] = None, +): + assert len(b_seq_len.shape) == 1 + batch_size = b_seq_len.shape[0] + if b_ready_cache_len is None: + b_q_seq_len = b_seq_len + else: + b_q_seq_len = b_seq_len - b_ready_cache_len + b1_cu_q_seq_len = F.pad(torch.cumsum(b_q_seq_len, dim=0, dtype=torch.int32), pad=(1, 0), mode="constant", value=0) + new_input_ids = torch.empty_like(input_ids) + BLOCK = 512 + num_warps = 4 + _gen_mtp_new_input_ids[(batch_size,)]( + b1_cu_q_seq_len_ptr=b1_cu_q_seq_len, + old_input_ids_ptr=input_ids, + insert_tail_input_ids=b_next_token_ids, + new_input_ids_ptr=new_input_ids, + BLOCK=BLOCK, + num_warps=num_warps, + ) + return new_input_ids diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 9a61e5294..643ccd894 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -175,7 +175,7 @@ def update_reqs_token_counter( if not self.enable_gpu_buffer_for_out_token_id_counter: for req_obj, next_token_id in zip(req_objs, next_token_ids): - if req_obj.need_out_token_id_statistics: + if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0: req_obj.out_token_id_count[next_token_id] += 1 else: b_req_idx = torch.tensor( diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index d91dd6ede..9101cb963 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -97,12 +97,18 @@ def _init_mem_manager(self): manager_class = Deepseek2MemoryManager if "triton_fp8kv" in self.mode: manager_class = Deepseek2FP8KVMemoryManager + + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3": + added_mtp_layer_num += get_env_start_args().mtp_step + self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) return @@ -190,5 +196,4 @@ def _init_to_get_yarn_rotary(self): @final def _context_forward(self, input_ids, infer_state): predict_logics = super()._context_forward(input_ids, infer_state) - dist_group_manager.clear_deepep_buffer() return predict_logics diff --git a/lightllm/models/deepseek_mtp/__init__.py b/lightllm/models/deepseek_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_infer/__init__.py b/lightllm/models/deepseek_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..66a41da73 --- /dev/null +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,60 @@ +import torch + +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward + + +class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_context_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def context_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_mtp/layer_weights/__init__.py b/lightllm/models/deepseek_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..f5b805647 --- /dev/null +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # 与DeepseekV3模型共享 + self.wte_weight_ = None + self.lm_head_weight_ = None + return + + def load_hf_weights(self, weights): + if "model.layers.0.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + if "model.layers.0.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + if "model.layers.0.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + if "model.layers.0.shared_head.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py new file mode 100644 index 000000000..2e2e95187 --- /dev/null +++ b/lightllm/models/deepseek_mtp/model.py @@ -0,0 +1,46 @@ +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class Deepseek3MTPModel(Deepseek2TpPartModel): + + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + super()._init_weights() + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 528361e69..e9943b05f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -330,9 +330,29 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--graph_max_batch_size", type=int, + default=256, + help="""Maximum batch size that can be captured by the cuda graph for decodign stage.""", + ) + parser.add_argument( + "--graph_split_batch_size", + type=int, + default=32, + help=""" + Controls the interval for generating CUDA graphs during decoding. + CUDA graphs will be generated continuously for values ranging from 1 up to the specified + graph_split_batch_size. For values from graph_split_batch_size to graph_max_batch_size, + a new CUDA graph will be generated for every increment of graph_grow_step_size. + Properly configuring this parameter can help optimize the performance of CUDA graph execution. + """, + ) + parser.add_argument( + "--graph_grow_step_size", + type=int, default=16, - help="""Maximum batch size that can be captured by the cuda graph for decodign stage. - The default value is 8. It will turn into eagar mode if encounters a larger value.""", + help=""" + For batch_size values from graph_split_batch_size to graph_max_batch_size, + a new CUDA graph will be generated for every increment of graph_grow_step_size. + """, ) parser.add_argument( "--graph_max_len_in_batch", @@ -391,4 +411,27 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) + parser.add_argument( + "--mtp_mode", + choices=["deepseekv3", None], + default=None, + help="""supported mtp mode, None is not enable mtp, """, + ) + parser.add_argument( + "--mtp_draft_model_dir", + type=str, + default=None, + help="""Path to the draft model for the MTP multi-prediction feature, + used for loading the MTP multi-output token model.""", + ) + parser.add_argument( + "--mtp_step", + type=int, + default=0, + help="""Specifies the number of additional tokens to predict using the draft model. + Currently, this feature supports only the DeepSeekV3 model. + Increasing this value allows for more predictions, + but ensure that the model is compatible with the specified step count. + currently, deepseekv3 model only support 1 step""", + ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3ce8c9872..de1e690a2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -115,6 +115,14 @@ def normal_or_p_d_start(args): if args.diverse_mode: assert args.router_token_ratio == 0.0 + # mtp params check + if args.mtp_mode is not None: + assert args.mtp_draft_model_dir is not None + assert args.mtp_step > 0 + else: + assert args.mtp_draft_model_dir is None + assert args.mtp_step == 0 + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index d4e69fd36..5594df6a0 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -2,3 +2,4 @@ from .req import Req, FinishStatus from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray +from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index c8d8476e5..7bf5ee239 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -94,6 +94,10 @@ class Req(ctypes.Structure): ("reward_score", ctypes.c_float), # 请求回复累计概率和 ("cumlogprob", ctypes.c_float), + # mtp draft model 多输出命中接受的token数量 + ("mtp_accepted_token_num", ctypes.c_int), + # mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化 + ("_mtp_step", ctypes.c_int), ] def get_str(self): @@ -145,6 +149,8 @@ def init( self.create_prompt_ids_shm_array() self.chunked_prefill_size = chunked_prefill_size self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids + self.mtp_accepted_token_num = 0 + self._mtp_step = get_env_start_args().mtp_step self.post_init() @@ -264,8 +270,8 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): return (a_len, b_len) def get_decode_need_tokens(self): - - return 1 + # 当开启 mtp 模式以后,每一次 decode 需要的 token 数量会增加 + return self._mtp_step + 1 def get_first_router_need_tokens(self): @@ -305,7 +311,12 @@ def get_decode_need_tokens(self): """ chunkedprefill 调度模式的实现 """ - return min(self.input_len + self.shm_cur_output_len - self.shm_cur_kv_len, self.chunked_prefill_size) + # 当开启 mtp 模式以后,每一次 decode 需要的 token 数量会增加 + need_tokens = min(self.input_len + self.shm_cur_output_len - self.shm_cur_kv_len, self.chunked_prefill_size) + if need_tokens == 1: + need_tokens = self._mtp_step + 1 + + return need_tokens def get_first_router_need_tokens(self): diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index e9a17d2d5..8a43d983d 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -30,6 +30,8 @@ class StartArgs: running_max_req_size: int = field(default=1000) tp: int = field(default=1) dp: int = field(default=1) + nnodes: int = field(default=1) + node_rank: int = field(default=0) max_req_total_len: int = field(default=2048 + 1024) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) @@ -74,7 +76,9 @@ class StartArgs: visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) - graph_max_batch_size: int = field(default=16) + graph_max_batch_size: int = field(default=256) + graph_split_batch_size: int = field(default=32) + graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=8192) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) @@ -85,3 +89,6 @@ class StartArgs: sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + mtp_mode: Optional[str] = field(default=None) + mtp_draft_model_dir: Optional[str] = field(default=None) + mtp_step: int = field(default=0) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d25b91864..8def96733 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -53,7 +53,7 @@ def __init__( self.multinode_req_manager = None self.nnodes = args.nnodes - self._shm_lock_pool = AtomicShmArrayLock("lightllm_resource_lock", 1) + self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) self.node_rank = args.node_rank self.transfer_lock = asyncio.Lock() # the lock for transfer to next module in multi node mode. @@ -542,6 +542,10 @@ async def _wait_to_token_package( x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" prompt_cache_ratio = prompt_cache_len / prompt_tokens + + mtp_avg_token_per_step = out_token_counter / max( + (out_token_counter - metadata["mtp_accepted_token_num"]), 1 + ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") logger.info( f"X-Request-Id:{x_request_id} " @@ -552,6 +556,7 @@ async def _wait_to_token_package( f"prompt_token_num:{prompt_tokens} " f"prompt_cache_len:{prompt_cache_len} " f"prompt_cache_ratio:{prompt_cache_ratio} " + f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) if group_request_id < 0: # health 探测请求,不记录日志和监控 @@ -654,6 +659,7 @@ async def handle_loop(self): "special": special, "count_output_tokens": count_output_tokens, "prompt_cache_len": req.prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, } if self.args.return_all_prompt_logprobs: metadata.update(req.get_all_prompt_metadata()) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 2811e4228..45e82ff3d 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,7 +22,7 @@ from .req_queue import build_req_queue from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -35,11 +35,12 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name + logger = init_logger(__name__) class RouterManager: - def __init__(self, args, router_port, detokenization_port, metric_port): + def __init__(self, args: StartArgs, router_port, detokenization_port, metric_port): self.args = args self.model_weightdir = args.model_dir self.world_size = args.tp @@ -62,6 +63,8 @@ def __init__(self, args, router_port, detokenization_port, metric_port): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.mtp_step = args.mtp_step + # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) for dp_index in range(self.dp_size_in_node): @@ -383,8 +386,7 @@ async def _prefill_batch(self, batch: Batch): self.overlap_event.set() await self.model_rpc_client.prefill(reqs) batch.filter_out_finished_req(self.shm_req_manager) - # 发个None包触发一下detokenization - self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self._send_detokenization_pack() logger.debug(f"Prefill Batch: {batch.simple_log()} \n") self.metric_client.histogram_observe( @@ -400,8 +402,8 @@ async def _decode_batch(self, batch: Batch): # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。 if batch is not None: batch.filter_out_finished_req(self.shm_req_manager) - # 发个None包触发一下detokenization - self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + + self._send_detokenization_pack() self.metric_client.histogram_observe( "lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode" ) @@ -424,6 +426,14 @@ def _can_decode(self, batch: Batch, dp_index: int): batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num ) + def _send_detokenization_pack(self): + # 发 mtp_step + 1 个 None 包触发一下 detokenization, 因为在开启 mtp feature 以后,每一步 + # 生成的 token 数量最多为 mtp_step + 1 个,如果不及时触发 detokenization, 会带来一些性能 + # 损失 + for _ in range(self.mtp_step + 1): + self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + return + def get_used_tokens(self, dp_index): if not self.args.disable_dynamic_prompt_cache: return ( diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index dfdfd43de..10b68245c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -17,6 +17,7 @@ from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.multimodal_params import MultimodalParams from lightllm.utils.custom_kernel_utis import custom_cat +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -260,6 +261,12 @@ def __init__( self.need_out_token_id_statistics = True self.out_token_id_count: Dict[int, int] = None + # mtp_gen_token_ids 用于处理一个请求可以通过mtp进行很多token的预先生成 + # 的技术,在没有开启 mtp 功能的时候,这个成员变量不会有任何的实际实用意义。 + # 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的 + # 步骤中需要重新进行校验。 + self.mtp_gen_token_ids: List[int] = [] + def init_all(self): if self.initialized is False: self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) @@ -330,6 +337,10 @@ def set_next_gen_token_id(self, next_token_id: int, logprob: float): self.shm_req.shm_logprobs.arr[index] = logprob return + def update_mtp_accepted_token_num(self, accept_token_num: int): + # 用于统计 mtp 的接受率 + self.shm_req.mtp_accepted_token_num += accept_token_num + def get_last_gen_token(self): return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1] diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 4594eec28..7ad15f00f 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -7,8 +7,16 @@ from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .dp_backend.impl import DPChunkedPrefillBackend +from .dp_backend.impl_mtp import DPChunkedPrefillWithMTPBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode +from .continues_batch.impl_mtp import ContinuesBatchWithMTPBackend + +# mtp and pd mode backend +from .continues_batch.pd_mode.decode_node_impl.decode_impl_mtp_for_dp import DPForMtpDecodeNode +from .continues_batch.pd_mode.decode_node_impl.decode_impl_mtp import ContinuesBatchBackendForMtpDecodeNode +from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_mtp import ChunckedPrefillForMtpPrefillNode +from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_mtp_for_dp_chuncked import DPChunkedForMtpPrefillNode diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0d33163ac..329dc9f3b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -15,9 +15,11 @@ from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp @@ -33,26 +35,26 @@ def __init__(self) -> None: pass def init_model(self, kvargs): - self.args = kvargs.get("args", None) + self.args: StartArgs = kvargs.get("args", None) + assert self.args is not None # p d 分离模式下会有特殊的一些初始化, 所以需要传递 # 模式参数到模型的初始化过程中进行控制 - self.run_mode = "normal" if self.args is None else self.args.run_mode + self.run_mode = self.args.run_mode self.is_multimodal = False self.nnodes = self.args.nnodes self.node_rank = self.args.node_rank - self.tp_rank = kvargs["rank_id"] self.world_size = kvargs["world_size"] - self.dp_size = kvargs.get("dp_size", 1) + self.dp_size = self.args.dp # dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, self.dp_size // self.nnodes) self.load_way = kvargs["load_way"] self.mode = kvargs["mode"] - self.disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False) - self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None) - self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) - self.use_dynamic_prompt_cache = not kvargs.get("disable_dynamic_prompt_cache", False) + self.disable_chunked_prefill = self.args.disable_chunked_prefill + self.chunked_prefill_size = self.args.chunked_prefill_size + self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs + self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) - self.disable_cudagraph = kvargs.get("disable_cudagraph", False) + self.disable_cudagraph = self.args.disable_cudagraph self.cache = {} self.logger = init_logger(__name__) @@ -113,6 +115,7 @@ def init_model(self, kvargs): "run_mode": self.run_mode, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) + self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( RadixCache( @@ -260,6 +263,7 @@ def _post_handle( for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): req_obj: InferReq = req_obj shm_req = req_obj.shm_req + finish_status = req_obj.finish_status if is_chuncked_mode: new_kv_len = req_obj.get_chuncked_input_token_len() else: @@ -269,13 +273,8 @@ def _post_handle( if self.is_master_in_dp: shm_req.shm_cur_kv_len = req_obj.cur_kv_len - # 这个地方主要是为了提前判断是否存在abort的情况,如果abort了 - # 直接将请求放入finished 处理队列中。 - if req_obj.is_finished_or_aborted(): - finished_req_ids.append(shm_req.request_id) - continue - - # 对于没有到达需要输出 token 阶段的请求,直接略过 + # 对于没有到达需要输出 token 阶段的请求,直接略过, 说明还 + # 处于chuncked prefill kv 填充的阶段。 if req_obj.cur_kv_len < req_obj.get_cur_total_len(): continue @@ -283,13 +282,26 @@ def _post_handle( req_obj.set_next_gen_token_id(next_token_id, next_token_logprob) req_obj.cur_output_len += 1 + # 这里提前判定的主要作用是: + # 在 mtp mode 下,可以存在同一个 req 对象的多次处理, + # 在这种情况下, 如果前一步接收的mtp token 已经导致了请求 + # 达到了finished 状态,后续的请求就不再进行后续的复杂流程 + # 判断和处理,但是,因为 mtp 多请求还是导致了kv 的使用,所以 + # 还是需要更新对应的 input_tokens 和 cur_kv_len 信息,否则 + # 在 filter req 的时候,容易导致kv 管理的泄露和插入radix cache + # 的信息不完整等问题。 + if finish_status.is_finished(): + finished_req_ids.append(shm_req.request_id) + continue + + # 更新判断请求的 finished 状态 req_obj.update_finish_status(self.eos_id) if extra_post_req_handle_func is not None: extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob) # 判断是否已经满足生成结束条件。 - if req_obj.is_finished_or_aborted(): + if finish_status.is_finished(): finished_req_ids.append(shm_req.request_id) if self.is_master_in_dp: @@ -298,7 +310,7 @@ def _post_handle( # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 shm_req.shm_cur_output_len = req_obj.cur_output_len - if req_obj.finish_status.is_finished(): + if finish_status.is_finished(): shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 shm_req.finish_status = req_obj.finish_status @@ -308,8 +320,13 @@ def _post_handle( req_objs=run_reqs, next_token_ids=next_token_ids ) + # mtp_mode 模式下,因为存在重复对象,需要进行去重操作。 + if self.args.mtp_mode is not None: + finished_req_ids = list(set(finished_req_ids)) + if do_filter_finished_reqs: g_infer_context.filter(finished_req_ids) + return finished_req_ids # 一些可以复用的通用功能函数 @@ -359,6 +376,76 @@ def _filter_reqs(self, reqs: List[InferReq]): def _trans_req_ids_to_req_objs(self, req_ids: List[int]) -> List[InferReq]: return [g_infer_context.requests_mapping[req_id] for req_id in req_ids] + # 对mtp 运行模式下的请求进行校验和过滤,保留校验成功的请求对象,并释放不再使用的kv 的 mem_index + def _verify_mtp(self, run_reqs: List[InferReq], next_token_ids_cpu: np.ndarray, input_mem_indexes_cpu: np.ndarray): + verify_ok_reqs = [] + verify_ok_req_indexes = [] + verify_ok_req_last_indexes = [] + need_free_mem_indexes = [] + grouped_reqs = self._group_mtp_run_reqs(run_reqs, next_token_ids_cpu, input_mem_indexes_cpu) + for req_group in grouped_reqs: + pre_req, pre_out_token_id, _, pre_index = req_group[0] + verify_ok_reqs.append(pre_req) + verify_ok_req_indexes.append(pre_index) + need_verify = True + verify_ok_count = 0 + for i in range(1, len(req_group)): + cur_req, cur_out_token_id, cur_mem_index, cur_index = req_group[i] + cur_req: InferReq = cur_req + # cur_req 的输入,等于pre_req 的输出,表示校验成功 + if need_verify and cur_req.mtp_gen_token_ids[i - 1] == pre_out_token_id: + verify_ok_reqs.append(cur_req) + verify_ok_req_indexes.append(cur_index) + pre_req, pre_out_token_id, _, pre_index = ( + cur_req, + cur_out_token_id, + cur_mem_index, + cur_index, + ) + verify_ok_count += 1 + continue + + need_verify = False + need_free_mem_indexes.append(cur_mem_index) + + verify_ok_req_last_indexes.append(verify_ok_req_indexes[-1]) + + # 清理每个请求上的 mtp_gen_token_ids, 并更新接受率信息 + pre_req.mtp_gen_token_ids = [] + if self.is_master_in_dp: + pre_req.update_mtp_accepted_token_num(accept_token_num=verify_ok_count) + + return verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes + + def _group_mtp_run_reqs(self, reqs: List[InferReq], next_token_ids_cpu: np.ndarray, input_mem_indexes: np.ndarray): + if not reqs: + return [] + + grouped_reqs = [] + current_group = [(reqs[0], next_token_ids_cpu[0], input_mem_indexes[0], 0)] + + for i in range(1, len(reqs)): + req = reqs[i] + if req.req_id == current_group[-1][0].req_id: + current_group.append((req, next_token_ids_cpu[i], input_mem_indexes[i], i)) + else: + grouped_reqs.append(current_group) + current_group = [(req, next_token_ids_cpu[i], input_mem_indexes[i], i)] + + grouped_reqs.append(current_group) + return grouped_reqs + + def _gen_argmax_token_ids(self, model_output: ModelOutput): + logits = model_output.logits + probs = torch.softmax(logits, dim=-1) + draft_next_token_ids_gpu = torch.argmax(probs, dim=-1) + return draft_next_token_ids_gpu, draft_next_token_ids_gpu.detach().cpu().numpy() + + def _update_reqs_mtp_gen_token_ids(self, reqs: List[InferReq], mtp_draft_next_token_ids: np.ndarray): + for req, token_id in zip(reqs, mtp_draft_next_token_ids): + req.mtp_gen_token_ids.append(token_id) + return + def preload_prompt_cache_kv_buffer(self, model_cfg): self.logger.info("Preload prompt cache kv buffer.") cur_rank = dist.get_rank() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 8dad113d4..00528fec7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -6,7 +6,7 @@ from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -38,37 +38,37 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py index 752d5748f..b083a7263 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py @@ -4,7 +4,7 @@ from .impl import ChunkedPrefillBackend from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -41,8 +41,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -53,16 +54,18 @@ def decode(self): self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False ) - logits = None + del model_output + del logits # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -73,7 +76,8 @@ def decode(self): self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False ) - logits = None + del model_output + del logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index ed092b42b..40401b895 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -6,7 +6,7 @@ from .impl import ChunkedPrefillBackend from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -66,8 +66,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -91,16 +92,18 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_state_fsm, ) - logits = None + del model_output + del logits # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -123,7 +126,8 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_state_fsm, ) - logits = None + del model_output + del logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index 09d9f0fb8..701333116 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -2,7 +2,7 @@ from .impl import ChunkedPrefillBackend from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -50,8 +50,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -77,16 +78,18 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, ) - logits = None + del model_output + del logits # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -113,7 +116,8 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, ) - logits = None + del model_output + del logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 1c7a3a36d..8cdd840e6 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -2,7 +2,7 @@ from typing import List, Tuple from .impl import ChunkedPrefillBackend -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -48,8 +48,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -75,16 +76,18 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_xgrammer_fsm, ) - logits = None + del model_output + del logits # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -109,7 +112,8 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_xgrammer_fsm, ) - logits = None + del model_output + del logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index 052a180c0..b0eb2b58f 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -4,7 +4,7 @@ from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.infer_batch import g_infer_context -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -30,10 +30,11 @@ def decode(self): g_infer_context.filter_reqs(aborted_reqs) if prefill_reqs: - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True @@ -48,8 +49,9 @@ def decode(self): ) if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py index db1e680ec..1462ca84e 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py @@ -3,7 +3,7 @@ from typing import List, Tuple from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import prepare_prefill_inputs from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample @@ -17,12 +17,15 @@ def prefill(self, run_reqs: List[Tuple]): req_ids = self._init_reqs(run_reqs, init_req_obj=True) req_objs = self._trans_req_ids_to_req_objs(req_ids) - kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal) - - prompt_all_logits = self.model.forward(**kwargs) - input_ids = kwargs["input_ids"] - b_ready_cache_len = kwargs["b_ready_cache_len"] - b_seq_len = kwargs["b_seq_len"] + model_input, run_reqs = prepare_prefill_inputs( + req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + ) + + model_output = self.model.forward(model_input) + prompt_all_logits = model_output.logits + input_ids = model_input.input_ids + b_ready_cache_len = model_input.b_ready_cache_len + b_seq_len = model_input.b_seq_len last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 logits = prompt_all_logits[last_index, :] diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py index 569da5150..9425b64b5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py @@ -2,7 +2,7 @@ from typing import List, Tuple from .impl import ContinuesBatchBackend from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import prepare_prefill_inputs from lightllm.server.core.objs import FinishStatus @@ -14,9 +14,12 @@ def prefill(self, reqs: List[Tuple]): req_ids = self._init_reqs(reqs, init_req_obj=True) req_objs = self._trans_req_ids_to_req_objs(req_ids) - kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal) + model_input, run_reqs = prepare_prefill_inputs( + req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + ) - scores: torch.Tensor = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + scores: torch.Tensor = model_output.logits scores = scores.unsqueeze(1).detach().cpu().float().numpy() next_token_id = 1 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py new file mode 100644 index 000000000..c0ac89458 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py @@ -0,0 +1,184 @@ +import os +import torch +import numpy as np +from typing import List, Tuple +from transformers.configuration_utils import PretrainedConfig +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.server.router.model_infer.mode_backend.pre import ( + prepare_prefill_inputs, + prepare_decode_inputs, +) +from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( + prepare_mtp_prefill_inputs, +) +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelOutput + + +logger = init_logger(__name__) + + +class ContinuesBatchWithMTPBackend(ModeBackend): + def __init__(self) -> None: + super().__init__() + + # 支持双模型 + def init_model(self, kvargs): + super().init_model(kvargs) + self._init_mtp_draft_model(kvargs) + return + + def _init_mtp_draft_model(self, main_kvargs: dict): + self.mtp_step = self.args.mtp_step + self.draft_models: List[Deepseek3MTPModel] = [] + + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" + for i in range(self.mtp_step): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + mtp_model_kvargs = { + "weight_dir": self.args.mtp_draft_model_dir, + "max_total_token_num": self.model.mem_manager.size, + "load_way": main_kvargs["load_way"], + "mode": main_kvargs["mode"], + "max_req_num": main_kvargs.get("max_req_num", 1000), + "max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5), + "is_token_healing": False, + "return_all_prompt_logics": False, + "use_dynamic_prompt_cache": self.use_dynamic_prompt_cache, + "disable_chunked_prefill": self.disable_chunked_prefill, + "data_type": main_kvargs.get("data_type", "float16"), + "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), + "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), + "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), + "mem_fraction": main_kvargs["mem_fraction"], + "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), + "quant_type": main_kvargs.get("quant_type", None), + "quant_cfg": main_kvargs.get("quant_cfg", None), + "run_mode": "normal", + "main_model": self.model, + "mem_layer_start": self.model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + assert mtp_model_cfg["model_type"] == "deepseek_v3" + assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + + self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids + ) + + if aborted_reqs: + g_infer_context.filter_reqs(aborted_reqs) + + if prefill_reqs: + self.normal_mtp_prefill_reqs( + prefill_reqs=prefill_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs + ) + + if decode_reqs: + self.normal_mtp_decode(decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_mtp_prefill_reqs( + self, prefill_reqs: List[InferReq], uninit_reqs: List[InferReq], ok_finished_reqs: List[InferReq] + ): + model_input, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal + ) + model_output = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + next_token_ids_gpu, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=not self.disable_chunked_prefill, + do_filter_finished_reqs=False, + ) + + # mtp kv fill + draft_next_token_ids_gpu = next_token_ids_gpu + draft_model_output = model_output + draft_model_input = model_input + # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 + for draft_model_idx in range(self.mtp_step): + draft_model_input = prepare_mtp_prefill_inputs( + model_input=draft_model_input, + b_next_token_ids=draft_next_token_ids_gpu, + deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + ) + + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids_gpu, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_model_output) + return + + def normal_mtp_decode( + self, decode_reqs: List[InferReq], uninit_reqs: List[InferReq], ok_finished_reqs: List[InferReq] + ): + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + next_token_ids_gpu, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # verify + mem_indexes_cpu = model_input.mem_indexes.detach().cpu().numpy() + verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes = self._verify_mtp( + run_reqs, next_token_ids_cpu, mem_indexes_cpu + ) + + self._post_handle( + verify_ok_reqs, + next_token_ids_cpu[verify_ok_req_indexes], + next_token_logprobs_cpu[verify_ok_req_indexes], + is_chuncked_mode=False, + do_filter_finished_reqs=False, + ) + + # share some inference info with the main model + draft_model_input = model_input + draft_model_output = model_output + draft_next_token_ids = next_token_ids_gpu + # process the draft model output + for draft_model_idx in range(self.mtp_step): + + draft_model_input.input_ids = draft_next_token_ids + draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + # spec decode: MTP + draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_model_output) + + unique_reqs = [run_reqs[index] for index in verify_ok_req_last_indexes] + self._update_reqs_mtp_gen_token_ids( + reqs=unique_reqs, mtp_draft_next_token_ids=draft_next_token_ids_cpu[verify_ok_req_last_indexes] + ) + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index d75f44f20..81d3de6a2 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -11,7 +11,7 @@ from lightllm.server.core.objs import FinishStatus from lightllm.server.pd_io_struct import UpKVStatus from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.pre import prepare_decode_inputs from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from .up_status import UpStatusManager from rpyc.utils.server import ThreadedServer @@ -69,8 +69,9 @@ def decode(self): if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py index 614bc9bc5..dd6afb034 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py @@ -1,24 +1,13 @@ -import os import torch import torch.multiprocessing as mp import torch.distributed as dist -import threading -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from typing import List, Tuple -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams -from lightllm.server.core.objs import FinishStatus -from lightllm.server.pd_io_struct import UpKVStatus +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from .up_status import UpStatusManager -from rpyc.utils.server import ThreadedServer -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask -from lightllm.utils.device_utils import kv_trans_use_p2p +from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from .decode_impl import ContinuesBatchBackendForDecodeNode +from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend logger = init_logger(__name__) @@ -32,16 +21,6 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: def init_custom(self): super().init_custom() self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - # 这个地方预先进行一次 prefill 推理,主要是为了填充后续fake请求的第一个token位置,因为填充的decode请求 - # 在推理的时候至少是两个token,1个是已经有kv的token,一个是等待计算kv的token,然后生成第三个token,这几个 - # token 实际引用的都是 g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,但是需要初始化排除 - # nan 值,避免后续构建的fake请求在计算的过程中出现计算错误。 - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) - self.model.forward(**kwargs) - assert len(run_reqs) == 0 and padded_req_num == 1 - return def prefill(self, reqs: List[Tuple]): @@ -61,59 +40,9 @@ def decode(self): max_decode_num = self.reduce_tensor.item() if max_decode_num != 0: if not self.enable_decode_microbatch_overlap: - self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + DPChunkedPrefillBackend.normal_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) else: - self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) - - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - return - - def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs + DPChunkedPrefillBackend.overlap_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) - kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( - decode_reqs, max_decode_num, is_multimodal=self.is_multimodal - ) - logits = self.model.forward(**kwargs) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - if len(run_reqs) != 0: - logits = logits[0 : len(run_reqs), :] - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) - return - - def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_overlap_prepare_decode_inputs, - ) - - ( - micro_batch, - run_reqs, - padded_req_num, - micro_batch1, - run_reqs1, - padded_req_num1, - ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) - - logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - req_num, req_num1 = len(run_reqs), len(run_reqs1) - all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) - - all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) - all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) - - all_run_reqs = run_reqs + run_reqs1 - if all_run_reqs: - next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp.py new file mode 100644 index 000000000..5d7c461cc --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp.py @@ -0,0 +1,36 @@ +import torch.multiprocessing as mp +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from .decode_impl import ContinuesBatchBackendForDecodeNode +from ...impl_mtp import ContinuesBatchWithMTPBackend +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) + + +class ContinuesBatchBackendForMtpDecodeNode(ContinuesBatchBackendForDecodeNode): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__(info_queue=info_queue, mem_queue=mem_queue) + + def init_model(self, kvargs): + super().init_model(kvargs) + ContinuesBatchWithMTPBackend._init_mtp_draft_model(self, kvargs) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=False, + ) + # p d 分离模式下, decode 节点不可能存在需要prefill操作的请求 + assert len(prefill_reqs) == 0 + + self._filter_reqs(aborted_reqs) + + if decode_reqs: + ContinuesBatchWithMTPBackend.normal_mtp_decode( + self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py new file mode 100644 index 000000000..61b7e07f6 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py @@ -0,0 +1,43 @@ +import torch.multiprocessing as mp +import torch.distributed as dist +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger +from .decode_impl_for_dp import DPForDecodeNode +from ....dp_backend.impl_mtp import DPChunkedPrefillWithMTPBackend + +logger = init_logger(__name__) + + +class DPForMtpDecodeNode(DPForDecodeNode): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__(info_queue, mem_queue) + return + + def init_model(self, kvargs): + super().init_model(kvargs) + DPChunkedPrefillWithMTPBackend._init_mtp_draft_model(self, kvargs) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids + ) + assert len(prefill_reqs) == 0 + + self._filter_reqs(aborted_reqs) + + self.reduce_tensor.fill_(len(decode_reqs)) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_decode_num = self.reduce_tensor.item() + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + DPChunkedPrefillWithMTPBackend.normal_mtp_decode( + self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs + ) + else: + DPChunkedPrefillWithMTPBackend.overlap_mtp_decode( + self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index 2652f4a33..a8084f2a1 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -72,11 +72,12 @@ def decode(self): self._filter_reqs(ok_finished_reqs) if prefill_reqs: - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py index a728b6b03..84777711d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py @@ -1,22 +1,9 @@ -import os -import time -import threading import torch import torch.multiprocessing as mp import torch.distributed as dist from typing import List, Tuple -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.core.objs import FinishStatus -from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from rpyc.utils.server import ThreadedServer -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from .prefill_impl import ChunckedPrefillForPrefillNode from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend @@ -60,58 +47,13 @@ def decode(self): max_prefill_num = self.reduce_tensor.item() if max_prefill_num != 0: if not self.enable_prefill_microbatch_overlap: - self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + DPChunkedPrefillBackend.normal_prefill_reqs( + self, prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs + ) else: - self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + DPChunkedPrefillBackend.overlap_prefill_reqs( + self, prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs + ) self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) return - - def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_prepare_prefill_inputs, - ) - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( - prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal - ) - logits = self.model.forward(**kwargs) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - if len(run_reqs) != 0: - logits = logits[0 : len(run_reqs), :] - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) - - def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_overlap_prepare_prefill_inputs, - ) - - ( - micro_batch, - run_reqs, - padded_req_num, - micro_batch1, - run_reqs1, - padded_req_num1, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - req_num, req_num1 = len(run_reqs), len(run_reqs1) - all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) - - all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) - all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) - - all_run_reqs = run_reqs + run_reqs1 - if all_run_reqs: - next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp.py new file mode 100644 index 000000000..ca740b477 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp.py @@ -0,0 +1,39 @@ +import torch.multiprocessing as mp +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from .prefill_impl import ChunckedPrefillForPrefillNode +from ...impl_mtp import ContinuesBatchWithMTPBackend +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class ChunckedPrefillForMtpPrefillNode(ChunckedPrefillForPrefillNode): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__(info_queue=info_queue, mem_queue=mem_queue) + return + + def init_model(self, kvargs): + super().init_model(kvargs) + ContinuesBatchWithMTPBackend._init_mtp_draft_model(self, kvargs) + return + + def decode(self): + uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + assert len(uinit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._filter_reqs(aborted_reqs) + + if ok_finished_reqs: + self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) + self._filter_reqs(ok_finished_reqs) + ok_finished_reqs.clear() + + if prefill_reqs: + ContinuesBatchWithMTPBackend.normal_mtp_prefill_reqs( + self, prefill_reqs=prefill_reqs, uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs + ) + return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py new file mode 100644 index 000000000..3f53070f6 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_mtp_for_dp_chuncked.py @@ -0,0 +1,61 @@ +import torch.multiprocessing as mp +import torch.distributed as dist +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger +from .prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode +from ....dp_backend.impl_mtp import DPChunkedPrefillWithMTPBackend + +logger = init_logger(__name__) + + +class DPChunkedForMtpPrefillNode(DPChunkedForPrefillNode): + def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None: + super().__init__(info_queue=info_queue, mem_queue=mem_queue) + return + + def init_model(self, kvargs): + super().init_model(kvargs) + DPChunkedPrefillWithMTPBackend._init_mtp_draft_model(self, kvargs) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids, + no_decode=True, + ) + assert len(uninit_reqs) == 0 + assert len(decode_reqs) == 0 + + self._filter_reqs(aborted_reqs) + + if ok_finished_reqs: + self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) + self._filter_reqs(ok_finished_reqs) + ok_finished_reqs.clear() + + # 进行 chuncked prefill + current_dp_prefill_num = len(prefill_reqs) + self.reduce_tensor.fill_(current_dp_prefill_num) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_prefill_num = self.reduce_tensor.item() + + if max_prefill_num != 0: + if not self.enable_prefill_microbatch_overlap: + DPChunkedPrefillWithMTPBackend.normal_mtp_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + max_prefill_num=max_prefill_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + ) + else: + DPChunkedPrefillWithMTPBackend.overlap_mtp_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + max_prefill_num=max_prefill_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, + ) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 18b73ee23..83735fd7f 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -10,7 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.tokenizer import get_tokenizer from lightllm.server.req_id_generator import convert_sub_id_to_group_id -from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( +from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, prepare_decode_inputs, ) @@ -69,10 +69,11 @@ def decode(self): for req in prefill_reqs if convert_sub_id_to_group_id(req.req_id) == req.req_id ] - kwargs, group_run_reqs = prepare_prefill_inputs( + model_input, group_run_reqs = prepare_prefill_inputs( group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits uninit_req_ids = [req.req_id for req in uninit_reqs] self._overlap_req_init_and_filter( @@ -90,8 +91,9 @@ def decode(self): ) if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits uninit_req_ids = [req.req_id for req in uninit_reqs] self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c9ecafdb2..376ed1501 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -3,6 +3,7 @@ import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams @@ -10,6 +11,10 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_overlap_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_overlap_prepare_decode_inputs class DPChunkedPrefillBackend(ModeBackend): @@ -25,15 +30,6 @@ def __init__(self) -> None: def init_custom(self): self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - # 这个地方预先进行一次 prefill 推理,主要是为了填充后续fake请求的第一个token位置,因为填充的decode请求 - # 在推理的时候至少是两个token,1个是已经有kv的token,一个是等待计算kv的token,然后生成第三个token,这几个 - # token 实际引用的都是 g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,但是需要初始化排除 - # nan 值,避免后续构建的fake请求在计算的过程中出现计算错误。 - from .pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) - self.model.forward(**kwargs) - assert len(run_reqs) == 0 and padded_req_num == 1 return def prefill(self, reqs: List[Tuple]): @@ -71,12 +67,11 @@ def decode(self): return def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( - prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output: ModelOutput = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) if len(run_reqs) != 0: logits = logits[0 : len(run_reqs), :] @@ -89,12 +84,11 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int return def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_prepare_decode_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( - decode_reqs, max_decode_num, is_multimodal=self.is_multimodal + model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output: ModelOutput = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) @@ -109,17 +103,17 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit logits = None def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_overlap_prepare_decode_inputs - ( - micro_batch, + micro_input, run_reqs, padded_req_num, - micro_batch1, + micro_input1, run_reqs1, padded_req_num1, - ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) + ) = padded_overlap_prepare_decode_inputs(decode_reqs, is_multimodal=self.is_multimodal) + model_output, model_output1 = self.model.microbatch_overlap_decode(micro_input, micro_input1) + logits = model_output.logits + logits1 = model_output1.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) req_num, req_num1 = len(run_reqs), len(run_reqs1) all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) @@ -138,17 +132,17 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini return def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_overlap_prepare_prefill_inputs - ( - micro_batch, + micro_input, run_reqs, padded_req_num, - micro_batch1, + micro_input1, run_reqs1, padded_req_num1, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + model_output, model_output1 = self.model.microbatch_overlap_prefill(micro_input, micro_input1) + logits = model_output.logits + logits1 = model_output1.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) req_num, req_num1 = len(run_reqs), len(run_reqs1) all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py new file mode 100644 index 000000000..0b8294f46 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py @@ -0,0 +1,330 @@ +import torch +import torch.distributed as dist +import numpy as np +from typing import List, Tuple +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl_mtp import ContinuesBatchWithMTPBackend +from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_overlap_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.pre import padded_overlap_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( + prepare_mtp_prefill_inputs, +) +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.infer_lock import g_infer_state_lock + + +class DPChunkedPrefillWithMTPBackend(ContinuesBatchWithMTPBackend): + def __init__(self) -> None: + super().__init__() + self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap + pass + + def init_custom(self): + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids + ) + + if aborted_reqs: + g_infer_context.filter_reqs(aborted_reqs) + + current_dp_prefill_num = len(prefill_reqs) + self.reduce_tensor.fill_(current_dp_prefill_num) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_prefill_num = self.reduce_tensor.item() + if max_prefill_num != 0: + if not self.enable_prefill_microbatch_overlap: + self.normal_mtp_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_mtp_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + + self.reduce_tensor.fill_(len(decode_reqs)) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_decode_num = self.reduce_tensor.item() + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + self.normal_mtp_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_mtp_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_mtp_prefill_reqs( + self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs + ): + # main model prefill + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, is_multimodal=self.is_multimodal + ) + model_output: ModelOutput = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + next_token_ids_cpu = [] + + if len(run_reqs) != 0: + next_token_ids_gpu, next_token_probs = sample(model_output.logits[: len(run_reqs)], run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + # fill mtp draft model prefill kv + # 因为存在padding的请求,需要将padding的请求一并考虑同时进行推理。 + draft_model_input = model_input + draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") + if len(run_reqs) != 0: + draft_next_token_ids_gpu[0 : len(run_reqs)].copy_(next_token_ids_gpu) + + draft_model_output = model_output + + for draft_model_idx in range(self.mtp_step): + draft_model_input = prepare_mtp_prefill_inputs( + model_input=draft_model_input, + b_next_token_ids=draft_next_token_ids_gpu, + deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens, + ) + + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids_gpu, _ = self._gen_argmax_token_ids(draft_model_output) + return + + def normal_mtp_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, is_multimodal=self.is_multimodal + ) + # main model decode + model_output = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + need_free_mem_indexes = [] + verify_ok_req_last_indexes = [] + if len(run_reqs) != 0: + next_token_ids_gpu, next_token_probs = sample(model_output.logits[: len(run_reqs)], run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # verify + mem_indexes_cpu = model_input.mem_indexes[0 : len(run_reqs)].cpu().numpy() + verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes = self._verify_mtp( + run_reqs, next_token_ids_cpu, mem_indexes_cpu + ) + + self._post_handle( + verify_ok_reqs, + next_token_ids_cpu[verify_ok_req_indexes], + next_token_logprobs_cpu[verify_ok_req_indexes], + is_chuncked_mode=False, + do_filter_finished_reqs=False, + ) + + # fill draft model kv and gen next mtp token ids. + draft_model_input = model_input + draft_model_output = model_output + draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda") + if len(run_reqs) != 0: + draft_next_token_ids_gpu[0 : len(run_reqs)].copy_(next_token_ids_gpu) + + # process the draft model output + for draft_model_idx in range(self.mtp_step): + + draft_model_input.input_ids = draft_next_token_ids_gpu + draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens + # spec decode: MTP + draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids_gpu, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_model_output) + + if verify_ok_req_last_indexes: + unique_reqs = [run_reqs[index] for index in verify_ok_req_last_indexes] + self._update_reqs_mtp_gen_token_ids( + reqs=unique_reqs, mtp_draft_next_token_ids=draft_next_token_ids_cpu[verify_ok_req_last_indexes] + ) + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + return + + def overlap_mtp_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + ( + micro_input, + run_reqs, + padded_req_num, + micro_input1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_decode_inputs(decode_reqs, is_multimodal=self.is_multimodal) + micro_output, micro_output1 = self.model.microbatch_overlap_decode(micro_input, micro_input1) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_run_reqs = run_reqs + run_reqs1 + need_free_mem_indexes = [] + verify_ok_req_last_indexes = [] + if len(all_run_reqs) != 0: + all_logits = torch.empty( + (req_num + req_num1, micro_output.logits.shape[1]), + dtype=micro_output.logits.dtype, + device=micro_output.logits.device, + ) + + all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True) + + next_token_ids_gpu, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + micro_mem_indexes_cpu = micro_input.mem_indexes[0:req_num].cpu() + micro_mem_indexes_cpu1 = micro_input1.mem_indexes[0:req_num1].cpu() + mem_indexes_cpu = torch.cat((micro_mem_indexes_cpu, micro_mem_indexes_cpu1), dim=0).numpy() + + # verify + verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes = self._verify_mtp( + all_run_reqs, next_token_ids_cpu, mem_indexes_cpu + ) + + self._post_handle( + verify_ok_reqs, + next_token_ids_cpu[verify_ok_req_indexes], + next_token_logprobs_cpu[verify_ok_req_indexes], + is_chuncked_mode=False, + do_filter_finished_reqs=False, + ) + + # share some inference info with the main model + draft_micro_input, draft_micro_input1 = micro_input, micro_input1 + + draft_next_token_ids_gpu = torch.zeros((micro_input.batch_size), dtype=torch.int64, device="cuda") + draft_next_token_ids_gpu1 = torch.zeros((micro_input1.batch_size), dtype=torch.int64, device="cuda") + if req_num > 0: + draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids_gpu[0:req_num]) + if req_num1 > 1: + draft_next_token_ids_gpu1[0:req_num1].copy_(next_token_ids_gpu[req_num : (req_num + req_num1)]) + draft_micro_output, draft_micro_output1 = micro_output, micro_output1 + + # process the draft model output + for draft_model_idx in range(self.mtp_step): + + draft_micro_input.input_ids = draft_next_token_ids_gpu + draft_micro_input.deepseekv3_mtp_draft_input_hiddens = draft_micro_output.deepseekv3_mtp_main_output_hiddens + draft_micro_input1.input_ids = draft_next_token_ids_gpu1 + draft_micro_input1.deepseekv3_mtp_draft_input_hiddens = ( + draft_micro_output1.deepseekv3_mtp_main_output_hiddens + ) + + draft_micro_output, draft_micro_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( + draft_micro_input, draft_micro_input1 + ) + + draft_next_token_ids_gpu, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_micro_output) + draft_next_token_ids_gpu1, draft_next_token_ids_cpu1 = self._gen_argmax_token_ids(draft_micro_output1) + + if verify_ok_req_last_indexes: + all_draft_next_token_ids_cpu = np.concatenate( + [draft_next_token_ids_cpu[0:req_num], draft_next_token_ids_cpu1[0:req_num1]], axis=0 + ) + unique_reqs = [all_run_reqs[index] for index in verify_ok_req_last_indexes] + self._update_reqs_mtp_gen_token_ids( + reqs=unique_reqs, mtp_draft_next_token_ids=all_draft_next_token_ids_cpu[verify_ok_req_last_indexes] + ) + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + return + + def overlap_mtp_prefill_reqs( + self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs + ): + ( + micro_input, + run_reqs, + padded_req_num, + micro_input1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + + micro_output, micro_output1 = self.model.microbatch_overlap_prefill(micro_input, micro_input1) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_run_reqs = run_reqs + run_reqs1 + next_token_ids_cpu = [] + if len(all_run_reqs) != 0: + all_logits = torch.empty( + (len(all_run_reqs), micro_output.logits.shape[1]), + dtype=micro_output.logits.dtype, + device=micro_output.logits.device, + ) + + all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True) + + next_token_ids_gpu, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + self._post_handle( + all_run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + # spec prefill: MTP + draft_micro_input, draft_micro_input1 = micro_input, micro_input1 + draft_next_token_ids_gpu = torch.zeros((micro_input.batch_size), dtype=torch.int64, device="cuda") + if req_num > 0: + draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids_gpu[0:req_num]) + + draft_next_token_ids_gpu1 = torch.zeros((micro_input1.batch_size), dtype=torch.int64, device="cuda") + if req_num1 > 0: + draft_next_token_ids_gpu1[0:req_num1].copy_(next_token_ids_gpu[req_num : (req_num + req_num1)]) + + draft_micro_output, draft_micro_output1 = micro_output, micro_output1 + + for draft_model_idx in range(self.mtp_step): + + draft_micro_input = prepare_mtp_prefill_inputs( + model_input=draft_micro_input, + b_next_token_ids=draft_next_token_ids_gpu, + deepseekv3_mtp_draft_input_hiddens=draft_micro_output.deepseekv3_mtp_main_output_hiddens, + ) + + draft_micro_input1 = prepare_mtp_prefill_inputs( + model_input=draft_micro_input1, + b_next_token_ids=draft_next_token_ids_gpu1, + deepseekv3_mtp_draft_input_hiddens=draft_micro_output1.deepseekv3_mtp_main_output_hiddens, + ) + + draft_micro_output, draft_micro_output1 = self.draft_models[draft_model_idx].microbatch_overlap_prefill( + draft_micro_input, draft_micro_input1 + ) + draft_next_token_ids_gpu, _ = self._gen_argmax_token_ids(draft_micro_output) + draft_next_token_ids_gpu1, _ = self._gen_argmax_token_ids(draft_micro_output1) + return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py deleted file mode 100644 index f249757d9..000000000 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py +++ /dev/null @@ -1,310 +0,0 @@ -import torch -import torch.distributed as dist -import numpy as np -import triton -from typing import List -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.infer_utils import calculate_time -from lightllm.common.mem_manager import MemoryManager -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch - - -def padded_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False): - assert max_prefill_num != 0 - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - # 当前 dp 没有请求的时候,需要进行 dp 操作。 - padded_req_num = 1 if len(req_objs) == 0 else 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - batch_multimodal_params = [] - b_ready_cache_len = [] - for req in req_objs: - - run_reqs.append(req) - batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) - - input_token_ids = req.get_chuncked_input_token_ids() - seq_len = len(input_token_ids) - input_token_len = seq_len - req.cur_kv_len - input_id = input_token_ids[req.cur_kv_len :] - - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) - b_ready_cache_len.append(req.cur_kv_len) - - # padding fake req for prefill - for _ in range(padded_req_num): - input_ids.append([1]) - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(1) - b_ready_cache_len.append(0) - nopad_total_token_num += 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1) - - input_ids = np.concatenate(input_ids, dtype=np.int64) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - if padded_req_num > 0: - padding_mem_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_mem_indexs), dim=0) - - kwargs = { - "batch_size": nopad_b_seq_len.shape[0], - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "b_ready_cache_len": b_ready_cache_len, - "is_prefill": True, - } - if is_multimodal: - kwargs["multimodal_params"] = batch_multimodal_params - - return kwargs, run_reqs, padded_req_num - - -def padded_prepare_decode_inputs(req_objs: List[InferReq], max_decode_num: int, is_multimodal=False): - assert max_decode_num != 0 - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - padded_req_num = 1 if len(req_objs) == 0 else 0 - for req in req_objs: - run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) - input_token_ids = req.get_input_token_ids() - input_id = input_token_ids[-1] - seq_len = len(input_token_ids) - assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - seq_len = 2 - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(seq_len) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - kwargs = { - "batch_size": nopad_b_seq_len.shape[0], - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "is_prefill": False, - } - return kwargs, run_reqs, padded_req_num - - -def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq], max_decode_num: int, is_multimodal=False): - assert max_decode_num != 0 - micro_batch_size = triton.cdiv(max_decode_num, 2) - micro_batch1_req_num = triton.cdiv(len(req_objs), 2) - micro_batch, run_reqs, padded_req_num = _padded_prepare_decode_micro_batch( - req_objs[0:micro_batch1_req_num], micro_batch_size, is_multimodal=is_multimodal - ) - micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_decode_micro_batch( - req_objs[micro_batch1_req_num:], micro_batch_size, is_multimodal=is_multimodal - ) - - return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1 - - -def _padded_prepare_decode_micro_batch(req_objs: List[InferReq], micro_batch_size: int, is_multimodal=False): - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - padded_req_num = micro_batch_size - len(req_objs) - for req in req_objs: - run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) - input_token_ids = req.get_input_token_ids() - input_id = input_token_ids[-1] - seq_len = len(input_token_ids) - assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - seq_len = 2 - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(seq_len) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - micro_batch = DecodeMicroBatch( - batch_size=nopad_b_seq_len.shape[0], - total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=nopad_b_req_idx, - b_seq_len=nopad_b_seq_len, - ) - - return micro_batch, run_reqs, padded_req_num - - -def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False): - assert max_prefill_num != 0 - micro_batch1_req_num = triton.cdiv(len(req_objs), 2) - micro_batch, run_reqs, padded_req_num = _padded_prepare_prefill_micro_batch( - req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal - ) - micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_prefill_micro_batch( - req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal - ) - - return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1 - - -def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=False): - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - # prefill 只需要 padding 一个请求形成 micro_batch, 并不需要两个 - # micro batch 的 batch_size 相同。 - padded_req_num = 1 if len(req_objs) == 0 else 0 - b_ready_cache_len = [] - batch_multimodal_params = [] - for req in req_objs: - run_reqs.append(req) - batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) - - input_token_ids = req.get_chuncked_input_token_ids() - seq_len = len(input_token_ids) - input_token_len = seq_len - req.cur_kv_len - input_id = input_token_ids[req.cur_kv_len :] - - nopad_b_seq_len.append(seq_len) - input_ids.extend(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) - b_ready_cache_len.append(req.cur_kv_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(1) - b_ready_cache_len.append(0) - nopad_total_token_num += 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - micro_batch = PrefillMicroBatch( - batch_size=nopad_b_seq_len.shape[0], - total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=nopad_b_req_idx, - b_seq_len=nopad_b_seq_len, - b_ready_cache_len=b_ready_cache_len, - multimodal_params=batch_multimodal_params, - ) - - return micro_batch, run_reqs, padded_req_num diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py new file mode 100644 index 000000000..985800d44 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -0,0 +1,218 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +import numpy as np +import triton +from typing import List, Optional, Tuple +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.infer_utils import calculate_time +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + + +def padded_prepare_prefill_inputs( + req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False +) -> Tuple[ModelInput, List[InferReq], int]: + + if dest_batch_size is None: + req_num = len(req_objs) + if req_num > 0: + dest_batch_size = req_num + else: + dest_batch_size = 1 + else: + assert len(req_objs) <= dest_batch_size + + run_reqs = [] + total_token_num = 0 + max_len_in_batch = 0 + padded_req_num = dest_batch_size - len(req_objs) + input_ids = [] + b_req_idx = [] + b_seq_len = [] + batch_multimodal_params = [] + b_ready_cache_len = [] + for req in req_objs: + + run_reqs.append(req) + batch_multimodal_params.append(req.multimodal_params) + b_req_idx.append(req.req_idx) + + input_token_ids = req.get_chuncked_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + + b_seq_len.append(seq_len) + input_ids.append(input_id) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, input_token_len) + b_ready_cache_len.append(req.cur_kv_len) + + # padding fake req for prefill + for _ in range(padded_req_num): + input_ids.append([1]) + b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) + b_seq_len.append(1) + b_ready_cache_len.append(0) + total_token_num += 1 + max_len_in_batch = max(max_len_in_batch, 1) + + input_ids = np.concatenate(input_ids, dtype=np.int64) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + g_infer_state_lock.release() + + if padded_req_num > 0: + mem_indexes = F.pad( + input=mem_indexes, + pad=(0, padded_req_num), + mode="constant", + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) + + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) + if is_multimodal: + model_input.multimodal_params = batch_multimodal_params + + return model_input, run_reqs, padded_req_num + + +def padded_prepare_decode_inputs( + req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False +) -> Tuple[ModelInput, List[InferReq], int]: + run_reqs = [] + total_token_num = 0 + max_len_in_batch = 0 + input_ids = [] + b_req_idx = [] + b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + b_req_idx.append(req.req_idx) + input_token_ids = req.get_input_token_ids() + input_id = input_token_ids[-1] + seq_len = len(input_token_ids) + assert req.cur_kv_len == seq_len - 1 + b_seq_len.append(seq_len) + input_ids.append(input_id) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + # process the draft tokens. + for step in range(len(req.mtp_gen_token_ids)): + run_reqs.append(req) + b_req_idx.append(req.req_idx) + seq_len += 1 + b_seq_len.append(seq_len) + input_ids.append(req.mtp_gen_token_ids[step]) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + if dest_batch_size is None: + if len(run_reqs) == 0: + dest_batch_size = 1 + else: + dest_batch_size = len(run_reqs) + else: + assert len(run_reqs) <= dest_batch_size + + padded_req_num = dest_batch_size - len(run_reqs) + + # padding fake req for decode + for _ in range(padded_req_num): + input_ids.append(1) + seq_len = 2 + b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) + b_seq_len.append(seq_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + g_infer_state_lock.release() + + if padded_req_num > 0: + mem_indexes = F.pad( + input=mem_indexes, + pad=(0, padded_req_num), + mode="constant", + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) + + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=False, + ) + return model_input, run_reqs, padded_req_num + + +def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq], is_multimodal=False): + split_req_bound = triton.cdiv(len(req_objs), 2) + req_objs_0 = req_objs[0:split_req_bound] + req_objs_1 = req_objs[split_req_bound:] + + enable_mtp = get_env_start_args().mtp_mode is not None + if enable_mtp: + micro_batch_size = max( + sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_0]), + sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_1]), + ) + else: + micro_batch_size = triton.cdiv(len(req_objs), 2) + + micro_batch_size = max(1, micro_batch_size) + + micro_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + req_objs_0, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal + ) + micro_input1, run_reqs1, padded_req_num1 = padded_prepare_decode_inputs( + req_objs_1, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal + ) + return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 + + +def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], is_multimodal=False): + micro_batch1_req_num = triton.cdiv(len(req_objs), 2) + + micro_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal + ) + + micro_input1, run_reqs1, padded_req_num1 = padded_prepare_prefill_inputs( + req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal + ) + + return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 9f174e8bc..185570d74 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,24 +1,26 @@ import torch import numpy as np -from typing import List +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelInput -def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal=False): +def prepare_prefill_inputs( + req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal: bool = False +) -> Tuple[ModelInput, List[InferReq]]: run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 + total_token_num = 0 + max_len_in_batch = 0 input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] + b_req_idx = [] + b_seq_len = [] batch_multimodal_params = [] b_ready_cache_len = [] for req in req_objs: - run_reqs.append(req) batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) + b_req_idx.append(req.req_idx) if is_chuncked_mode: input_token_ids = req.get_chuncked_input_token_ids() @@ -30,17 +32,17 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_ input_id = input_token_ids[req.cur_kv_len :] - nopad_b_seq_len.append(seq_len) + b_seq_len.append(seq_len) input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, input_token_len) b_ready_cache_len.append(req.cur_kv_len) input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") # dynamic prompt cache 准备 token @@ -50,44 +52,54 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_ mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() g_infer_state_lock.release() - kwargs = { - "batch_size": len(run_reqs), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "b_ready_cache_len": b_ready_cache_len, - "is_prefill": True, - } + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) if is_multimodal: - kwargs["multimodal_params"] = batch_multimodal_params + model_input.multimodal_params = batch_multimodal_params - return kwargs, run_reqs + return model_input, run_reqs -def prepare_decode_inputs(req_objs: List[InferReq]): +def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 + total_token_num = 0 + max_len_in_batch = 0 input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] + b_req_idx = [] + b_seq_len = [] for req in req_objs: run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) + b_req_idx.append(req.req_idx) input_id = req.get_last_gen_token() seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) + b_seq_len.append(seq_len) input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + # process the draft tokens. + for step in range(len(req.mtp_gen_token_ids)): + run_reqs.append(req) + b_req_idx.append(req.req_idx) + seq_len += 1 + b_seq_len.append(seq_len) + input_ids.append(req.mtp_gen_token_ids[step]) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() @@ -96,14 +108,14 @@ def prepare_decode_inputs(req_objs: List[InferReq]): mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() g_infer_state_lock.release() - kwargs = { - "batch_size": len(run_reqs), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "is_prefill": False, - } - return kwargs, run_reqs + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=False, + ) + return model_input, run_reqs diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py new file mode 100644 index 000000000..3991e42de --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -0,0 +1,19 @@ +import torch +import copy +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.triton_kernel.gen_mtp_prefill_params import gen_mtp_new_input_ids + + +def prepare_mtp_prefill_inputs( + model_input: ModelInput, b_next_token_ids: torch.Tensor, deepseekv3_mtp_draft_input_hiddens: torch.Tensor +): + new_model_input = copy.copy(model_input) + new_input_ids = gen_mtp_new_input_ids( + input_ids=model_input.input_ids, + b_next_token_ids=b_next_token_ids, + b_seq_len=model_input.b_seq_len, + b_ready_cache_len=model_input.b_ready_cache_len, + ) + new_model_input.input_ids = new_input_ids + new_model_input.deepseekv3_mtp_draft_input_hiddens = deepseekv3_mtp_draft_input_hiddens + return new_model_input diff --git a/lightllm/server/router/model_infer/mode_backend/pre.py b/lightllm/server/router/model_infer/mode_backend/pre.py new file mode 100644 index 000000000..e88b1a62b --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/pre.py @@ -0,0 +1,6 @@ +from .generic_pre_process import prepare_prefill_inputs +from .generic_pre_process import prepare_decode_inputs +from .generic_padded_pre_process import padded_prepare_prefill_inputs +from .generic_padded_pre_process import padded_prepare_decode_inputs +from .generic_padded_pre_process import padded_overlap_prepare_prefill_inputs +from .generic_padded_pre_process import padded_overlap_prepare_decode_inputs diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 1568d62ef..c1d89beb6 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -21,6 +21,12 @@ DPForDecodeNode, ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, + ContinuesBatchWithMTPBackend, + DPChunkedPrefillWithMTPBackend, + DPForMtpDecodeNode, + ContinuesBatchBackendForMtpDecodeNode, + ChunckedPrefillForMtpPrefillNode, + DPChunkedForMtpPrefillNode, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray @@ -108,38 +114,48 @@ def init_model(self, kvargs): # 填充真正的 rank_id 参数 kvargs["rank_id"] = self.rank self.world_size = kvargs["world_size"] - disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False) - return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) - use_reward_model = kvargs.get("use_reward_model", False) - diverse_mode = kvargs.get("diverse_mode", False) - is_token_healing = kvargs.get("is_token_healing", False) - is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) - if kvargs.get("args", None) is not None: - is_outlines_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" - is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" - assert not ( - is_outlines_constraint_mode and is_xgrammar_constraint_mode - ), "only one constraint mode can be true" - is_prefill_node = kvargs.get("args", None).run_mode == "prefill" - is_decode_node = kvargs.get("args", None).run_mode == "decode" - else: - is_outlines_constraint_mode = False - is_xgrammar_constraint_mode = False - is_prefill_node = False - is_decode_node = False + disable_chunked_prefill = self.args.disable_chunked_prefill + return_all_prompt_logprobs = self.args.return_all_prompt_logprobs + use_reward_model = self.args.use_reward_model + diverse_mode = self.args.diverse_mode + is_token_healing = self.args.token_healing_mode + is_first_token_constraint_mode = self.args.first_token_constraint_mode + + is_outlines_constraint_mode = self.args.output_constraint_mode == "outlines" + is_xgrammar_constraint_mode = self.args.output_constraint_mode == "xgrammar" + assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" + is_prefill_node = self.args.run_mode == "prefill" + is_decode_node = self.args.run_mode == "decode" + + enable_mtp = self.args.mtp_mode is not None if is_prefill_node: - if kvargs.get("args", None).dp > 1: - self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) + if enable_mtp: + if self.args.dp > 1: + self.backend = DPChunkedForMtpPrefillNode(self.info_queue, self.mem_queue) + else: + self.backend = ChunckedPrefillForMtpPrefillNode(self.info_queue, self.mem_queue) else: - self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) + if self.args.dp > 1: + self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) + else: + self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) elif is_decode_node: - if kvargs.get("args", None).dp > 1: - self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) + if enable_mtp: + if self.args.dp > 1: + self.backend = DPForMtpDecodeNode(self.info_queue, self.mem_queue) + else: + self.backend = ContinuesBatchBackendForMtpDecodeNode(self.info_queue, self.mem_queue) else: - self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) - elif kvargs.get("dp_size", 1) > 1: - self.backend = DPChunkedPrefillBackend() + if self.args.dp > 1: + self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) + else: + self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) + elif self.args.dp > 1: + if enable_mtp: + self.backend = DPChunkedPrefillWithMTPBackend() + else: + self.backend = DPChunkedPrefillBackend() elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: @@ -155,9 +171,15 @@ def init_model(self, kvargs): elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() elif disable_chunked_prefill: - self.backend = ContinuesBatchBackend() + if enable_mtp: + self.backend = ContinuesBatchWithMTPBackend() + else: + self.backend = ContinuesBatchBackend() else: - self.backend = ChunkedPrefillBackend() + if enable_mtp: + self.backend = ContinuesBatchWithMTPBackend() + else: + self.backend = ChunkedPrefillBackend() logger.info(f"use {self.backend.__class__.__name__}") self.backend.init_model(kvargs) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 9bb553b2b..b8df6e3e9 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -51,7 +51,7 @@ def _merge_batch(self, dp_batches: List[Batch]): def append(self, req: Req): suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: - logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error") + logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") suggested_dp_index = self._get_suggest_dp_index() self.pre_select_dp_index = suggested_dp_index req.sample_params.suggested_dp_index = suggested_dp_index @@ -66,7 +66,7 @@ def extend(self, req_group: List[Req]): for req in req_group: suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: - logger.error(f"input req {req.request_id} dp index {suggested_dp_index} has error") + logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") self.pre_select_dp_index = index req.sample_params.suggested_dp_index = index self.inner_queues[index].append(req) diff --git a/lightllm/utils/custom_kernel_utis.py b/lightllm/utils/custom_kernel_utis.py index f72dfabfc..cbc52c8f9 100644 --- a/lightllm/utils/custom_kernel_utis.py +++ b/lightllm/utils/custom_kernel_utis.py @@ -114,3 +114,13 @@ def tensor_copy_3dim(dest_tensor: torch.Tensor, source_tensor: torch.Tensor): num_stages=3, ) return + + +def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int): + assert input.ndim == 2 + origin_batch_size = input.shape[0] + hidden = input.shape[1] + out = torch.empty((new_batch_size, hidden), dtype=input.dtype, device=input.device, requires_grad=False) + out[0:origin_batch_size, :] = input + out[origin_batch_size:, :] = input[0:1, :] + return out diff --git a/test/model/model_infer_mtp.py b/test/model/model_infer_mtp.py new file mode 100644 index 000000000..eb36bc873 --- /dev/null +++ b/test/model/model_infer_mtp.py @@ -0,0 +1,281 @@ +import os +import torch +import numpy as np +from multiprocessing import Queue +import multiprocessing +from transformers import PretrainedConfig +from lightllm.utils.dist_utils import init_distributed_env, get_current_rank_in_dp +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models import get_model +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.server.core.objs.start_args_type import StartArgs +from torch.profiler import profile, record_function, ProfilerActivity +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +import torch.cuda as cuda + +logger = init_logger(__name__) + + +def init_mtp_model(args: StartArgs, kvargs, main_model): + mtp_step = args.mtp_step + draft_models = [] + + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" + mtp_model_kvargs = kvargs + mtp_model_kvargs.update( + { + "weight_dir": args.mtp_draft_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "use_dynamic_prompt_cache": False, + "disable_chunked_prefill": True, + "mtp_mode": args.mtp_mode, + "main_model": main_model, + } + ) + for i in range(mtp_step): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) + mtp_model_kvargs.update( + { + "weight_dir": args.spec_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "use_dynamic_prompt_cache": False, + "disable_chunked_prefill": True, + "mtp_mode": args.mtp_mode, + "main_model": main_model, + "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + ) + draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + return draft_models + + +def test_model_inference_mtp(args): + ans_queue = Queue() + workers = [] + dp_size = args.get("dp", 1) + + for rank_id in range(args.node_rank * args.tp, (args.node_rank + 1) * args.tp): + model_kvargs = { + "args": args, + "nccl_host": args.nccl_host, + "data_type": args.data_type, + "nccl_port": args.nccl_port, + "rank_id": rank_id, + "world_size": args.tp, + "dp_size": dp_size, + "weight_dir": args.model_dir, + "quant_type": args.quant_type, + "load_way": "HF", + "max_total_token_num": args.max_total_token_num, + "graph_max_len_in_batch": args.max_req_total_len, + "graph_max_batch_size": args.graph_max_batch_size, + "mem_faction": args.mem_fraction, + "max_req_num": 2000, + "batch_max_tokens": 2048, + "run_mode": "normal", + "max_seq_length": args.max_req_total_len, + "spec_algo": args.spec_algo, + "disable_cudagraph": args.disable_cudagraph, + } + proc = multiprocessing.Process( + target=tppart_model_infer, + args=(args, model_kvargs, args.batch_size, args.input_len, args.output_len, ans_queue), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + assert not ans_queue.empty() + while not ans_queue.empty(): + assert ans_queue.get() + return + + +def torch_profile(fn, log_dir=None): + torch.cuda.synchronize() + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir), + ) as prof: + fn() + if get_current_rank_in_dp() == 0: + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False): + import time + + torch.cuda.synchronize() + prefill_start_time = time.time() + + test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + test_data = test_data.reshape(-1) + test_data = torch.from_numpy(test_data).cuda() + + b_req_idx = torch.tensor( + [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(batch_size): + b_seq_len[i] = input_len + + total_token_num = input_len * batch_size + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + # Main model Prefill + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=input_len, + input_ids=test_data, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=True, + b_ready_cache_len=b_ready_cache_len, + ) + + model_output: ModelOutput = main_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + draft_ids = [predict_ids] + + # Draft model Prefill + # For simplicity, we'll just take the input of main_model to draft model. + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + draft_ids.append(predict_ids) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + + torch.cuda.synchronize() + prefill_end_time = time.time() + if get_current_rank_in_dp() == 0 and not warmup: + print("prefill time cost:", (prefill_end_time - prefill_start_time) * 1000) + print( + f"Prefill throughput: {batch_size * input_len * args.dp / (prefill_end_time - prefill_start_time)} tokens/s" + ) + + torch.cuda.synchronize() + + decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1) + decode_input_ids = torch.from_numpy(decode_input_ids).cuda() + + # build main decode input: + nopad_b_seq_idx = [] + nopad_b_seq_len = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + + for i in range(batch_size): + nopad_b_seq_idx.append(b_req_idx[i]) + seq_len = b_seq_len[i].item() + nopad_b_seq_len.append(seq_len + 1) + nopad_total_token_num += seq_len + 1 + nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1) + + for step in range(len(draft_models)): + nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_seq_len.append(seq_len + step + 2) + nopad_total_token_num += seq_len + step + 2 + nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) + + nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + + model_input = ModelInput( + batch_size=batch_size * (len(draft_models) + 1), + total_token_num=nopad_total_token_num, + max_len_in_batch=nopad_max_len_in_batch, + input_ids=decode_input_ids, + mem_indexes=mem_indexes, + b_req_idx=nopad_b_seq_idx, + b_seq_len=nopad_b_seq_len, + is_prefill=False, + ) + + # Main decode + for i in range(0, output_len, len(draft_models) + 1): + torch.cuda.synchronize() + step_start_time = time.time() + model_output = main_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + + # draft decode + model_input.input_ids = predict_ids.reshape(-1) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + model_input.input_ids = predict_ids.reshape(-1) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + + # accept all draft ids by default. + model_input.input_ids = predict_ids.reshape(-1) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + torch.cuda.synchronize() + if i % 100 == 0 or i == output_len - 1: + step_end_time = time.time() + if get_current_rank_in_dp() == 0 and not warmup: + step_time = step_end_time - step_start_time + print(i, " step cost time:", step_time * 1000) + print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + + main_model.mem_manager.free_all() + main_model.req_manager.free_all() + + +def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, ans_queue): + args = get_env_start_args() + import triton.profiler as proton + import torch + from lightllm.distributed import dist_group_manager + from lightllm.utils.dist_utils import set_current_device_id + + import torch.distributed as dist + + enable_decode_overlap = args.enable_decode_microbatch_overlap + group_size = 1 + if enable_decode_overlap or args.enable_prefill_microbatch_overlap: + group_size = 2 + init_distributed_env(model_kvargs) + dist_group_manager.create_groups(group_size=group_size) + model_cfg, _ = PretrainedConfig.get_config_dict(model_kvargs["weight_dir"]) + dist.barrier() + + torch.cuda.empty_cache() + + main_model, _ = get_model(model_cfg, model_kvargs) + draft_models = init_mtp_model(args, model_kvargs, main_model) + if isinstance(batch_sizes, int): + batch_sizes = [batch_sizes] + + for batch_size in batch_sizes: + # warm up + run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=True) + torch.cuda.synchronize() + run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False) + dist.barrier() + + ans_queue.put(True) + return diff --git a/test/model/test_model.py b/test/model/test_model.py index 1376e2570..bf7d0ac43 100644 --- a/test/model/test_model.py +++ b/test/model/test_model.py @@ -5,6 +5,7 @@ import unittest from model_infer import test_model_inference +from model_infer_mtp import test_model_inference_mtp from lightllm.server.api_cli import make_argument_parser from lightllm.utils.envs_utils import set_env_start_args, get_env_start_args from lightllm.utils.config_utils import get_config_json, get_dtype @@ -15,7 +16,10 @@ def test_model_infer(self): args = get_env_start_args() if args.data_type is None: args.data_type = get_dtype(args.model_dir) - test_model_inference(args) + if args.mtp_mode == "deepseekv3": + test_model_inference_mtp(args) + else: + test_model_inference(args) return diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_mtp_prefill_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_mtp_prefill_params.py new file mode 100644 index 000000000..ecb7e6821 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_mtp_prefill_params.py @@ -0,0 +1,28 @@ +import torch +import pytest +import numpy as np +from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.gen_mtp_prefill_params import gen_mtp_new_input_ids + + +def test_gen_mtp_new_input_ids_0(): + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).int().cuda() + b_next_token_ids = torch.tensor([10, 11, 12]).int().cuda() + b_seq_len = torch.tensor([3, 3, 3]).int().cuda() + expected_output = torch.tensor([2, 3, 10, 5, 6, 11, 8, 9, 12]).int().cuda() + new_input_ids = gen_mtp_new_input_ids(input_ids, b_next_token_ids, b_seq_len) + assert torch.equal(new_input_ids, expected_output) + + +def test_gen_mtp_new_input_ids_1(): + input_ids = torch.tensor([1, 2, 3, 4, 5, 6]).int().cuda() + b_next_token_ids = torch.tensor([10, 11, 12]).int().cuda() + b_seq_len = torch.tensor([3, 3, 3]).int().cuda() + b_ready_cache_len = torch.tensor([1, 1, 1]).int().cuda() + expected_output = torch.tensor([2, 10, 4, 11, 6, 12]).int().cuda() + new_input_ids = gen_mtp_new_input_ids(input_ids, b_next_token_ids, b_seq_len, b_ready_cache_len=b_ready_cache_len) + assert torch.equal(new_input_ids, expected_output) + + +if __name__ == "__main__": + pytest.main()