From e25e054485b1ae3793c4ff9b94e0aea3ae5993a2 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 17 Jul 2025 15:42:11 +0800 Subject: [PATCH] feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun --- lightllm/common/basemodel/basemodel.py | 66 +------------------ lightllm/common/basemodel/cuda_graph.py | 4 +- lightllm/common/infer_utils.py | 4 +- lightllm/common/mem_manager.py | 23 ++++++- lightllm/common/req_manager.py | 1 + .../generic_padded_pre_process.py | 8 ++- .../mode_backend/generic_pre_process.py | 8 ++- .../benchmark/static_inference/model_infer.py | 8 ++- .../static_inference/model_infer_mtp.py | 8 ++- 9 files changed, 53 insertions(+), 77 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ff3290233..63df953d8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,9 +11,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager -from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg @@ -330,14 +328,6 @@ def _prefill( model_input: ModelInput, ): infer_state = self._create_inferstate(model_input) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input.b_req_idx, - model_input.b_seq_len, - infer_state.b_ready_cache_len, - model_input.max_len_in_batch, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) return self._context_forward(model_input.input_ids, infer_state) @@ -350,12 +340,6 @@ def _decode( find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -371,12 +355,6 @@ def _decode( ) else: infer_state = self._create_inferstate(model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) model_output = self._token_forward(model_input.input_ids, infer_state) @@ -458,25 +436,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids infer_state0 = self._create_inferstate(model_input0, 0) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input0.b_req_idx, - model_input0.b_seq_len, - infer_state0.b_ready_cache_len, - model_input0.max_len_in_batch, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, input_ids0) infer_state1 = self._create_inferstate(model_input1, 1) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input1.b_req_idx, - model_input1.b_seq_len, - infer_state1.b_ready_cache_len, - model_input1.max_len_in_batch, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, input_ids1) model_output0, model_output1 = self._overlap_tpsp_context_forward( @@ -502,20 +464,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) infer_state1 = self._create_inferstate(padded_model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -540,20 +490,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: infer_state0 = self._create_inferstate(model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, model_input0.input_ids) infer_state1 = self._create_inferstate(model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, model_input1.input_ids) model_output0, model_output1 = self._overlap_tpsp_token_forward( @@ -654,10 +592,12 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.mem_manager.alloc( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() total_token_num = self.batch_max_tokens model_input = ModelInput( batch_size=1, diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dc615eb46..37e0172ab 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,12 +196,12 @@ def warmup(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda() model_input = ModelInput( batch_size=batch_size, @@ -250,12 +250,12 @@ def warmup_overlap(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index da2f35e08..c3f980ecd 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,6 +1,4 @@ -def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index -): +def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..6f3eb2c22 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -12,6 +12,8 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.infer_utils import init_req_to_token_indexes +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req logger = init_logger(__name__) @@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size + self.req_to_token_indexs = None def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -243,7 +246,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: + def alloc(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=True) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -255,6 +258,24 @@ def alloc(self, need_size) -> torch.Tensor: self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.req_to_token_indexs is not None: + assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided" + if is_prefill: + init_req_to_token_indexes( + self.req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + ans, + ) + else: + copy_kv_index_to_req( + self.req_to_token_indexs, + b_req_idx, + b_seq_len, + ans, + ) return ans def free(self, free_index: Union[torch.Tensor, List[int]]): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 643ccd894..f702c2480 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_to_token_indexs = torch.zeros( (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) + mem_manager.req_to_token_indexs = self.req_to_token_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 985800d44..990dcad2a 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -69,7 +69,9 @@ def padded_prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() g_infer_state_lock.release() if padded_req_num > 0: @@ -155,7 +157,9 @@ def padded_prepare_decode_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda() + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, None, False + ).cuda() g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 185570d74..57be00e54 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -49,7 +49,9 @@ def prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() g_infer_state_lock.release() model_input = ModelInput( @@ -105,7 +107,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda() + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0], b_req_idx, b_seq_len, None, False + ).cuda() g_infer_state_lock.release() model_input = ModelInput( diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 3db0ac004..6aa9a17ef 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -242,7 +242,9 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() rank_id = model_kvargs["rank_id"] @@ -303,7 +305,9 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc( + predict_ids.shape[0], b_req_idx, b_seq_len, None, False + ).cuda() max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index eb36bc873..b204c0122 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len, None, False + ).cuda() model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1),