diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 873ab00d4..72dba7b57 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1,8 +1,10 @@ import os # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +import copy import json import torch +import torch.nn.functional as F from typing import final from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights @@ -19,7 +21,9 @@ from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch + logger = init_logger(__name__) @@ -40,6 +44,7 @@ class TpPartBaseModel: infer_state_class = InferStateInfo def __init__(self, kvargs): + self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] self.max_total_token_num = kvargs["max_total_token_num"] @@ -71,6 +76,8 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode + self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3" + self._init_datatype() self._init_config() self._verify_must() @@ -226,263 +233,150 @@ def _init_custom(self): pass @torch.no_grad() - def forward( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids: torch.Tensor, - mem_indexes: torch.Tensor, - b_req_idx: torch.Tensor, - b_seq_len: torch.Tensor, - b_ready_cache_len: torch.Tensor = None, - multimodal_params=None, - is_prefill=True, - ): - assert mem_indexes.is_cuda - - if is_prefill: - return self._prefill( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len, - multimodal_params, - ) + def forward(self, model_input: ModelInput): + assert model_input.mem_indexes.is_cuda + + if model_input.is_prefill: + return self._prefill(model_input) else: - return self._decode( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - multimodal_params, - ) + return self._decode(model_input) - def _prefill( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len, - multimodal_params, - ): + def _create_inferstate(self, model_input: ModelInput, batch_index: int = 0): infer_state = self.infer_state_class() - infer_state.is_prefill = True + infer_state.is_prefill = model_input.is_prefill infer_state.is_token_healing = self.is_token_healing infer_state.return_all_prompt_logics = self.return_all_prompt_logics infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - infer_state.batch_size = batch_size - infer_state.total_token_num = total_token_num - infer_state.max_len_in_batch = max_len_in_batch - assert b_req_idx.shape[0] == b_seq_len.shape[0] - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - if b_ready_cache_len is not None: - infer_state.b_ready_cache_len = b_ready_cache_len - else: - infer_state.b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=b_seq_len.dtype, device=b_seq_len.device) - infer_state.multimodal_params = multimodal_params + infer_state.batch_size = model_input.batch_size + infer_state.total_token_num = model_input.total_token_num + infer_state.max_len_in_batch = model_input.max_len_in_batch + assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0] + infer_state.b_req_idx = model_input.b_req_idx + infer_state.b_seq_len = model_input.b_seq_len + if model_input.is_prefill: + if model_input.b_ready_cache_len is not None: + infer_state.b_ready_cache_len = model_input.b_ready_cache_len + else: + infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len) + + infer_state.multimodal_params = model_input.multimodal_params infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - infer_state.mem_index = mem_indexes + infer_state.mem_index = model_input.mem_indexes infer_state.kv_buffer_shapedtype = ( - (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + (model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), self.data_type, ) - infer_state.dist_group = dist_group_manager.get_default_group() + infer_state.dist_group = dist_group_manager.get_group(batch_index) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - b_req_idx, - b_seq_len, - infer_state.b_ready_cache_len, - max_len_in_batch, - infer_state.mem_index, - ) + # 特殊模型,特殊模式的特定变量初始化操作。 + infer_state.deepseekv3_mtp_draft_input_hiddens = model_input.deepseekv3_mtp_draft_input_hiddens - infer_state.init_some_extra_state(self, input_ids) - predict_logics = self._context_forward(input_ids, infer_state) - return predict_logics + return infer_state - def _decode( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - multimodal_params, - ): - infer_state = self.infer_state_class() - infer_state.is_prefill = False - infer_state.batch_size = batch_size - infer_state.total_token_num = total_token_num - infer_state.max_len_in_batch = max_len_in_batch - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - assert b_req_idx.shape[0] == b_seq_len.shape[0] - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.multimodal_params = multimodal_params + def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): + if model_input.batch_size == new_batch_size: + return model_input - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager + assert model_input.batch_size <= new_batch_size - infer_state.mem_index = mem_indexes - infer_state.kv_buffer_shapedtype = ( - (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, + padded_batch_size = new_batch_size - model_input.batch_size + new_model_input = copy.copy(model_input) + new_model_input.batch_size = new_batch_size + new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) + new_model_input.b_req_idx = F.pad( + new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID + ) + new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, padded_batch_size), mode="constant", value=2) + new_model_input.mem_indexes = F.pad( + new_model_input.mem_indexes, + (0, padded_batch_size), + mode="constant", + value=self.mem_manager.HOLD_TOKEN_MEMINDEX, ) - infer_state.dist_group = dist_group_manager.get_default_group() - copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) - - infer_state.init_some_extra_state(self, input_ids) - if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch): - if self.graph.need_capture(batch_size): - infer_state.is_cuda_graph = True - predict_logics = self.graph.capture_decode(self._token_forward, input_ids, infer_state) - else: - predict_logics = self.graph.replay(input_ids, infer_state) - else: - predict_logics = self._token_forward(input_ids, infer_state) - return predict_logics - @torch.no_grad() - def microbatch_overlap_decode(self, batch: DecodeMicroBatch, batch1: DecodeMicroBatch): - assert batch.batch_size == batch1.batch_size - assert batch.mem_indexes.is_cuda - assert batch1.mem_indexes.is_cuda - input_ids, input_ids1 = batch.input_ids, batch1.input_ids - - def create_inferstate(cur_batch: DecodeMicroBatch, batch_index): - infer_state = self.infer_state_class() - infer_state.is_prefill = False - infer_state.batch_size = cur_batch.batch_size - infer_state.total_token_num = cur_batch.total_token_num - infer_state.max_len_in_batch = cur_batch.max_len_in_batch - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0] - infer_state.b_req_idx = cur_batch.b_req_idx - infer_state.b_seq_len = cur_batch.b_seq_len - infer_state.multimodal_params = None - infer_state.microbatch_index = batch_index - - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager - - infer_state.mem_index = cur_batch.mem_indexes - infer_state.kv_buffer_shapedtype = ( - (cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, - ) - infer_state.dist_group = dist_group_manager.get_group(batch_index) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, cur_batch.b_req_idx, cur_batch.b_seq_len, infer_state.mem_index + # 特殊模型,特殊模式的特殊变量的特殊 padding + if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None: + new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch( + input=new_model_input.deepseekv3_mtp_draft_input_hiddens, + new_batch_size=new_batch_size, ) - return infer_state - infer_state = create_inferstate(batch, 0) - infer_state1 = create_inferstate(batch1, 1) + return new_model_input - infer_state.init_some_extra_state(self, input_ids) - infer_state1.init_some_extra_state(self, input_ids1) + def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int): + padded_batch_size = model_output.logits.shape[0] + if padded_batch_size == origin_batch_size: + return model_output + new_model_output = copy.copy(model_output) + new_model_output.logits = new_model_output.logits[0:origin_batch_size] - batch_size = batch.batch_size - max_len_in_batch = max(batch.max_len_in_batch, batch1.max_len_in_batch) + # 特殊模型,特殊模式的特殊变量的特殊 unpad + if new_model_output.deepseekv3_mtp_main_output_hiddens is not None: + _hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens + new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_batch_size] - if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch): - if self.graph.need_capture(batch_size): - infer_state.is_cuda_graph = True - infer_state1.is_cuda_graph = True + return new_model_output - predict_logics, predict_logics1 = self.graph.capture_decode( - self._overlap_tpsp_token_forward, - input_ids, - infer_state, - input_ids1=input_ids1, - infer_state1=infer_state1, - ) - else: - predict_logics, predict_logics1 = self.graph.replay( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 - ) - else: - predict_logics, predict_logics1 = self._overlap_tpsp_token_forward( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 + def _prefill( + self, + model_input: ModelInput, + ): + infer_state = self._create_inferstate(model_input) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input.b_req_idx, + model_input.b_seq_len, + infer_state.b_ready_cache_len, + model_input.max_len_in_batch, + infer_state.mem_index, + ) + + infer_state.init_some_extra_state(self, model_input.input_ids) + return self._context_forward(model_input.input_ids, infer_state) + + def _decode( + self, + model_input: ModelInput, + ) -> ModelOutput: + if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) + padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) + infer_state = self._create_inferstate(padded_model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, ) - return predict_logics, predict_logics1 + infer_state.init_some_extra_state(self, padded_model_input.input_ids) - @torch.no_grad() - def microbatch_overlap_prefill(self, batch: PrefillMicroBatch, batch1: PrefillMicroBatch): - assert batch.mem_indexes.is_cuda - assert batch1.mem_indexes.is_cuda - input_ids, input_ids1 = batch.input_ids, batch1.input_ids - - def create_inferstate(cur_batch: PrefillMicroBatch, batch_index): - infer_state = self.infer_state_class() - infer_state.is_prefill = True - infer_state.is_token_healing = self.is_token_healing - infer_state.return_all_prompt_logics = self.return_all_prompt_logics - infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache - infer_state.batch_size = cur_batch.batch_size - infer_state.total_token_num = cur_batch.total_token_num - infer_state.max_len_in_batch = cur_batch.max_len_in_batch - assert cur_batch.b_req_idx.shape[0] == cur_batch.b_seq_len.shape[0] - infer_state.b_req_idx = cur_batch.b_req_idx - infer_state.b_seq_len = cur_batch.b_seq_len - if cur_batch.b_ready_cache_len is not None: - infer_state.b_ready_cache_len = cur_batch.b_ready_cache_len - else: - infer_state.b_ready_cache_len = torch.zeros_like( - cur_batch.b_seq_len, dtype=cur_batch.b_seq_len.dtype, device=cur_batch.b_seq_len.device + if self.graph.need_capture(find_graph_batch_size): + infer_state.is_cuda_graph = True + model_output: ModelOutput = self.graph.capture_decode( + self._token_forward, padded_model_input.input_ids, infer_state ) - infer_state.multimodal_params = cur_batch.multimodal_params - infer_state.microbatch_index = batch_index - - infer_state.mem_manager = self.mem_manager - infer_state.req_manager = self.req_manager + else: + model_output: ModelOutput = self.graph.replay(padded_model_input.input_ids, infer_state) - infer_state.mem_index = cur_batch.mem_indexes - infer_state.kv_buffer_shapedtype = ( - (cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - self.data_type, + model_output = self._create_unpad_decode_model_output( + model_output, origin_batch_size=model_input.batch_size ) - infer_state.dist_group = dist_group_manager.get_group(batch_index) - init_req_to_token_indexes( + else: + infer_state = self._create_inferstate(model_input) + copy_kv_index_to_req( self.req_manager.req_to_token_indexs, - cur_batch.b_req_idx, - cur_batch.b_seq_len, - infer_state.b_ready_cache_len, - cur_batch.max_len_in_batch, + infer_state.b_req_idx, + infer_state.b_seq_len, infer_state.mem_index, ) - return infer_state - - infer_state = create_inferstate(batch, 0) - infer_state1 = create_inferstate(batch1, 1) - - infer_state.init_some_extra_state(self, input_ids) - infer_state1.init_some_extra_state(self, input_ids1) + infer_state.init_some_extra_state(self, model_input.input_ids) + model_output = self._token_forward(model_input.input_ids, infer_state) - predict_logics, predict_logics1 = self._overlap_tpsp_context_forward( - input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1 - ) - dist_group_manager.clear_deepep_buffer() - return predict_logics, predict_logics1 + return model_output @final def _context_forward(self, input_ids, infer_state: InferStateInfo): @@ -499,10 +393,17 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo): input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logics = post_method(input_embs, infer_state, self.pre_post_weight) + predict_logits = post_method(input_embs, infer_state, self.pre_post_weight) g_cache_manager.cache_env_out() - return predict_logics + + model_output = ModelOutput(logits=predict_logits) + + # 特殊模型特殊模式的额外输出 + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = input_embs + + return model_output @final def _token_forward(self, input_ids, infer_state: InferStateInfo): @@ -521,10 +422,160 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo): input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i]) post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index] - predict_logics = post_method(input_embs, infer_state, self.pre_post_weight) + predict_logits = post_method(input_embs, infer_state, self.pre_post_weight) + + if self.is_deepseekv3_mtp_mode: + graph_out_hiddens = g_cache_manager.alloc_tensor( + input_embs.shape, + data_type=input_embs.data_type, + is_graph_out=True, + microbatch_index=infer_state.microbatch_index, + graph_out_key=520, + ) + graph_out_hiddens.copy_(input_embs) + + g_cache_manager.cache_env_out() + + model_output = ModelOutput(logits=predict_logits) + + # 特殊模型特殊模式的额外输出 + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + return model_output + + @torch.no_grad() + def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput): + assert model_input0.mem_indexes.is_cuda + assert model_input1.mem_indexes.is_cuda + input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids + + infer_state0 = self._create_inferstate(model_input0, 0) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input0.b_req_idx, + model_input0.b_seq_len, + infer_state0.b_ready_cache_len, + model_input0.max_len_in_batch, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, input_ids0) + + infer_state1 = self._create_inferstate(model_input1, 1) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input1.b_req_idx, + model_input1.b_seq_len, + infer_state1.b_ready_cache_len, + model_input1.max_len_in_batch, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, input_ids1) + + model_output0, model_output1 = self._overlap_tpsp_context_forward( + input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 + ) + dist_group_manager.clear_deepep_buffer() + return model_output0, model_output1 + + @torch.no_grad() + def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): + assert model_input0.batch_size == model_input1.batch_size + assert model_input0.mem_indexes.is_cuda + assert model_input1.mem_indexes.is_cuda + + origin_batch_size = model_input0.batch_size + max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) + + if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) + padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) + infer_state0 = self._create_inferstate(padded_model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) + infer_state1 = self._create_inferstate(padded_model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + + if self.graph.need_capture(find_graph_batch_size): + infer_state0.is_cuda_graph = True + infer_state1.is_cuda_graph = True + + model_output0, model_output1 = self.graph.capture_decode( + self._overlap_tpsp_token_forward, + padded_model_input0.input_ids, + infer_state0, + input_ids1=padded_model_input1.input_ids, + infer_state1=infer_state1, + ) + else: + model_output0, model_output1 = self.graph.replay( + padded_model_input0.input_ids, + infer_state0, + input_ids1=padded_model_input1.input_ids, + infer_state1=infer_state1, + ) + model_output0 = self._create_unpad_decode_model_output(model_output0, origin_batch_size=origin_batch_size) + model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) + else: + infer_state0 = self._create_inferstate(model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) + infer_state0.init_some_extra_state(self, model_input0.input_ids) + infer_state1 = self._create_inferstate(model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) + infer_state1.init_some_extra_state(self, model_input1.input_ids) + + model_output0, model_output1 = self._overlap_tpsp_token_forward( + model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1 + ) + return model_output0, model_output1 + + @final + def _overlap_tpsp_context_forward( + self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo + ): + g_cache_manager.cache_env_in() + input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( + input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight + ) + for i in range(self.layers_num): + input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( + input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + ) + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( + input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight + ) g_cache_manager.cache_env_out() - return predict_logics + + model_output = ModelOutput(logits=predict_logits) + model_output1 = ModelOutput(logits=predict_logits1) + + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = input_embs + model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1 + + return model_output, model_output1 @final def _overlap_tpsp_token_forward( @@ -544,30 +595,38 @@ def _overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] ) - predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward( + predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight ) - g_cache_manager.cache_env_out() - return predict_logics, predict_logics1 - - @final - def _overlap_tpsp_context_forward( - self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo - ): - g_cache_manager.cache_env_in() - input_embs, input_embs1 = self.pre_infer.overlap_tpsp_context_forward( - input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight - ) - for i in range(self.layers_num): - input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward( - input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i] + if self.is_deepseekv3_mtp_mode: + graph_out_hiddens = g_cache_manager.alloc_tensor( + input_embs.shape, + data_type=input_embs.data_type, + is_graph_out=True, + microbatch_index=0, + graph_out_key=520, ) - predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward( - input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight - ) + graph_out_hiddens.copy_(input_embs) + graph_out_hiddens1 = g_cache_manager.alloc_tensor( + input_embs1.shape, + data_type=input_embs1.data_type, + is_graph_out=True, + microbatch_index=1, + graph_out_key=520, + ) + graph_out_hiddens1.copy_(input_embs1) + g_cache_manager.cache_env_out() - return predict_logics, predict_logics1 + + model_output = ModelOutput(logits=predict_logits) + model_output1 = ModelOutput(logits=predict_logits1) + + if self.is_deepseekv3_mtp_mode: + model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens + model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1 + + return model_output, model_output1 @final @torch.no_grad() @@ -590,20 +649,22 @@ def _check_max_len_infer(self): b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") total_token_num = self.batch_max_tokens - logics = self.forward( - 1, - total_token_num, - self.batch_max_tokens, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, + model_input = ModelInput( + batch_size=1, + total_token_num=total_token_num, + max_len_in_batch=self.batch_max_tokens, + input_ids=dummy_input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, is_prefill=True, - multimodal_params=[], + b_ready_cache_len=b_ready_cache_len, + ) + model_output = self.forward( + model_input, ) - prob_out = torch.softmax(logics, dim=-1) - logics = None + prob_out = torch.softmax(model_output.logits, dim=-1) + del model_output torch.argmax(prob_out, dim=1, keepdim=True) prob_out = None self.req_manager.free_all() diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py new file mode 100644 index 000000000..d9d131056 --- /dev/null +++ b/lightllm/common/basemodel/batch_objs.py @@ -0,0 +1,39 @@ +import torch +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelInput: + # 通用变量 + batch_size: int + total_token_num: int + max_len_in_batch: int + input_ids: torch.Tensor + mem_indexes: torch.Tensor + b_req_idx: torch.Tensor + b_seq_len: torch.Tensor + is_prefill: bool = False + b_ready_cache_len: torch.Tensor = None + multimodal_params: list = field(default_factory=list) + + # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 + # 的输入变量。只在特殊的模型模式下才会具体使用和生效。 + + # deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下 + # 的 draft 模型的输入 + deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + + +@dataclass +class ModelOutput: + # 通用变量 + logits: torch.Tensor + + # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 + # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 + + # deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model + # 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens + # 输入 + deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 41cd59039..271c307fb 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -1,10 +1,14 @@ import os import torch import copy +import bisect +from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from .infer_struct import InferStateInfo + logger = init_logger(__name__) @@ -17,15 +21,45 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch - self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + self.args = get_env_start_args() + self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + + # gen cuda graph batch_sizes + # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] + # and [graph_split_batch_size + graph_grow_step_size, + # graph_split_batch_size + 2 * graph_grow_step_size, ..., self.max_batch_size] + graph_split_batch_size = self.args.graph_split_batch_size + max_batch_size = self.max_batch_size + graph_grow_step_size = self.args.graph_grow_step_size + + batch_sizes = [i for i in range(1, min(graph_split_batch_size, max_batch_size) + 1)] + for _batch_size in range( + graph_split_batch_size + graph_grow_step_size, max_batch_size + 1, graph_grow_step_size + ): + batch_sizes.append(_batch_size) + self.cuda_graph_batch_sizes = batch_sizes + assert batch_sizes[-1] == self.max_batch_size + logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def need_capture(self, batch_size): - return batch_size not in self.graph + find_batch_size = self.find_closest_graph_batch_size(batch_size) + if find_batch_size is not None: + return find_batch_size not in self.graph + else: + assert False, "dead code" - def _capture_decode(self, decode_func, input_ids, infer_state): + def find_closest_graph_batch_size(self, batch_size): + index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) + if index < len(self.cuda_graph_batch_sizes): + find_batch_size = self.cuda_graph_batch_sizes[index] + return find_batch_size + else: + return None + + def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] @@ -46,12 +80,19 @@ def _capture_decode(self, decode_func, input_ids, infer_state): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): - predict_logics = decode_func(input_ids, infer_state) - self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics) + model_output = decode_func(input_ids, infer_state) + self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() - return predict_logics + return model_output - def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids1, infer_state1): + def _capture_decode_overlap( + self, + decode_func, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: torch.Tensor, + infer_state1: InferStateInfo, + ): dist_group: CustomProcessGroup = infer_state.dist_group dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() @@ -81,7 +122,14 @@ def _capture_decode_overlap(self, decode_func, input_ids, infer_state, input_ids graph_obj.replay() return predict_logics, predict_logics1 - def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, infer_state1=None): + def capture_decode( + self, + decode_func, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: Optional[torch.Tensor] = None, + infer_state1: Optional[torch.Tensor] = None, + ): """ Capture the cuda graph for the decoding stage. input_ids1 and infer_state1 is used for the overlap. @@ -92,15 +140,21 @@ def capture_decode(self, decode_func, input_ids, infer_state, input_ids1=None, i assert input_ids1 is None and infer_state1 is None return self._capture_decode(decode_func, input_ids, infer_state) - def _replay(self, input_ids, infer_state): + def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo): batch_size = input_ids.shape[0] - graph_obj, graph_input_ids, graph_infer_state, graph_predict_logics = self.graph[batch_size] + graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size] graph_input_ids.copy_(input_ids) graph_infer_state.copy_for_cuda_graph(infer_state) graph_obj.replay() - return graph_predict_logics + return graph_output - def _replay_overlap(self, input_ids, infer_state, input_ids1, infer_state1): + def _replay_overlap( + self, + input_ids: torch.Tensor, + infer_state: InferStateInfo, + input_ids1: torch.Tensor, + infer_state1: InferStateInfo, + ): batch_size = input_ids.shape[0] ( graph_obj, @@ -128,52 +182,44 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): @torch.no_grad() def warmup(self, model): logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.") - for batch_size in range(self.max_batch_size, 0, -1): - # dummy prefill - prefill_input_len = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + # for typing easy + from .basemodel import TpPartBaseModel + + model: TpPartBaseModel = model + + # prefill init padding req. + predict_id = self._warmup_prefill(model) + # decode cuda graph init + for batch_size in self.cuda_graph_batch_sizes[::-1]: + seq_len = 2 + total_token_num = batch_size * seq_len + max_len_in_batch = self.graph_max_len_in_batch + input_ids = torch.tensor([predict_id for _ in range(batch_size)], dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - total_token_num = prefill_input_len * batch_size - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - multimodal_params=[], + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) - mem_indexes = None - prob_out = torch.softmax(logics, dim=-1) - logics = None - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - prob_out = None - predict_ids = predict_ids.detach().cpu().numpy() - torch.cuda.empty_cache() + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len.fill_(seq_len) - # dummy decoding, capture the cudagraph - total_token_num += batch_size - b_seq_len += 1 - mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda() - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len + 1, - torch.from_numpy(predict_ids).cuda().reshape(-1), - mem_indexes, - b_req_idx, - b_seq_len, + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, is_prefill=False, + **self._gen_special_model_input(model, batch_size), ) - mem_indexes = None + model_output: ModelOutput = model.forward(model_input) + del model_output + del input_ids + del mem_indexes + del b_req_idx + del b_seq_len + model.mem_manager.free_all() model.req_manager.free_all() # release local tensors @@ -181,6 +227,7 @@ def warmup(self, model): if isinstance(var_value, torch.Tensor): del locals()[var_name] torch.cuda.empty_cache() + logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph." @@ -189,52 +236,37 @@ def warmup(self, model): @torch.no_grad() def warmup_overlap(self, model): logger.info("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it.") - for batch_size in range(self.max_batch_size, 0, -1): + # for typing easy + from .basemodel import TpPartBaseModel + + model: TpPartBaseModel = model + + predict_id = self._warmup_prefill(model) + + for batch_size in self.cuda_graph_batch_sizes[::-1]: decode_batches = [] for micro_batch_index in [0, 1]: - # dummy prefill - prefill_input_len = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + # dummy decoding, capture the cudagraph + seq_len = 2 + total_token_num = batch_size * seq_len + max_len_in_batch = self.graph_max_len_in_batch + input_ids = torch.tensor([predict_id for _ in range(batch_size)], dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( - [model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" - ) - mem_indexes = model.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - total_token_num = prefill_input_len * batch_size - logics = model.forward( - batch_size, - total_token_num, - prefill_input_len, - dummy_input_ids, - mem_indexes, - b_req_idx, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - multimodal_params=[], + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) - mem_indexes = None - prob_out = torch.softmax(logics, dim=-1) - logics = None - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - prob_out = None - predict_ids = predict_ids.detach().cpu().numpy() - torch.cuda.empty_cache() - - # dummy decoding, capture the cudagraph - total_token_num += batch_size - b_seq_len += 1 - mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda() + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len.fill_(seq_len) - micro_batch = DecodeMicroBatch( + micro_batch = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=prefill_input_len + 1, - input_ids=torch.from_numpy(predict_ids).cuda().reshape(-1), + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, mem_indexes=mem_indexes, b_req_idx=b_req_idx, b_seq_len=b_seq_len, + **self._gen_special_model_input(model, batch_size), ) decode_batches.append(micro_batch) @@ -242,6 +274,7 @@ def warmup_overlap(self, model): if isinstance(var_value, torch.Tensor): del locals()[var_name] torch.cuda.empty_cache() + _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) model.mem_manager.free_all() @@ -257,3 +290,63 @@ def warmup_overlap(self, model): f"Capture overlap cudagraph success, batch_size <={self.max_batch_size} " f"and max_len_in_batch <= {self.graph_max_len_in_batch} will infer with cudagraph." ) + + def _warmup_prefill(self, model) -> int: + from .basemodel import TpPartBaseModel + + model: TpPartBaseModel = model + + # prefill init padding req. + prefill_input_len = 1 + batch_size = 1 + dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor( + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + mem_indexes = torch.tensor( + [model.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + total_token_num = prefill_input_len * batch_size + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=prefill_input_len, + input_ids=dummy_input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + multimodal_params=[], + **self._gen_special_model_input(model, total_token_num), + ) + + model_output: ModelOutput = model.forward(model_input) + del dummy_input_ids + del b_req_idx + del mem_indexes + del b_seq_len + del b_ready_cache_len + prob_out = torch.softmax(model_output.logits, dim=-1) + del model_output + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + del prob_out + predict_ids = predict_ids.detach().cpu().numpy() + predict_id = int(predict_ids[0][0]) + torch.cuda.empty_cache() + return predict_id + + def _gen_special_model_input(self, model, batch_size): + special_model_input = {} + + is_deepseekv3_mtp_draft_model = "Deepseek3MTPModel" in str(model.__class__) + if is_deepseekv3_mtp_draft_model: + special_model_input["deepseekv3_mtp_draft_input_hiddens"] = torch.randn( + batch_size, model.config["hidden_size"], dtype=model.data_type, device="cuda" + ) + else: + special_model_input["deepseekv3_mtp_draft_input_hiddens"] = None + + return special_model_input diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index a946219fb..021de6843 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -2,7 +2,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager from lightllm.distributed import CustomProcessGroup -from typing import Tuple, Any +from typing import Tuple, Any, Optional from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params @@ -54,6 +54,15 @@ def __init__(self): self.max_q_seq_len: int = None self.max_kv_seq_len: int = None + # 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在 + # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 + # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 + + # deepseekv3 mtp draft model 使用的额外输入参数, + # 在开启 mtp_mode == deepseekv3 时,mtp draft model + # 的输入会用到,其他模型和场景都不会用到 + self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: ( @@ -82,7 +91,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ) = gen_decode_params(b_seq_len=self.b_seq_len) self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] - def copy_for_cuda_graph(self, new_infer_state): + def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): if isinstance(attr_value, torch.Tensor): attr_ = getattr(self, attr_name, None) diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 570654b55..0cf0c7ae9 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -44,7 +44,8 @@ def __del__(self): class CudaGraphCacheTensorManager: def __init__(self, cuda_graph_max_batch_size: int): self.cuda_graph_max_batch_size = cuda_graph_max_batch_size - self.graph_out_tensor_dict: Dict[int, torch.Tensor] = {} + # Dict[graph_out_key, Dict[microbatch_index, tensor_chache]] + self.graph_out_tensor_dict: Dict[int, Dict[int, torch.Tensor]] = collections.defaultdict(dict) self.managed_total_tensor_bytes = 0 return @@ -56,6 +57,7 @@ def alloc_tensor_for_cuda_graph( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: assert microbatch_index in [0, 1] if not is_graph_out: @@ -66,13 +68,16 @@ def alloc_tensor_for_cuda_graph( max_size = size // cur_batch_size * self.cuda_graph_max_batch_size # graph out tensor, 只有一个, 不需要进行引用计数管理 - if microbatch_index not in self.graph_out_tensor_dict: + microbatch_index_to_tensor_cache = self.graph_out_tensor_dict[graph_out_key] + + if microbatch_index not in microbatch_index_to_tensor_cache: graph_out_tensor = torch.empty((max_size,), dtype=data_type, device=device, requires_grad=False) logger.info(f"pid {os.getpid()} cuda graph alloc graph out mem {shape} {data_type} {size} {max_size}") self.managed_total_tensor_bytes += graph_out_tensor.element_size() * graph_out_tensor.numel() logger.info(f"cuda graph managed_total_tensor_bytes: {self.managed_total_tensor_bytes}") - self.graph_out_tensor_dict[microbatch_index] = graph_out_tensor - return self.graph_out_tensor_dict[microbatch_index][0:size].view(shape) + microbatch_index_to_tensor_cache[microbatch_index] = graph_out_tensor + + return self.graph_out_tensor_dict[graph_out_key][microbatch_index][0:size].view(shape) class CacheTensorManager: def __init__(self): @@ -119,6 +124,7 @@ def alloc_tensor( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: # shape 类型转换 if isinstance(shape, list): @@ -126,7 +132,13 @@ def alloc_tensor( # 是 cuda graph的时候,由cuda graph manager 接管 if self.is_cuda_graph: return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph( - self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out, microbatch_index + self.cuda_graph_cur_batch_size, + shape, + data_type, + device, + is_graph_out, + microbatch_index, + graph_out_key, ) # 回收可能消亡的 tensor @@ -191,6 +203,7 @@ def alloc_tensor( device: str = "cuda", is_graph_out: bool = False, microbatch_index: int = 0, + graph_out_key: int = 0, ) -> torch.Tensor: return torch.empty(shape, dtype=data_type, device=device, requires_grad=False) diff --git a/lightllm/common/basemodel/microbatch_overlap_objs.py b/lightllm/common/basemodel/microbatch_overlap_objs.py deleted file mode 100644 index 35835cd34..000000000 --- a/lightllm/common/basemodel/microbatch_overlap_objs.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -from dataclasses import dataclass - - -@dataclass -class DecodeMicroBatch: - batch_size: int - total_token_num: int - max_len_in_batch: int - input_ids: torch.Tensor - mem_indexes: torch.Tensor - b_req_idx: torch.Tensor - b_seq_len: torch.Tensor - - -@dataclass -class PrefillMicroBatch: - batch_size: int - total_token_num: int - max_len_in_batch: int - input_ids: torch.Tensor - mem_indexes: torch.Tensor - b_req_idx: torch.Tensor - b_seq_len: torch.Tensor - b_ready_cache_len: torch.Tensor - multimodal_params: list diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 9a61e5294..643ccd894 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -175,7 +175,7 @@ def update_reqs_token_counter( if not self.enable_gpu_buffer_for_out_token_id_counter: for req_obj, next_token_id in zip(req_objs, next_token_ids): - if req_obj.need_out_token_id_statistics: + if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0: req_obj.out_token_id_count[next_token_id] += 1 else: b_req_idx = torch.tensor( diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 021f7a123..e065adc40 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -3,6 +3,7 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.spec_info import SpeculativeDecodeAlgorithm class Deepseek2InferStateInfo(LlamaInferStateInfo): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index d91dd6ede..4e769d7ad 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -65,6 +65,7 @@ def __init__(self, kvargs): self.enable_flashinfer = ( get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode ) + self.mtp_layer_num = get_env_start_args().spec_step super().__init__(kvargs) return @@ -102,7 +103,7 @@ def _init_mem_manager(self): dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], - layer_num=self.config["num_hidden_layers"], + layer_num=self.config["num_hidden_layers"] + self.mtp_layer_num, mem_fraction=self.mem_fraction, ) return diff --git a/lightllm/models/deepseek_mtp/__init__.py b/lightllm/models/deepseek_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_infer/__init__.py b/lightllm/models/deepseek_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..66a41da73 --- /dev/null +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,60 @@ +import torch + +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward + + +class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): + """ """ + + def __init__(self, network_config, mode): + super().__init__(network_config, mode) + self.eps_ = network_config["rms_norm_eps"] + self.hidden_size = network_config["hidden_size"] + return + + def _mtp_context_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def _mtp_token_forward( + self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens + assert input_embdings.shape[0] == tgt_embdings.shape[0] + rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) + rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) + + cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) + + ans_logics = self.alloc_tensor( + (cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype + ) + torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) + return ans_logics + + def context_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_context_forward(input_embdings, infer_state, layer_weight) + + def token_forward( + self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight + ): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/deepseek_mtp/layer_weights/__init__.py b/lightllm/models/deepseek_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..f5b805647 --- /dev/null +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +import numpy as np +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def __init__(self, data_type, network_config, mode): + super().__init__(data_type, network_config, mode) + # 与DeepseekV3模型共享 + self.wte_weight_ = None + self.lm_head_weight_ = None + return + + def load_hf_weights(self, weights): + if "model.layers.0.eh_proj.weight" in weights: + self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() + if "model.layers.0.enorm.weight" in weights: + self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) + if "model.layers.0.hnorm.weight" in weights: + self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) + if "model.layers.0.shared_head.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return diff --git a/lightllm/models/deepseek_mtp/model.py b/lightllm/models/deepseek_mtp/model.py new file mode 100644 index 000000000..2e2e95187 --- /dev/null +++ b/lightllm/models/deepseek_mtp/model.py @@ -0,0 +1,46 @@ +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer +from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight +from lightllm.common.basemodel import TpPartBaseModel + + +class Deepseek3MTPModel(Deepseek2TpPartModel): + + pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight + pre_layer_infer_class = Deepseek3MTPPreLayerInfer + + def __init__(self, kvargs: dict): + self._pre_init(kvargs) + super().__init__(kvargs) + return + + def _pre_init(self, kvargs: dict): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mem_layer_start = kvargs.pop("mem_layer_start", 0) + return + + def _init_custom(self): + self._cos_cached = self.main_model._cos_cached + self._sin_cached = self.main_model._sin_cached + return + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + return + + def _init_mem_manager(self): + self.mem_manager = self.main_model.mem_manager + return + + def _init_weights(self): + super()._init_weights() + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_ + return + + def _init_infer_layer(self): + super()._init_infer_layer() + # reset the layer_num_ of the self.layers_infer + for layer in self.layers_infer: + layer.layer_num_ = layer.layer_num_ + self.mem_layer_start + return diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 202df6969..237ddde4e 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -65,8 +65,8 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo assert False, "Error State" def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings_dtype = input_embdings.dtype + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 528361e69..7f2da75b1 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -330,10 +330,31 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--graph_max_batch_size", type=int, - default=16, + default=64, help="""Maximum batch size that can be captured by the cuda graph for decodign stage. The default value is 8. It will turn into eagar mode if encounters a larger value.""", ) + parser.add_argument( + "--graph_split_batch_size", + type=int, + default=32, + help=""" + Controls the interval for generating CUDA graphs during decoding. + CUDA graphs will be generated continuously for values ranging from 1 up to the specified + graph_split_batch_size. For values from graph_split_batch_size to graph_max_batch_size, + a new CUDA graph will be generated for every increment of graph_grow_step_size. + Properly configuring this parameter can help optimize the performance of CUDA graph execution. + """, + ) + parser.add_argument( + "--graph_grow_step_size", + type=int, + default=8, + help=""" + For batch_size values from graph_split_batch_size to graph_max_batch_size, + a new CUDA graph will be generated for every increment of graph_grow_step_size. + """, + ) parser.add_argument( "--graph_max_len_in_batch", type=int, @@ -391,4 +412,26 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) + parser.add_argument( + "--mtp_mode", + choices=["deepseekv3", None], + default=None, + help="""supported mtp mode, None is not enable mtp, """, + ) + parser.add_argument( + "--mtp_draft_model_dir", + type=str, + default=None, + help="""Path to the draft model for the MTP multi-prediction feature, + used for loading the MTP multi-output token model.""", + ) + parser.add_argument( + "--mtp_step", + type=int, + default=0, + help="""Specifies the number of additional tokens to predict using the draft model. + Currently, this feature supports only the DeepSeekV3 model. + Increasing this value allows for more predictions, + but ensure that the model is compatible with the specified step count.""", + ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3ce8c9872..418e2715a 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -115,6 +115,11 @@ def normal_or_p_d_start(args): if args.diverse_mode: assert args.router_token_ratio == 0.0 + # mtp params check + if args.mtp_mode is not None: + assert args.mtp_draft_model_dir is not None + assert args.mtp_step > 0 + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index d4e69fd36..5594df6a0 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -2,3 +2,4 @@ from .req import Req, FinishStatus from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray +from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index c8d8476e5..b4a14923d 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -94,6 +94,8 @@ class Req(ctypes.Structure): ("reward_score", ctypes.c_float), # 请求回复累计概率和 ("cumlogprob", ctypes.c_float), + # mtp draft model 多输出命中接受的token数量 + ("mtp_accepted_token_num", ctypes.c_int), ] def get_str(self): @@ -145,6 +147,7 @@ def init( self.create_prompt_ids_shm_array() self.chunked_prefill_size = chunked_prefill_size self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids + self.mtp_accepted_token_num = 0 self.post_init() diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index e9a17d2d5..2098c9fc7 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -30,6 +30,8 @@ class StartArgs: running_max_req_size: int = field(default=1000) tp: int = field(default=1) dp: int = field(default=1) + nnodes: int = field(default=1) + node_rank: int = field(default=0) max_req_total_len: int = field(default=2048 + 1024) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) @@ -74,7 +76,9 @@ class StartArgs: visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) - graph_max_batch_size: int = field(default=16) + graph_max_batch_size: int = field(default=64) + graph_split_batch_size: int = field(default=32) + graph_grow_step_size: int = field(default=8) graph_max_len_in_batch: int = field(default=8192) quant_type: Optional[str] = field(default=None) quant_cfg: Optional[str] = field(default=None) @@ -85,3 +89,6 @@ class StartArgs: sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) + mtp_mode: Optional[str] = field(default=None) + mtp_draft_model_dir: Optional[str] = field(default=None) + mtp_step: int = field(default=0) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d25b91864..ebd9ed94b 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -542,6 +542,10 @@ async def _wait_to_token_package( x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" prompt_cache_ratio = prompt_cache_len / prompt_tokens + + mtp_avg_token_per_step = out_token_counter / ( + out_token_counter - metadata["mtp_accepted_token_num"] + ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") logger.info( f"X-Request-Id:{x_request_id} " @@ -552,6 +556,7 @@ async def _wait_to_token_package( f"prompt_token_num:{prompt_tokens} " f"prompt_cache_len:{prompt_cache_len} " f"prompt_cache_ratio:{prompt_cache_ratio} " + f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) if group_request_id < 0: # health 探测请求,不记录日志和监控 @@ -654,6 +659,7 @@ async def handle_loop(self): "special": special, "count_output_tokens": count_output_tokens, "prompt_cache_len": req.prompt_cache_len, + "mtp_accepted_token_num": req.mtp_accepted_token_num, } if self.args.return_all_prompt_logprobs: metadata.update(req.get_all_prompt_metadata()) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 2811e4228..45e82ff3d 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -22,7 +22,7 @@ from .req_queue import build_req_queue from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs.io_objs import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs @@ -35,11 +35,12 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name + logger = init_logger(__name__) class RouterManager: - def __init__(self, args, router_port, detokenization_port, metric_port): + def __init__(self, args: StartArgs, router_port, detokenization_port, metric_port): self.args = args self.model_weightdir = args.model_dir self.world_size = args.tp @@ -62,6 +63,8 @@ def __init__(self, args, router_port, detokenization_port, metric_port): # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None + self.mtp_step = args.mtp_step + # 共享变量,用于存储router端调度分析得到的机器负载信息 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node) for dp_index in range(self.dp_size_in_node): @@ -383,8 +386,7 @@ async def _prefill_batch(self, batch: Batch): self.overlap_event.set() await self.model_rpc_client.prefill(reqs) batch.filter_out_finished_req(self.shm_req_manager) - # 发个None包触发一下detokenization - self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + self._send_detokenization_pack() logger.debug(f"Prefill Batch: {batch.simple_log()} \n") self.metric_client.histogram_observe( @@ -400,8 +402,8 @@ async def _decode_batch(self, batch: Batch): # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。 if batch is not None: batch.filter_out_finished_req(self.shm_req_manager) - # 发个None包触发一下detokenization - self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + + self._send_detokenization_pack() self.metric_client.histogram_observe( "lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode" ) @@ -424,6 +426,14 @@ def _can_decode(self, batch: Batch, dp_index: int): batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num ) + def _send_detokenization_pack(self): + # 发 mtp_step + 1 个 None 包触发一下 detokenization, 因为在开启 mtp feature 以后,每一步 + # 生成的 token 数量最多为 mtp_step + 1 个,如果不及时触发 detokenization, 会带来一些性能 + # 损失 + for _ in range(self.mtp_step + 1): + self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + return + def get_used_tokens(self, dp_index): if not self.args.disable_dynamic_prompt_cache: return ( diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index dfdfd43de..f1076e565 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -17,6 +17,7 @@ from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.multimodal_params import MultimodalParams from lightllm.utils.custom_kernel_utis import custom_cat +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -260,6 +261,14 @@ def __init__( self.need_out_token_id_statistics = True self.out_token_id_count: Dict[int, int] = None + # mtp_gen_token_ids 用于处理一个请求可以通过mtp进行很多token的预先生成 + # 的技术,在没有开启 mtp 功能的时候,这个成员变量不会有任何的实际实用意义。 + # 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的 + # 步骤中需要重新进行校验。 + self.mtp_gen_token_ids: List[int] = [] + # 用于记录每一次 decode verify 接受的 mtp token 的数量 + self.mtp_step_accepted_token_num = 0 + def init_all(self): if self.initialized is False: self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index) @@ -314,11 +323,27 @@ def get_cur_total_len(self): def get_input_token_ids(self): return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] + def get_input_token_ids_shift(self, next_token_ids_cpu: np.ndarray, shift: int = 1): + origin_input_ids = self.get_input_token_ids() + input_ids = np.concatenate([origin_input_ids, next_token_ids_cpu, self.mtp_gen_token_ids]) + return input_ids[shift:] + def get_chuncked_input_token_ids(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) return self.shm_req.shm_prompt_ids.arr[0:chunked_end] + def get_chunked_input_token_ids_shift(self, next_token_ids_cpu: np.ndarray, shift: int = 1): + chunked_start = self.cur_kv_len + chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + is_last_chunk = chunked_end == self.get_cur_total_len() + # if the current chunk is not the last chunk, the main model has not generated the next token, + # so we should not use the next token ids. + if not is_last_chunk: + next_token_ids_cpu = [] + input_ids = self.get_input_token_ids_shift(next_token_ids_cpu, shift) + return input_ids[0:chunked_end] + def get_chuncked_input_token_len(self): chunked_start = self.cur_kv_len chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) @@ -330,6 +355,9 @@ def set_next_gen_token_id(self, next_token_id: int, logprob: float): self.shm_req.shm_logprobs.arr[index] = logprob return + def set_total_accepted_len(self): + self.shm_req.mtp_accepted_token_num += self.mtp_step_accepted_token_num + def get_last_gen_token(self): return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1] diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 4594eec28..d60fcf7c3 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -7,8 +7,18 @@ from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .dp_backend.impl import DPChunkedPrefillBackend +from .dp_backend.impl_mtp import DPChunkedPrefillWithMTPBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ChunckedPrefillForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode from .chunked_prefill.impl_for_xgrammar_mode import XgrammarBackend from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode +from .continues_batch.impl_mtp import ContinuesBatchWithMTPBackend + + +from .generic_pre_process import prepare_prefill_inputs +from .generic_pre_process import prepare_decode_inputs +from .generic_padded_pre_process import padded_prepare_prefill_inputs +from .generic_padded_pre_process import padded_prepare_decode_inputs +from .generic_padded_pre_process import padded_overlap_prepare_prefill_inputs +from .generic_padded_pre_process import padded_overlap_prepare_decode_inputs diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0d33163ac..e590a413b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -15,9 +15,10 @@ from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock +from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp @@ -33,26 +34,26 @@ def __init__(self) -> None: pass def init_model(self, kvargs): - self.args = kvargs.get("args", None) + self.args: StartArgs = kvargs.get("args", None) + assert self.args is not None # p d 分离模式下会有特殊的一些初始化, 所以需要传递 # 模式参数到模型的初始化过程中进行控制 - self.run_mode = "normal" if self.args is None else self.args.run_mode + self.run_mode = self.args.run_mode self.is_multimodal = False self.nnodes = self.args.nnodes self.node_rank = self.args.node_rank - self.tp_rank = kvargs["rank_id"] self.world_size = kvargs["world_size"] - self.dp_size = kvargs.get("dp_size", 1) + self.dp_size = self.args.dp # dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, self.dp_size // self.nnodes) self.load_way = kvargs["load_way"] self.mode = kvargs["mode"] - self.disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False) - self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None) - self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) - self.use_dynamic_prompt_cache = not kvargs.get("disable_dynamic_prompt_cache", False) + self.disable_chunked_prefill = self.args.disable_chunked_prefill + self.chunked_prefill_size = self.args.chunked_prefill_size + self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs + self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) - self.disable_cudagraph = kvargs.get("disable_cudagraph", False) + self.disable_cudagraph = self.args.disable_cudagraph self.cache = {} self.logger = init_logger(__name__) @@ -113,6 +114,7 @@ def init_model(self, kvargs): "run_mode": self.run_mode, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) + self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( RadixCache( @@ -260,6 +262,7 @@ def _post_handle( for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): req_obj: InferReq = req_obj shm_req = req_obj.shm_req + finish_status = req_obj.finish_status if is_chuncked_mode: new_kv_len = req_obj.get_chuncked_input_token_len() else: @@ -269,13 +272,8 @@ def _post_handle( if self.is_master_in_dp: shm_req.shm_cur_kv_len = req_obj.cur_kv_len - # 这个地方主要是为了提前判断是否存在abort的情况,如果abort了 - # 直接将请求放入finished 处理队列中。 - if req_obj.is_finished_or_aborted(): - finished_req_ids.append(shm_req.request_id) - continue - - # 对于没有到达需要输出 token 阶段的请求,直接略过 + # 对于没有到达需要输出 token 阶段的请求,直接略过, 说明还 + # 处于chuncked prefill kv 填充的阶段。 if req_obj.cur_kv_len < req_obj.get_cur_total_len(): continue @@ -283,13 +281,26 @@ def _post_handle( req_obj.set_next_gen_token_id(next_token_id, next_token_logprob) req_obj.cur_output_len += 1 + # 这里提前判定的主要作用是: + # 在 mtp mode 下,可以存在同一个 req 对象的多次处理, + # 在这种情况下, 如果前一步接收的mtp token 已经导致了请求 + # 达到了finished 状态,后续的请求就不再进行后续的复杂流程 + # 判断和处理,但是,因为 mtp 多请求还是导致了kv 的使用,所以 + # 还是需要更新对应的 input_tokens 和 cur_kv_len 信息,否则 + # 在 filter req 的时候,容易导致kv 管理的泄露和插入radix cache + # 的信息不完整等问题。 + if finish_status.is_finished(): + finished_req_ids.append(shm_req.request_id) + continue + + # 更新判断请求的 finished 状态 req_obj.update_finish_status(self.eos_id) if extra_post_req_handle_func is not None: extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob) # 判断是否已经满足生成结束条件。 - if req_obj.is_finished_or_aborted(): + if finish_status.is_finished(): finished_req_ids.append(shm_req.request_id) if self.is_master_in_dp: @@ -298,7 +309,7 @@ def _post_handle( # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 shm_req.shm_cur_output_len = req_obj.cur_output_len - if req_obj.finish_status.is_finished(): + if finish_status.is_finished(): shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 shm_req.finish_status = req_obj.finish_status @@ -308,8 +319,13 @@ def _post_handle( req_objs=run_reqs, next_token_ids=next_token_ids ) + # mtp_mode 模式下,因为存在重复对象,需要进行去重操作。 + if self.args.mtp_mode is not None: + finished_req_ids = list(set(finished_req_ids)) + if do_filter_finished_reqs: g_infer_context.filter(finished_req_ids) + return finished_req_ids # 一些可以复用的通用功能函数 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 8dad113d4..329c79981 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -38,37 +38,37 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py index 752d5748f..ef8bdb95f 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py @@ -41,8 +41,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -53,16 +54,17 @@ def decode(self): self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -73,7 +75,7 @@ def decode(self): self._post_handle( run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index ed092b42b..9c9dcc4f2 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -66,8 +66,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -91,16 +92,17 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_state_fsm, ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -123,7 +125,7 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_state_fsm, ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index 09d9f0fb8..d376aeafd 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -50,8 +50,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -77,16 +78,17 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -113,7 +115,7 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_tokenhealing_req_prefix_str, ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 1c7a3a36d..c75e15229 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -48,8 +48,9 @@ def decode(self): # 先 decode if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -75,16 +76,17 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_xgrammer_fsm, ) - logits = None + del model_output # 再 prefill if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0): if prefill_reqs: self.need_prefill_count -= 1 - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True ) @@ -109,7 +111,7 @@ def decode(self): do_filter_finished_reqs=False, extra_post_req_handle_func=self._update_xgrammer_fsm, ) - logits = None + del model_output self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) self.forward_step += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py index 052a180c0..a910a41ef 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py @@ -30,10 +30,11 @@ def decode(self): g_infer_context.filter_reqs(aborted_reqs) if prefill_reqs: - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True @@ -48,8 +49,9 @@ def decode(self): ) if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py index db1e680ec..8fddc31b9 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py @@ -17,12 +17,15 @@ def prefill(self, run_reqs: List[Tuple]): req_ids = self._init_reqs(run_reqs, init_req_obj=True) req_objs = self._trans_req_ids_to_req_objs(req_ids) - kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal) - - prompt_all_logits = self.model.forward(**kwargs) - input_ids = kwargs["input_ids"] - b_ready_cache_len = kwargs["b_ready_cache_len"] - b_seq_len = kwargs["b_seq_len"] + model_input, run_reqs = prepare_prefill_inputs( + req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + ) + + model_output = self.model.forward(model_input) + prompt_all_logits = model_output.logits + input_ids = model_input.input_ids + b_ready_cache_len = model_input.b_ready_cache_len + b_seq_len = model_input.b_seq_len last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 logits = prompt_all_logits[last_index, :] diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py index 569da5150..0b04ca8e7 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py @@ -14,9 +14,12 @@ def prefill(self, reqs: List[Tuple]): req_ids = self._init_reqs(reqs, init_req_obj=True) req_objs = self._trans_req_ids_to_req_objs(req_ids) - kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal) + model_input, run_reqs = prepare_prefill_inputs( + req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + ) - scores: torch.Tensor = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + scores: torch.Tensor = model_output.logits scores = scores.unsqueeze(1).detach().cpu().float().numpy() next_token_id = 1 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py new file mode 100644 index 000000000..eb552abff --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py @@ -0,0 +1,225 @@ +import os +import torch +import numpy as np +from typing import List, Tuple +from transformers.configuration_utils import PretrainedConfig +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( + prepare_prefill_inputs, + prepare_decode_inputs, +) +from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( + prepare_mtp_prefill_inputs, +) +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelOutput + + +logger = init_logger(__name__) + + +class ContinuesBatchWithMTPBackend(ModeBackend): + def __init__(self) -> None: + super().__init__() + + # 支持双模型 + def init_model(self, kvargs): + super().init_model(kvargs) + self._init_mtp_draft_model(kvargs) + return + + def _init_mtp_draft_model(self, main_kvargs: dict): + self.mtp_step = self.args.mtp_step + self.draft_models = [] + + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" + for i in range(self.mtp_step): + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + mtp_model_kvargs = { + "weight_dir": self.args.mtp_draft_model_dir, + "max_total_token_num": self.model.mem_manager.size, + "load_way": main_kvargs["load_way"], + "mode": main_kvargs["mode"], + "max_req_num": main_kvargs.get("max_req_num", 1000), + "max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5), + "is_token_healing": False, + "return_all_prompt_logics": False, + "use_dynamic_prompt_cache": self.use_dynamic_prompt_cache, + "disable_chunked_prefill": self.disable_chunked_prefill, + "data_type": main_kvargs.get("data_type", "float16"), + "graph_max_batch_size": main_kvargs.get("graph_max_batch_size", 16), + "graph_max_len_in_batch": main_kvargs.get("graph_max_len_in_batch", 8196), + "disable_cudagraph": main_kvargs.get("disable_cudagraph", False), + "mem_fraction": main_kvargs["mem_fraction"], + "batch_max_tokens": main_kvargs.get("batch_max_tokens", None), + "quant_type": main_kvargs.get("quant_type", None), + "quant_cfg": main_kvargs.get("quant_cfg", None), + "run_mode": "normal", + "main_model": self.model, + "mem_layer_start": self.model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir) + assert mtp_model_cfg["model_type"] == "deepseek_v3" + assert mtp_model_cfg["architectures"][0] == "DeepseekV3ForCausalLMNextN" + self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + + self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids + ) + + if aborted_reqs: + g_infer_context.filter_reqs(aborted_reqs) + + if prefill_reqs: + model_input, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal + ) + model_output = self.model.forward(model_input) + + self._overlap_req_init_and_filter( + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ) + + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # spec prefill: MTP + draft_model_input = model_input + draft_model_input.hidden_states = model_output.hidden_states + for draft_model_idx in range(self.spec_step): + draft_model_input = prepare_mtp_prefill_inputs( + prefill_reqs, + model_input, + next_token_ids_cpu, + draft_model_idx, + is_chunked_mode=not self.disable_chunked_prefill, + ) + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + _, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_model_output) + model_input.hidden_states = draft_model_output.hidden_states + self._save_prefill_draft_tokens(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + + self._post_handle( + run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=not self.disable_chunked_prefill, + do_filter_finished_reqs=False, + ) + + if decode_reqs: + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + assert model_output.logits.shape[0] % self.spec_stride == 0 + + self._overlap_req_init_and_filter( + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True + ) + + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # verify + mem_indexes_cpu = model_input.mem_indexes.cpu() + accepted_reqs, accepted_index, need_free_mem_indexes = self._verify( + next_token_ids_cpu, run_reqs, mem_indexes_cpu + ) + self._post_handle( + accepted_reqs, + next_token_ids_cpu[accepted_index], + next_token_logprobs_cpu[accepted_index], + is_chuncked_mode=not self.disable_chunked_prefill, + do_filter_finished_reqs=False, + ) + + # share some inference info with the main model + draft_model_input = model_input + draft_model_input.input_ids = next_token_ids + draft_model_input.hidden_states = model_output.hidden_states + # process the draft model output + for draft_model_idx in range(self.spec_step): + # spec decode: MTP + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_model_output) + # prepare inputs for the next draft model + draft_model_input.input_ids = draft_next_token_ids + draft_model_input.hidden_states = draft_model_output.hidden_states + self._save_decode_draft_token_ids(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def _gen_draft_tokens(self, model_output: ModelOutput): + logits = model_output.logits + probs = torch.softmax(logits, dim=-1) + draft_next_token_ids = torch.argmax(probs, dim=-1) + return draft_next_token_ids, draft_next_token_ids.detach().cpu().numpy() + + def _verify(self, next_token_ids: torch.Tensor, run_reqs: List[InferReq], draft_mem_indexes: torch.Tensor): + accepted_reqs = [] + accepted_index = [] + need_free_mem_indexes = [] + assert next_token_ids.shape[0] % self.spec_stride == 0 + batch_size = next_token_ids.shape[0] // self.spec_stride + for b in range(batch_size): + req: InferReq = run_reqs[b % self.spec_stride] + req_start_idx = b * self.spec_stride + req_end_idx = (b + 1) * self.spec_stride + # step_idx==0 means the output of the main model + for step_idx in range(self.spec_stride): + if step_idx == 0 or req.mtp_gen_token_ids[step_idx - 1] == next_token_ids[req_start_idx + step_idx - 1]: + accepted_reqs.append(req) + accepted_index.append(req_start_idx + step_idx) + req.mtp_step_accepted_token_num += 1 if step_idx != 0 else 0 + else: + need_free_mem_indexes.extend(draft_mem_indexes[req_start_idx + step_idx : req_end_idx]) + break + # reset the mtp status + req.mtp_gen_token_ids = [] + return accepted_reqs, accepted_index, need_free_mem_indexes + + def _save_prefill_draft_tokens( + self, draft_next_token_ids: torch.Tensor, run_reqs: List[InferReq], draft_model_idx: int + ): + batch_size = len(run_reqs) + for i in range(batch_size): + req = run_reqs[i] + # if the request has unfinished chunked tokens, skip it. + if req.get_chuncked_input_token_len() < req.get_cur_total_len(): + continue + req.mtp_gen_token_ids.append(draft_next_token_ids[i]) + + def _save_decode_draft_token_ids( + self, draft_next_token_ids: torch.Tensor, run_reqs: List[InferReq], draft_model_idx: int + ): + batch_size = len(run_reqs) // self.spec_stride + for i in range(batch_size): + req: InferReq = run_reqs[self.spec_stride * i] + # append the draft token + req.mtp_gen_token_ids.append(draft_next_token_ids[i * self.spec_stride + req.mtp_step_accepted_token_num]) + # reset the mtp status + if draft_model_idx == self.spec_step - 1: + if self.is_master_in_dp: + req.set_total_accepted_len() + req.mtp_step_accepted_token_num = 0 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py index d75f44f20..7d3f2666c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py @@ -69,8 +69,9 @@ def decode(self): if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py index 614bc9bc5..240eba0a2 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_for_dp.py @@ -1,24 +1,13 @@ -import os import torch import torch.multiprocessing as mp import torch.distributed as dist -import threading -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend from typing import List, Tuple -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams -from lightllm.server.core.objs import FinishStatus -from lightllm.server.pd_io_struct import UpKVStatus +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from .up_status import UpStatusManager -from rpyc.utils.server import ThreadedServer -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask -from lightllm.utils.device_utils import kv_trans_use_p2p +from lightllm.server.router.model_infer.mode_backend import padded_prepare_prefill_inputs from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from .decode_impl import ContinuesBatchBackendForDecodeNode +from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend logger = init_logger(__name__) @@ -36,10 +25,8 @@ def init_custom(self): # 在推理的时候至少是两个token,1个是已经有kv的token,一个是等待计算kv的token,然后生成第三个token,这几个 # token 实际引用的都是 g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,但是需要初始化排除 # nan 值,避免后续构建的fake请求在计算的过程中出现计算错误。 - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) - self.model.forward(**kwargs) + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], is_multimodal=self.is_multimodal) + self.model.forward(model_input) assert len(run_reqs) == 0 and padded_req_num == 1 return @@ -69,51 +56,21 @@ def decode(self): return def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import padded_prepare_decode_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( - decode_reqs, max_decode_num, is_multimodal=self.is_multimodal + DPChunkedPrefillBackend.normal_decode( + self, + decode_reqs=decode_reqs, + max_decode_num=max_decode_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, ) - logits = self.model.forward(**kwargs) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - if len(run_reqs) != 0: - logits = logits[0 : len(run_reqs), :] - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) return def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_overlap_prepare_decode_inputs, + DPChunkedPrefillBackend.overlap_decode( + self, + decode_reqs=decode_reqs, + max_decode_num=max_decode_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, ) - - ( - micro_batch, - run_reqs, - padded_req_num, - micro_batch1, - run_reqs1, - padded_req_num1, - ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) - - logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - req_num, req_num1 = len(run_reqs), len(run_reqs1) - all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) - - all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) - all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) - - all_run_reqs = run_reqs + run_reqs1 - if all_run_reqs: - next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False - ) return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index 2652f4a33..a8084f2a1 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -72,11 +72,12 @@ def decode(self): self._filter_reqs(ok_finished_reqs) if prefill_reqs: - kwargs, run_reqs = prepare_prefill_inputs( + model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py index a728b6b03..04f436b80 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl_for_dp_chuncked.py @@ -1,22 +1,9 @@ -import os -import time -import threading import torch import torch.multiprocessing as mp import torch.distributed as dist from typing import List, Tuple -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend -from lightllm.utils.infer_utils import set_random_seed -from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context -from lightllm.server.core.objs import FinishStatus -from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo from lightllm.utils.log_utils import init_logger -from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from rpyc.utils.server import ThreadedServer -from .prefill_task_cache import g_kv_move_task_cache -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from .prefill_impl import ChunckedPrefillForPrefillNode from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend @@ -68,50 +55,21 @@ def decode(self): return def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_prepare_prefill_inputs, + DPChunkedPrefillBackend.normal_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + max_prefill_num=max_prefill_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, ) - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( - prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal - ) - logits = self.model.forward(**kwargs) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - if len(run_reqs) != 0: - logits = logits[0 : len(run_reqs), :] - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) + return def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from lightllm.server.router.model_infer.mode_backend.dp_backend.pre_process import ( - padded_overlap_prepare_prefill_inputs, + DPChunkedPrefillBackend.overlap_prefill_reqs( + self, + prefill_reqs=prefill_reqs, + max_prefill_num=max_prefill_num, + uninit_reqs=uninit_reqs, + ok_finished_reqs=ok_finished_reqs, ) - - ( - micro_batch, - run_reqs, - padded_req_num, - micro_batch1, - run_reqs1, - padded_req_num1, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) - self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) - req_num, req_num1 = len(run_reqs), len(run_reqs1) - all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) - - all_logits[0:req_num, :].copy_(logits[0:req_num, :], non_blocking=True) - all_logits[req_num : (req_num + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True) - - all_run_reqs = run_reqs + run_reqs1 - if all_run_reqs: - next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - self._post_handle( - all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False - ) + return diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 18b73ee23..4480c2536 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -69,10 +69,11 @@ def decode(self): for req in prefill_reqs if convert_sub_id_to_group_id(req.req_id) == req.req_id ] - kwargs, group_run_reqs = prepare_prefill_inputs( + model_input, group_run_reqs = prepare_prefill_inputs( group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output = self.model.forward(model_input) + logits = model_output.logits uninit_req_ids = [req.req_id for req in uninit_reqs] self._overlap_req_init_and_filter( @@ -90,8 +91,9 @@ def decode(self): ) if decode_reqs: - kwargs, run_reqs = prepare_decode_inputs(decode_reqs) - logits = self.model.forward(**kwargs) + model_input, run_reqs = prepare_decode_inputs(decode_reqs) + model_output = self.model.forward(model_input) + logits = model_output.logits uninit_req_ids = [req.req_id for req in uninit_reqs] self._overlap_req_init_and_filter( uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c9ecafdb2..f6ed334f8 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -3,6 +3,7 @@ import numpy as np from typing import List, Tuple from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.common.basemodel.batch_objs import ModelOutput from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams @@ -10,6 +11,10 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend import padded_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend import padded_overlap_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend import padded_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend import padded_overlap_prepare_decode_inputs class DPChunkedPrefillBackend(ModeBackend): @@ -29,10 +34,8 @@ def init_custom(self): # 在推理的时候至少是两个token,1个是已经有kv的token,一个是等待计算kv的token,然后生成第三个token,这几个 # token 实际引用的都是 g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX,但是需要初始化排除 # nan 值,避免后续构建的fake请求在计算的过程中出现计算错误。 - from .pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], 1, is_multimodal=self.is_multimodal) - self.model.forward(**kwargs) + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], is_multimodal=self.is_multimodal) + self.model.forward(model_input) assert len(run_reqs) == 0 and padded_req_num == 1 return @@ -71,12 +74,11 @@ def decode(self): return def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_prepare_prefill_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_prefill_inputs( - prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output: ModelOutput = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) if len(run_reqs) != 0: logits = logits[0 : len(run_reqs), :] @@ -89,12 +91,11 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int return def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_prepare_decode_inputs - - kwargs, run_reqs, padded_req_num = padded_prepare_decode_inputs( - decode_reqs, max_decode_num, is_multimodal=self.is_multimodal + model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, is_multimodal=self.is_multimodal ) - logits = self.model.forward(**kwargs) + model_output: ModelOutput = self.model.forward(model_input) + logits = model_output.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) @@ -109,17 +110,17 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit logits = None def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_overlap_prepare_decode_inputs - ( - micro_batch, + micro_input, run_reqs, padded_req_num, - micro_batch1, + micro_input1, run_reqs1, padded_req_num1, - ) = padded_overlap_prepare_decode_inputs(decode_reqs, max_decode_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_decode(micro_batch, micro_batch1) + ) = padded_overlap_prepare_decode_inputs(decode_reqs, is_multimodal=self.is_multimodal) + model_output, model_output1 = self.model.microbatch_overlap_decode(micro_input, micro_input1) + logits = model_output.logits + logits1 = model_output1.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) req_num, req_num1 = len(run_reqs), len(run_reqs1) all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) @@ -138,17 +139,17 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini return def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): - from .pre_process import padded_overlap_prepare_prefill_inputs - ( - micro_batch, + micro_input, run_reqs, padded_req_num, - micro_batch1, + micro_input1, run_reqs1, padded_req_num1, - ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, max_prefill_num, is_multimodal=self.is_multimodal) - logits, logits1 = self.model.microbatch_overlap_prefill(micro_batch, micro_batch1) + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + model_output, model_output1 = self.model.microbatch_overlap_prefill(micro_input, micro_input1) + logits = model_output.logits + logits1 = model_output1.logits self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) req_num, req_num1 = len(run_reqs), len(run_reqs1) all_logits = torch.empty((req_num + req_num1, logits.shape[1]), dtype=logits.dtype, device=logits.device) diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py new file mode 100644 index 000000000..12dac4ff6 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl_mtp.py @@ -0,0 +1,333 @@ +import torch +import torch.distributed as dist +import numpy as np +import triton +from typing import List, Tuple +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.utils.infer_utils import set_random_seed +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, InferSamplingParams +from lightllm.server.core.objs import FinishStatus +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl_mtp import ContinuesBatchWithMTPBackend +from lightllm.server.router.model_infer.mode_backend import padded_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend import padded_overlap_prepare_prefill_inputs +from lightllm.server.router.model_infer.mode_backend import padded_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend import padded_overlap_prepare_decode_inputs +from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( + prepare_mtp_prefill_inputs, +) +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.infer_lock import g_infer_state_lock + + +class DPChunkedPrefillWithMTPBackend(ContinuesBatchWithMTPBackend): + def __init__(self) -> None: + super().__init__() + self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap + self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap + pass + + def init_model(self, kvargs): + super().init_model(kvargs) + + def init_custom(self): + self.reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs([], is_multimodal=self.is_multimodal) + self.model.forward(model_input) + assert len(run_reqs) == 0 and padded_req_num == 1 + return + + def prefill(self, reqs: List[Tuple]): + self._init_reqs(reqs, init_req_obj=False) + return + + def decode(self): + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( + g_infer_context.infer_req_ids + ) + + if aborted_reqs: + g_infer_context.filter_reqs(aborted_reqs) + + current_dp_prefill_num = len(prefill_reqs) + self.reduce_tensor.fill_(current_dp_prefill_num) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_prefill_num = self.reduce_tensor.item() + if max_prefill_num != 0: + if not self.enable_prefill_microbatch_overlap: + self.normal_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_prefill_reqs(prefill_reqs, max_prefill_num, uninit_reqs, ok_finished_reqs) + + self.reduce_tensor.fill_(len(decode_reqs)) + dist.all_reduce(self.reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + max_decode_num = self.reduce_tensor.item() + if max_decode_num != 0: + if not self.enable_decode_microbatch_overlap: + self.normal_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + else: + self.overlap_decode(decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + return + + def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + # main model prefill + model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + prefill_reqs, is_multimodal=self.is_multimodal + ) + model_output: ModelOutput = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + next_token_ids_cpu = [] + if len(run_reqs) != 0: + next_token_ids, next_token_probs = sample(model_output.logits[: len(run_reqs)], run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # spec prefill: MTP + draft_model_input = model_input + draft_model_input.hidden_states = model_output.hidden_states + + for draft_model_idx in range(self.spec_step): + draft_model_input = prepare_mtp_prefill_inputs( + prefill_reqs, + model_input, + next_token_ids_cpu, + draft_model_idx, + is_chunked_mode=True, + padded_req_num=padded_req_num, + ) + + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + _, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_model_output) + model_input.hidden_states = draft_model_output.hidden_states + self._save_prefill_draft_tokens(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + + if len(run_reqs) != 0: + self._post_handle( + run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + decode_reqs, is_multimodal=self.is_multimodal + ) + # main model decode + model_output = self.model.forward(model_input) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + next_token_ids = torch.empty((0,), dtype=torch.int64, device="cuda") + need_free_mem_indexes = [] + if len(run_reqs) != 0: + next_token_ids, next_token_probs = sample(model_output.logits[: len(run_reqs)], run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # verify + mem_indexes_cpu = model_input.mem_indexes.cpu() + accepted_reqs, accepted_index, need_free_mem_indexes = self._verify( + next_token_ids_cpu[: len(run_reqs)], run_reqs, mem_indexes_cpu + ) + self._post_handle( + accepted_reqs, + next_token_ids_cpu[accepted_index], + next_token_logprobs_cpu[accepted_index], + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + if padded_req_num != 0: + next_token_ids = torch.cat( + [next_token_ids, torch.ones((padded_req_num,), dtype=torch.int64, device="cuda")], dim=0 + ) + + # share some inference info with the main model + draft_model_input = model_input + draft_model_input.input_ids = next_token_ids + draft_model_input.hidden_states = model_output.hidden_states + # process the draft model output + for draft_model_idx in range(self.spec_step): + # spec decode: MTP + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) + draft_next_token_ids, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_model_output) + self._save_decode_draft_token_ids(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + draft_model_input.input_ids = draft_next_token_ids + draft_model_input.hidden_states = draft_model_output.hidden_states + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + + def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit_reqs, ok_finished_reqs): + ( + micro_input, + run_reqs, + padded_req_num, + micro_input1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_decode_inputs(decode_reqs, is_multimodal=self.is_multimodal) + micro_output, micro_output1 = self.model.microbatch_overlap_decode(micro_input, micro_input1) + + assert micro_output.logits.shape[0] % self.spec_stride == 0 + assert micro_output1.logits.shape[0] % self.spec_stride == 0 + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_run_reqs = run_reqs + run_reqs1 + next_token_ids = torch.empty((0,), dtype=torch.int64, device="cuda") + next_token_ids_cpu = [] + need_free_mem_indexes = [] + if len(all_run_reqs) != 0: + all_logits = torch.empty( + (req_num + req_num1, micro_output.logits.shape[1]), + dtype=micro_output.logits.dtype, + device=micro_output.logits.device, + ) + + all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True) + + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + micro_mem_indexes_cpu = micro_input.mem_indexes.cpu() + micro_mem_indexes_cpu1 = micro_input1.mem_indexes.cpu() + mem_indexes_cpu = torch.cat((micro_mem_indexes_cpu, micro_mem_indexes_cpu1), dim=0) + + # verify + accepted_reqs, accepted_index, need_free_mem_indexes = self._verify( + next_token_ids_cpu, all_run_reqs, mem_indexes_cpu + ) + + self._post_handle( + accepted_reqs, + next_token_ids_cpu[accepted_index], + next_token_logprobs_cpu[accepted_index], + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + # share some inference info with the main model + draft_micro_input, draft_micro_input1 = micro_input, micro_input1 + + draft_micro_input.input_ids = next_token_ids[:req_num] + draft_micro_input.hidden_states = micro_output.hidden_states + draft_micro_input1.input_ids = next_token_ids[req_num:] + draft_micro_input1.hidden_states = micro_output1.hidden_states + + if padded_req_num != 0: + draft_micro_input.input_ids = torch.cat( + [draft_micro_input.input_ids, torch.ones((padded_req_num,), dtype=torch.int64, device="cuda")], dim=0 + ) + if padded_req_num1 != 0: + draft_micro_input1.input_ids = torch.cat( + [draft_micro_input1.input_ids, torch.ones((padded_req_num1,), dtype=torch.int64, device="cuda")], dim=0 + ) + + # process the draft model output + for draft_model_idx in range(self.spec_step): + # spec decode: MTP + draft_micro_output, draft_micro_output1 = self.draft_models[draft_model_idx].microbatch_overlap_decode( + draft_micro_input, draft_micro_input1 + ) + + draft_next_token_ids, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_micro_output) + draft_next_token_ids1, draft_next_token_ids_cpu1 = self._gen_draft_tokens(draft_micro_output1) + self._save_decode_draft_token_ids(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + self._save_decode_draft_token_ids(draft_next_token_ids_cpu1, run_reqs1, draft_model_idx) + # prepare inputs for the next draft model + draft_micro_input.input_ids = draft_next_token_ids + draft_micro_input.hidden_states = draft_micro_output.hidden_states + draft_micro_input1.input_ids = draft_next_token_ids1 + draft_micro_input1.hidden_states = draft_micro_output1.hidden_states + + if need_free_mem_indexes: + g_infer_state_lock.acquire() + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) + g_infer_state_lock.release() + return + + def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs): + ( + micro_input, + run_reqs, + padded_req_num, + micro_input1, + run_reqs1, + padded_req_num1, + ) = padded_overlap_prepare_prefill_inputs(prefill_reqs, is_multimodal=self.is_multimodal) + + micro_output, micro_output1 = self.model.microbatch_overlap_prefill(micro_input, micro_input1) + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + + req_num, req_num1 = len(run_reqs), len(run_reqs1) + all_run_reqs = run_reqs + run_reqs1 + next_token_ids_cpu = [] + if len(all_run_reqs) != 0: + all_logits = torch.empty( + (len(all_run_reqs), micro_output.logits.shape[1]), + dtype=micro_output.logits.dtype, + device=micro_output.logits.device, + ) + + all_logits[0:req_num, :].copy_(micro_output.logits[0:req_num, :], non_blocking=True) + all_logits[req_num : (req_num + req_num1), :].copy_(micro_output1.logits[0:req_num1, :], non_blocking=True) + + next_token_ids, next_token_probs = sample(all_logits, all_run_reqs, self.eos_id) + next_token_ids_cpu = next_token_ids.detach().cpu().numpy() + next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy() + + # spec prefill: MTP + draft_micro_input, draft_micro_input1 = micro_input, micro_input1 + draft_micro_input.hidden_states = micro_output.hidden_states + draft_micro_input1.hidden_states = micro_output1.hidden_states + + for draft_model_idx in range(self.spec_step): + + draft_micro_input = prepare_mtp_prefill_inputs( + run_reqs, + draft_micro_input, + next_token_ids_cpu[0:req_num], + draft_model_idx, + is_chunked_mode=True, + padded_req_num=padded_req_num, + ) + + draft_micro_input1 = prepare_mtp_prefill_inputs( + run_reqs1, + draft_micro_input1, + next_token_ids_cpu[req_num:], + draft_model_idx, + is_chunked_mode=True, + padded_req_num=padded_req_num1, + ) + + draft_micro_output, draft_micro_output1 = self.draft_models[draft_model_idx].microbatch_overlap_prefill( + draft_micro_input, draft_micro_input1 + ) + _, draft_next_token_ids_cpu = self._gen_draft_tokens(draft_micro_output) + _, draft_next_token_ids_cpu1 = self._gen_draft_tokens(draft_micro_output1) + self._save_prefill_draft_tokens(draft_next_token_ids_cpu, run_reqs, draft_model_idx) + self._save_prefill_draft_tokens(draft_next_token_ids_cpu1, run_reqs1, draft_model_idx) + + if len(all_run_reqs) != 0: + self._post_handle( + all_run_reqs, + next_token_ids_cpu, + next_token_logprobs_cpu, + is_chuncked_mode=True, + do_filter_finished_reqs=False, + ) + + return diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py deleted file mode 100644 index f249757d9..000000000 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/pre_process.py +++ /dev/null @@ -1,310 +0,0 @@ -import torch -import torch.distributed as dist -import numpy as np -import triton -from typing import List -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.infer_utils import calculate_time -from lightllm.common.mem_manager import MemoryManager -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch - - -def padded_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False): - assert max_prefill_num != 0 - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - # 当前 dp 没有请求的时候,需要进行 dp 操作。 - padded_req_num = 1 if len(req_objs) == 0 else 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - batch_multimodal_params = [] - b_ready_cache_len = [] - for req in req_objs: - - run_reqs.append(req) - batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) - - input_token_ids = req.get_chuncked_input_token_ids() - seq_len = len(input_token_ids) - input_token_len = seq_len - req.cur_kv_len - input_id = input_token_ids[req.cur_kv_len :] - - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) - b_ready_cache_len.append(req.cur_kv_len) - - # padding fake req for prefill - for _ in range(padded_req_num): - input_ids.append([1]) - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(1) - b_ready_cache_len.append(0) - nopad_total_token_num += 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1) - - input_ids = np.concatenate(input_ids, dtype=np.int64) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - if padded_req_num > 0: - padding_mem_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_mem_indexs), dim=0) - - kwargs = { - "batch_size": nopad_b_seq_len.shape[0], - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "b_ready_cache_len": b_ready_cache_len, - "is_prefill": True, - } - if is_multimodal: - kwargs["multimodal_params"] = batch_multimodal_params - - return kwargs, run_reqs, padded_req_num - - -def padded_prepare_decode_inputs(req_objs: List[InferReq], max_decode_num: int, is_multimodal=False): - assert max_decode_num != 0 - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - padded_req_num = 1 if len(req_objs) == 0 else 0 - for req in req_objs: - run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) - input_token_ids = req.get_input_token_ids() - input_id = input_token_ids[-1] - seq_len = len(input_token_ids) - assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - seq_len = 2 - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(seq_len) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - kwargs = { - "batch_size": nopad_b_seq_len.shape[0], - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "is_prefill": False, - } - return kwargs, run_reqs, padded_req_num - - -def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq], max_decode_num: int, is_multimodal=False): - assert max_decode_num != 0 - micro_batch_size = triton.cdiv(max_decode_num, 2) - micro_batch1_req_num = triton.cdiv(len(req_objs), 2) - micro_batch, run_reqs, padded_req_num = _padded_prepare_decode_micro_batch( - req_objs[0:micro_batch1_req_num], micro_batch_size, is_multimodal=is_multimodal - ) - micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_decode_micro_batch( - req_objs[micro_batch1_req_num:], micro_batch_size, is_multimodal=is_multimodal - ) - - return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1 - - -def _padded_prepare_decode_micro_batch(req_objs: List[InferReq], micro_batch_size: int, is_multimodal=False): - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - padded_req_num = micro_batch_size - len(req_objs) - for req in req_objs: - run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) - input_token_ids = req.get_input_token_ids() - input_id = input_token_ids[-1] - seq_len = len(input_token_ids) - assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) - input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - seq_len = 2 - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(seq_len) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - micro_batch = DecodeMicroBatch( - batch_size=nopad_b_seq_len.shape[0], - total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=nopad_b_req_idx, - b_seq_len=nopad_b_seq_len, - ) - - return micro_batch, run_reqs, padded_req_num - - -def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], max_prefill_num: int, is_multimodal=False): - assert max_prefill_num != 0 - micro_batch1_req_num = triton.cdiv(len(req_objs), 2) - micro_batch, run_reqs, padded_req_num = _padded_prepare_prefill_micro_batch( - req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal - ) - micro_batch1, run_reqs1, padded_req_num1 = _padded_prepare_prefill_micro_batch( - req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal - ) - - return micro_batch, run_reqs, padded_req_num, micro_batch1, run_reqs1, padded_req_num1 - - -def _padded_prepare_prefill_micro_batch(req_objs: List[InferReq], is_multimodal=False): - run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] - # prefill 只需要 padding 一个请求形成 micro_batch, 并不需要两个 - # micro batch 的 batch_size 相同。 - padded_req_num = 1 if len(req_objs) == 0 else 0 - b_ready_cache_len = [] - batch_multimodal_params = [] - for req in req_objs: - run_reqs.append(req) - batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) - - input_token_ids = req.get_chuncked_input_token_ids() - seq_len = len(input_token_ids) - input_token_len = seq_len - req.cur_kv_len - input_id = input_token_ids[req.cur_kv_len :] - - nopad_b_seq_len.append(seq_len) - input_ids.extend(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) - b_ready_cache_len.append(req.cur_kv_len) - - # padding fake req for decode - for _ in range(padded_req_num): - input_ids.append(1) - nopad_b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) - nopad_b_seq_len.append(1) - b_ready_cache_len.append(0) - nopad_total_token_num += 1 - nopad_max_len_in_batch = max(nopad_max_len_in_batch, 1) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") - - # dynamic prompt cache 准备 token - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() - g_infer_state_lock.release() - if padded_req_num > 0: - padding_indexs = torch.full( - (padded_req_num,), - fill_value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, - dtype=torch.int32, - device="cuda", - ) - mem_indexes = torch.cat((mem_indexes, padding_indexs), dim=0) - - micro_batch = PrefillMicroBatch( - batch_size=nopad_b_seq_len.shape[0], - total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, - input_ids=input_ids, - mem_indexes=mem_indexes, - b_req_idx=nopad_b_req_idx, - b_seq_len=nopad_b_seq_len, - b_ready_cache_len=b_ready_cache_len, - multimodal_params=batch_multimodal_params, - ) - - return micro_batch, run_reqs, padded_req_num diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py new file mode 100644 index 000000000..569bc8acd --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -0,0 +1,213 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +import numpy as np +import triton +from typing import List, Optional +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.infer_utils import calculate_time +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput + + +def padded_prepare_prefill_inputs(req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False): + + if dest_batch_size is None: + req_num = len(req_objs) + if req_num > 0: + dest_batch_size = req_num + else: + dest_batch_size = 1 + else: + assert len(req_objs) <= dest_batch_size + + run_reqs = [] + total_token_num = 0 + max_len_in_batch = 0 + padded_req_num = dest_batch_size - len(req_objs) + input_ids = [] + b_req_idx = [] + b_seq_len = [] + batch_multimodal_params = [] + b_ready_cache_len = [] + for req in req_objs: + + run_reqs.append(req) + batch_multimodal_params.append(req.multimodal_params) + b_req_idx.append(req.req_idx) + + input_token_ids = req.get_chuncked_input_token_ids() + seq_len = len(input_token_ids) + input_token_len = seq_len - req.cur_kv_len + input_id = input_token_ids[req.cur_kv_len :] + + b_seq_len.append(seq_len) + input_ids.append(input_id) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, input_token_len) + b_ready_cache_len.append(req.cur_kv_len) + + # padding fake req for prefill + for _ in range(padded_req_num): + input_ids.append([1]) + b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) + b_seq_len.append(1) + b_ready_cache_len.append(0) + total_token_num += 1 + max_len_in_batch = max(max_len_in_batch, 1) + + input_ids = np.concatenate(input_ids, dtype=np.int64) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + g_infer_state_lock.release() + + if padded_req_num > 0: + mem_indexes = F.pad( + input=mem_indexes, + pad=(0, padded_req_num), + mode="constant", + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) + + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=True, + ) + if is_multimodal: + model_input.multimodal_params = batch_multimodal_params + + return model_input, run_reqs, padded_req_num + + +def padded_prepare_decode_inputs(req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False): + run_reqs = [] + total_token_num = 0 + max_len_in_batch = 0 + input_ids = [] + b_req_idx = [] + b_seq_len = [] + + for req in req_objs: + run_reqs.append(req) + b_req_idx.append(req.req_idx) + input_token_ids = req.get_input_token_ids() + input_id = input_token_ids[-1] + seq_len = len(input_token_ids) + assert req.cur_kv_len == seq_len - 1 + b_seq_len.append(seq_len) + input_ids.append(input_id) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + # process the draft tokens. + for step in range(len(req.mtp_gen_token_ids)): + run_reqs.append(req) + b_req_idx.append(req.req_idx) + seq_len += 1 + b_seq_len.append(seq_len) + input_ids.append(req.mtp_gen_token_ids[step]) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + if dest_batch_size is None: + if len(run_reqs) == 0: + dest_batch_size = 1 + else: + dest_batch_size = len(run_reqs) + else: + assert len(run_reqs) <= dest_batch_size + + padded_req_num = dest_batch_size - len(run_reqs) + + # padding fake req for decode + for _ in range(padded_req_num): + input_ids.append(1) + seq_len = 2 + b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) + b_seq_len.append(seq_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") + + # dynamic prompt cache 准备 token + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + g_infer_state_lock.release() + + if padded_req_num > 0: + mem_indexes = F.pad( + input=mem_indexes, + pad=(0, padded_req_num), + mode="constant", + value=g_infer_context.req_manager.mem_manager.HOLD_TOKEN_MEMINDEX, + ) + + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=False, + ) + return model_input, run_reqs, padded_req_num + + +def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq], is_multimodal=False): + split_req_bound = triton.cdiv(len(req_objs), 2) + req_objs_0 = req_objs[0:split_req_bound] + req_objs_1 = req_objs[split_req_bound:] + + enable_mtp = get_env_start_args().mtp_mode is not None + if enable_mtp: + micro_batch_size = max( + sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_0]), + sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_1]), + ) + else: + micro_batch_size = triton.cdiv(len(req_objs), 2) + + micro_batch_size = max(1, micro_batch_size) + + micro_input, run_reqs, padded_req_num = padded_prepare_decode_inputs( + req_objs_0, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal + ) + micro_input1, run_reqs1, padded_req_num1 = padded_prepare_decode_inputs( + req_objs_1, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal + ) + return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 + + +def padded_overlap_prepare_prefill_inputs(req_objs: List[InferReq], is_multimodal=False): + micro_batch1_req_num = triton.cdiv(len(req_objs), 2) + + micro_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs( + req_objs[0:micro_batch1_req_num], is_multimodal=is_multimodal + ) + + micro_input1, run_reqs1, padded_req_num1 = padded_prepare_prefill_inputs( + req_objs[micro_batch1_req_num:], is_multimodal=is_multimodal + ) + + return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1 diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 9f174e8bc..92f43dac5 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -3,22 +3,22 @@ from typing import List from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelInput -def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal=False): +def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_multimodal: bool = False): run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 + total_token_num = 0 + max_len_in_batch = 0 input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] + b_req_idx = [] + b_seq_len = [] batch_multimodal_params = [] b_ready_cache_len = [] for req in req_objs: - run_reqs.append(req) batch_multimodal_params.append(req.multimodal_params) - nopad_b_req_idx.append(req.req_idx) + b_req_idx.append(req.req_idx) if is_chuncked_mode: input_token_ids = req.get_chuncked_input_token_ids() @@ -30,17 +30,17 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_ input_id = input_token_ids[req.cur_kv_len :] - nopad_b_seq_len.append(seq_len) + b_seq_len.append(seq_len) input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, input_token_len) b_ready_cache_len.append(req.cur_kv_len) input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") # dynamic prompt cache 准备 token @@ -50,44 +50,54 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool, is_ mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() g_infer_state_lock.release() - kwargs = { - "batch_size": len(run_reqs), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "b_ready_cache_len": b_ready_cache_len, - "is_prefill": True, - } + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) if is_multimodal: - kwargs["multimodal_params"] = batch_multimodal_params + model_input.multimodal_params = batch_multimodal_params - return kwargs, run_reqs + return model_input, run_reqs def prepare_decode_inputs(req_objs: List[InferReq]): run_reqs = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 + total_token_num = 0 + max_len_in_batch = 0 input_ids = [] - nopad_b_req_idx = [] - nopad_b_seq_len = [] + b_req_idx = [] + b_seq_len = [] for req in req_objs: run_reqs.append(req) - nopad_b_req_idx.append(req.req_idx) + b_req_idx.append(req.req_idx) input_id = req.get_last_gen_token() seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1 - nopad_b_seq_len.append(seq_len) + b_seq_len.append(seq_len) input_ids.append(input_id) - nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) + + # process the draft tokens. + for step in range(len(req.mtp_gen_token_ids)): + run_reqs.append(req) + b_req_idx.append(req.req_idx) + seq_len += 1 + b_seq_len.append(seq_len) + input_ids.append(req.mtp_gen_token_ids[step]) + total_token_num += seq_len + max_len_in_batch = max(max_len_in_batch, seq_len) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda") + b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() @@ -96,14 +106,14 @@ def prepare_decode_inputs(req_objs: List[InferReq]): mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() g_infer_state_lock.release() - kwargs = { - "batch_size": len(run_reqs), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "mem_indexes": mem_indexes, - "b_req_idx": nopad_b_req_idx, - "b_seq_len": nopad_b_seq_len, - "is_prefill": False, - } - return kwargs, run_reqs + model_input = ModelInput( + batch_size=b_seq_len.shape[0], + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=False, + ) + return model_input, run_reqs diff --git a/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py new file mode 100644 index 000000000..33c02a504 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py @@ -0,0 +1,33 @@ +import torch +import numpy as np +from typing import List, Tuple +from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.common.basemodel.batch_objs import ModelInput + + +def prepare_mtp_prefill_inputs( + req_objs: List[InferReq], + model_input: ModelInput, + next_token_ids_cpu: torch.Tensor, + draft_model_idx: int, + is_chunked_mode: bool, + padded_req_num: int = 0, +): + assert padded_req_num >= 0, f"padded_req_num must be greater than or euqal to 0, but got {padded_req_num}" + input_ids = [] + for i, req in enumerate(req_objs): + if is_chunked_mode: + input_token_ids = req.get_chunked_input_token_ids_shift(next_token_ids_cpu[i : i + 1], draft_model_idx + 1) + else: + input_token_ids = req.get_input_token_ids_shift(next_token_ids_cpu[i : i + 1], draft_model_idx + 1) + input_ids.append(input_token_ids[req.cur_kv_len :].astype(np.int64)) + + # padding fake req for prefill + for _ in range(padded_req_num): + input_ids.append([1]) + + input_ids = np.concatenate(input_ids, dtype=np.int64) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + model_input.input_ids = input_ids + return model_input diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 1568d62ef..cf9972edc 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -21,6 +21,8 @@ DPForDecodeNode, ChunckedPrefillForPrefillNode, DPChunkedForPrefillNode, + ContinuesBatchWithMTPBackend, + DPChunkedPrefillWithMTPBackend, ) from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray @@ -108,38 +110,36 @@ def init_model(self, kvargs): # 填充真正的 rank_id 参数 kvargs["rank_id"] = self.rank self.world_size = kvargs["world_size"] - disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False) - return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) - use_reward_model = kvargs.get("use_reward_model", False) - diverse_mode = kvargs.get("diverse_mode", False) - is_token_healing = kvargs.get("is_token_healing", False) - is_first_token_constraint_mode = kvargs.get("is_first_token_constraint_mode", False) - if kvargs.get("args", None) is not None: - is_outlines_constraint_mode = kvargs.get("args", None).output_constraint_mode == "outlines" - is_xgrammar_constraint_mode = kvargs.get("args", None).output_constraint_mode == "xgrammar" - assert not ( - is_outlines_constraint_mode and is_xgrammar_constraint_mode - ), "only one constraint mode can be true" - is_prefill_node = kvargs.get("args", None).run_mode == "prefill" - is_decode_node = kvargs.get("args", None).run_mode == "decode" - else: - is_outlines_constraint_mode = False - is_xgrammar_constraint_mode = False - is_prefill_node = False - is_decode_node = False + disable_chunked_prefill = self.args.disable_chunked_prefill + return_all_prompt_logprobs = self.args.return_all_prompt_logprobs + use_reward_model = self.args.use_reward_model + diverse_mode = self.args.diverse_mode + is_token_healing = self.args.token_healing_mode + is_first_token_constraint_mode = self.args.first_token_constraint_mode + + is_outlines_constraint_mode = self.args.output_constraint_mode == "outlines" + is_xgrammar_constraint_mode = self.args.output_constraint_mode == "xgrammar" + assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" + is_prefill_node = self.args.run_mode == "prefill" + is_decode_node = self.args.run_mode == "decode" + + enable_mtp = self.args.mtp_mode is not None if is_prefill_node: - if kvargs.get("args", None).dp > 1: + if self.args.dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue, self.mem_queue) else: self.backend = ChunckedPrefillForPrefillNode(self.info_queue, self.mem_queue) elif is_decode_node: - if kvargs.get("args", None).dp > 1: + if self.args.dp > 1: self.backend = DPForDecodeNode(self.info_queue, self.mem_queue) else: self.backend = ContinuesBatchBackendForDecodeNode(self.info_queue, self.mem_queue) - elif kvargs.get("dp_size", 1) > 1: - self.backend = DPChunkedPrefillBackend() + elif self.args.dp > 1: + if enable_mtp: + self.backend = DPChunkedPrefillWithMTPBackend() + else: + self.backend = DPChunkedPrefillBackend() elif use_reward_model: self.backend = RewardModelBackend() elif return_all_prompt_logprobs: @@ -155,9 +155,15 @@ def init_model(self, kvargs): elif is_first_token_constraint_mode: self.backend = FirstTokenConstraintBackend() elif disable_chunked_prefill: - self.backend = ContinuesBatchBackend() + if enable_mtp: + self.backend = ContinuesBatchWithMTPBackend() + else: + self.backend = ContinuesBatchBackend() else: - self.backend = ChunkedPrefillBackend() + if enable_mtp: + self.backend = ContinuesBatchWithMTPBackend() + else: + self.backend = ChunkedPrefillBackend() logger.info(f"use {self.backend.__class__.__name__}") self.backend.init_model(kvargs) diff --git a/lightllm/utils/custom_kernel_utis.py b/lightllm/utils/custom_kernel_utis.py index f72dfabfc..cbc52c8f9 100644 --- a/lightllm/utils/custom_kernel_utis.py +++ b/lightllm/utils/custom_kernel_utis.py @@ -114,3 +114,13 @@ def tensor_copy_3dim(dest_tensor: torch.Tensor, source_tensor: torch.Tensor): num_stages=3, ) return + + +def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int): + assert input.ndim == 2 + origin_batch_size = input.shape[0] + hidden = input.shape[1] + out = torch.empty((new_batch_size, hidden), dtype=input.dtype, device=input.device, requires_grad=False) + out[0:origin_batch_size, :] = input + out[origin_batch_size:, :] = input[0:1, :] + return out diff --git a/test/model/model_infer_mtp.py b/test/model/model_infer_mtp.py new file mode 100644 index 000000000..84a4140c7 --- /dev/null +++ b/test/model/model_infer_mtp.py @@ -0,0 +1,284 @@ +import os +import torch +import numpy as np +from multiprocessing import Queue +import multiprocessing +from transformers import PretrainedConfig +from lightllm.utils.dist_utils import init_distributed_env, get_current_rank_in_dp +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models import get_model +from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from torch.profiler import profile, record_function, ProfilerActivity +from lightllm.utils.log_utils import init_logger +from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +import torch.cuda as cuda + +logger = init_logger(__name__) + + +def init_mtp_model(args, kvargs, main_model): + spec_step = args.spec_step + draft_models = [] + + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" + mtp_model_kvargs = kvargs + mtp_model_kvargs.update( + { + "weight_dir": args.spec_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "use_dynamic_prompt_cache": False, + "disable_chunked_prefill": True, + "spec_algo": "MTP_MOUDLE", + "main_model": main_model, + } + ) + for i in range(spec_step): + mtp_model_kvargs.update( + { + "weight_dir": args.spec_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "use_dynamic_prompt_cache": False, + "disable_chunked_prefill": True, + "spec_algo": "MTP_MOUDLE", + "main_model": main_model, + } + ) + mtp_model_kvargs.update( + { + "last_mtp_module": i == spec_step - 1, + } + ) + draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + return draft_models + + +def test_model_inference_mtp(args): + ans_queue = Queue() + workers = [] + dp_size = args.get("dp", 1) + + for rank_id in range(args.node_rank * args.tp, (args.node_rank + 1) * args.tp): + model_kvargs = { + "args": args, + "nccl_host": args.nccl_host, + "data_type": args.data_type, + "nccl_port": args.nccl_port, + "rank_id": rank_id, + "world_size": args.tp, + "dp_size": dp_size, + "weight_dir": args.model_dir, + "quant_type": args.quant_type, + "load_way": "HF", + "max_total_token_num": args.max_total_token_num, + "graph_max_len_in_batch": args.max_req_total_len, + "graph_max_batch_size": args.graph_max_batch_size, + "mem_faction": args.mem_fraction, + "max_req_num": 2000, + "batch_max_tokens": 2048, + "run_mode": "normal", + "max_seq_length": args.max_req_total_len, + "spec_algo": args.spec_algo, + "disable_cudagraph": args.disable_cudagraph, + } + proc = multiprocessing.Process( + target=tppart_model_infer, + args=(args, model_kvargs, args.batch_size, args.input_len, args.output_len, ans_queue), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + assert not ans_queue.empty() + while not ans_queue.empty(): + assert ans_queue.get() + return + + +def torch_profile(fn, log_dir=None): + torch.cuda.synchronize() + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir), + ) as prof: + fn() + if get_current_rank_in_dp() == 0: + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False): + import time + + torch.cuda.synchronize() + prefill_start_time = time.time() + + test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + test_data = test_data.reshape(-1) + test_data = torch.from_numpy(test_data).cuda() + + b_req_idx = torch.tensor( + [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" + ) + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(batch_size): + b_seq_len[i] = input_len + + total_token_num = input_len * batch_size + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + # Main model Prefill + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=input_len, + input_ids=test_data, + mem_indexes=mem_indexes, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + is_prefill=True, + b_ready_cache_len=b_ready_cache_len, + ) + + model_output = main_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + + draft_ids = [predict_ids] + + # Draft model Prefill + # For simplicity, we'll just take the input of main_model to draft model. + model_input.hidden_states = model_output.hidden_states + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + draft_ids.append(predict_ids) + model_input.hidden_states = model_output.hidden_states + + torch.cuda.synchronize() + prefill_end_time = time.time() + if get_current_rank_in_dp() == 0 and not warmup: + print("prefill time cost:", (prefill_end_time - prefill_start_time) * 1000) + print( + f"Prefill throughput: {batch_size * input_len * args.dp / (prefill_end_time - prefill_start_time)} tokens/s" + ) + + torch.cuda.synchronize() + + decode_input_ids = np.stack(draft_ids, axis=-1).reshape(-1) + decode_input_ids = torch.from_numpy(decode_input_ids).cuda() + + # build main decode input: + nopad_b_seq_idx = [] + nopad_b_seq_len = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + + for i in range(batch_size): + nopad_b_seq_idx.append(b_req_idx[i]) + seq_len = b_seq_len[i].item() + nopad_b_seq_len.append(seq_len + 1) + nopad_total_token_num += seq_len + 1 + nopad_max_len_in_batch = max(nopad_max_len_in_batch, b_seq_len[i] + 1) + + for step in range(len(draft_models)): + nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_seq_len.append(seq_len + step + 2) + nopad_total_token_num += seq_len + step + 2 + nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) + + nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + + model_input = ModelInput( + batch_size=batch_size * (len(draft_models) + 1), + total_token_num=nopad_total_token_num, + max_len_in_batch=nopad_max_len_in_batch, + input_ids=decode_input_ids, + mem_indexes=mem_indexes, + b_req_idx=nopad_b_seq_idx, + b_seq_len=nopad_b_seq_len, + is_prefill=False, + ) + + # Main decode + for i in range(0, output_len, len(draft_models) + 1): + torch.cuda.synchronize() + step_start_time = time.time() + model_output = main_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + + # draft decode + model_input.input_ids = predict_ids.reshape(-1) + model_input.hidden_states = model_output.hidden_states + + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward( + model_input, + ) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + model_input.input_ids = predict_ids.reshape(-1) + model_input.hidden_states = model_output.hidden_states + + # accept all draft ids by default. + model_input.input_ids = predict_ids.reshape(-1) + model_input.hidden_states = model_output.hidden_states + torch.cuda.synchronize() + if i % 100 == 0 or i == output_len - 1: + step_end_time = time.time() + if get_current_rank_in_dp() == 0 and not warmup: + step_time = step_end_time - step_start_time + print(i, " step cost time:", step_time * 1000) + print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + + main_model.mem_manager.free_all() + main_model.req_manager.free_all() + + +def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, ans_queue): + args = get_env_start_args() + import triton.profiler as proton + import torch + from lightllm.distributed import dist_group_manager + from lightllm.utils.dist_utils import set_current_device_id + + import torch.distributed as dist + + enable_decode_overlap = args.enable_decode_microbatch_overlap + group_size = 1 + if enable_decode_overlap or args.enable_prefill_microbatch_overlap: + group_size = 2 + init_distributed_env(model_kvargs) + dist_group_manager.create_groups(group_size=group_size) + model_cfg, _ = PretrainedConfig.get_config_dict(model_kvargs["weight_dir"]) + dist.barrier() + + torch.cuda.empty_cache() + + main_model, _ = get_model(model_cfg, model_kvargs) + draft_models = init_mtp_model(args, model_kvargs, main_model) + if isinstance(batch_sizes, int): + batch_sizes = [batch_sizes] + + for batch_size in batch_sizes: + # warm up + run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=True) + torch.cuda.synchronize() + run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False) + dist.barrier() + + ans_queue.put(True) + return diff --git a/test/model/test_model.py b/test/model/test_model.py index 1376e2570..026e6b044 100644 --- a/test/model/test_model.py +++ b/test/model/test_model.py @@ -5,6 +5,7 @@ import unittest from model_infer import test_model_inference +from model_infer_mtp import test_model_inference_mtp from lightllm.server.api_cli import make_argument_parser from lightllm.utils.envs_utils import set_env_start_args, get_env_start_args from lightllm.utils.config_utils import get_config_json, get_dtype @@ -15,7 +16,10 @@ def test_model_infer(self): args = get_env_start_args() if args.data_type is None: args.data_type = get_dtype(args.model_dir) - test_model_inference(args) + if args.spec_algo == "MTP": + test_model_inference_mtp(args) + else: + test_model_inference(args) return @@ -23,9 +27,9 @@ def test_model_infer(self): import torch parser = make_argument_parser() - parser.add_argument("--batch_size", type=int, default=2, help="batch size") - parser.add_argument("--input_len", type=int, default=4096, help="input sequence length") - parser.add_argument("--output_len", type=int, default=128, help="output sequence length") + parser.add_argument("--batch_size", nargs="+", type=int, default=1, help="batch size") + parser.add_argument("--input_len", type=int, default=64, help="input sequence length") + parser.add_argument("--output_len", type=int, default=4096 + 1024, help="output sequence length") parser.add_argument( "--profile", action="store_true",