Skip to content

Commit e25e054

Browse files
author
niushengxiao
committed
feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun
1 parent 8ed97c7 commit e25e054

File tree

9 files changed

+53
-77
lines changed

9 files changed

+53
-77
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from lightllm.common.basemodel.infer_struct import InferStateInfo
1212
from lightllm.common.mem_manager import MemoryManager
1313
from lightllm.common.req_manager import ReqManager
14-
from lightllm.common.infer_utils import init_req_to_token_indexes
1514
from lightllm.common.build_utils import repair_config
16-
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1715
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1816
from lightllm.common.basemodel.cuda_graph import CudaGraph
1917
from lightllm.common.quantization import Quantcfg
@@ -330,14 +328,6 @@ def _prefill(
330328
model_input: ModelInput,
331329
):
332330
infer_state = self._create_inferstate(model_input)
333-
init_req_to_token_indexes(
334-
self.req_manager.req_to_token_indexs,
335-
model_input.b_req_idx,
336-
model_input.b_seq_len,
337-
infer_state.b_ready_cache_len,
338-
model_input.max_len_in_batch,
339-
infer_state.mem_index,
340-
)
341331

342332
infer_state.init_some_extra_state(self, model_input.input_ids)
343333
return self._context_forward(model_input.input_ids, infer_state)
@@ -350,12 +340,6 @@ def _decode(
350340
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
351341
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
352342
infer_state = self._create_inferstate(padded_model_input)
353-
copy_kv_index_to_req(
354-
self.req_manager.req_to_token_indexs,
355-
infer_state.b_req_idx,
356-
infer_state.b_seq_len,
357-
infer_state.mem_index,
358-
)
359343
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
360344

361345
if self.graph.need_capture(find_graph_batch_size):
@@ -371,12 +355,6 @@ def _decode(
371355
)
372356
else:
373357
infer_state = self._create_inferstate(model_input)
374-
copy_kv_index_to_req(
375-
self.req_manager.req_to_token_indexs,
376-
infer_state.b_req_idx,
377-
infer_state.b_seq_len,
378-
infer_state.mem_index,
379-
)
380358
infer_state.init_some_extra_state(self, model_input.input_ids)
381359
model_output = self._token_forward(model_input.input_ids, infer_state)
382360

@@ -458,25 +436,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
458436
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
459437

460438
infer_state0 = self._create_inferstate(model_input0, 0)
461-
init_req_to_token_indexes(
462-
self.req_manager.req_to_token_indexs,
463-
model_input0.b_req_idx,
464-
model_input0.b_seq_len,
465-
infer_state0.b_ready_cache_len,
466-
model_input0.max_len_in_batch,
467-
infer_state0.mem_index,
468-
)
469439
infer_state0.init_some_extra_state(self, input_ids0)
470440

471441
infer_state1 = self._create_inferstate(model_input1, 1)
472-
init_req_to_token_indexes(
473-
self.req_manager.req_to_token_indexs,
474-
model_input1.b_req_idx,
475-
model_input1.b_seq_len,
476-
infer_state1.b_ready_cache_len,
477-
model_input1.max_len_in_batch,
478-
infer_state1.mem_index,
479-
)
480442
infer_state1.init_some_extra_state(self, input_ids1)
481443

482444
model_output0, model_output1 = self._overlap_tpsp_context_forward(
@@ -502,20 +464,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
502464
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
503465
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
504466
infer_state0 = self._create_inferstate(padded_model_input0, 0)
505-
copy_kv_index_to_req(
506-
self.req_manager.req_to_token_indexs,
507-
infer_state0.b_req_idx,
508-
infer_state0.b_seq_len,
509-
infer_state0.mem_index,
510-
)
511467
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
512468
infer_state1 = self._create_inferstate(padded_model_input1, 1)
513-
copy_kv_index_to_req(
514-
self.req_manager.req_to_token_indexs,
515-
infer_state1.b_req_idx,
516-
infer_state1.b_seq_len,
517-
infer_state1.mem_index,
518-
)
519469
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
520470

521471
if self.graph.need_capture(find_graph_batch_size):
@@ -540,20 +490,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
540490
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
541491
else:
542492
infer_state0 = self._create_inferstate(model_input0, 0)
543-
copy_kv_index_to_req(
544-
self.req_manager.req_to_token_indexs,
545-
infer_state0.b_req_idx,
546-
infer_state0.b_seq_len,
547-
infer_state0.mem_index,
548-
)
549493
infer_state0.init_some_extra_state(self, model_input0.input_ids)
550494
infer_state1 = self._create_inferstate(model_input1, 1)
551-
copy_kv_index_to_req(
552-
self.req_manager.req_to_token_indexs,
553-
infer_state1.b_req_idx,
554-
infer_state1.b_seq_len,
555-
infer_state1.mem_index,
556-
)
557495
infer_state1.init_some_extra_state(self, model_input1.input_ids)
558496

559497
model_output0, model_output1 = self._overlap_tpsp_token_forward(
@@ -654,10 +592,12 @@ def _check_max_len_infer(self):
654592
logger.info("begin check max_len infer")
655593
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
656594
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
657-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
658595
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
659596
b_seq_len[:] = self.batch_max_tokens
660597
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
598+
mem_indexes = self.mem_manager.alloc(
599+
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
600+
).cuda()
661601
total_token_num = self.batch_max_tokens
662602
model_input = ModelInput(
663603
batch_size=1,

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,12 @@ def warmup(self, model):
196196
total_token_num = batch_size * seq_len
197197
max_len_in_batch = self.graph_max_len_in_batch
198198
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
199-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
200199
b_req_idx = torch.tensor(
201200
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
202201
)
203202
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
204203
b_seq_len.fill_(seq_len)
204+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda()
205205

206206
model_input = ModelInput(
207207
batch_size=batch_size,
@@ -250,12 +250,12 @@ def warmup_overlap(self, model):
250250
total_token_num = batch_size * seq_len
251251
max_len_in_batch = self.graph_max_len_in_batch
252252
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
253-
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
254253
b_req_idx = torch.tensor(
255254
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
256255
)
257256
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
258257
b_seq_len.fill_(seq_len)
258+
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda()
259259

260260
micro_batch = ModelInput(
261261
is_prefill=False,

lightllm/common/infer_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
def init_req_to_token_indexes(
2-
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
3-
):
1+
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index):
42
start_index = 0
53
b_seq_len_numpy = b_seq_len.cpu().numpy()
64
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()

lightllm/common/mem_manager.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313
from lightllm.distributed.pynccl import PyNcclCommunicator
1414
from lightllm.utils.dist_utils import get_current_device_id
15+
from lightllm.common.infer_utils import init_req_to_token_indexes
16+
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
1517

1618
logger = init_logger(__name__)
1719

@@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
5254
layer_num,
5355
)
5456
self.HOLD_TOKEN_MEMINDEX = self.size
57+
self.req_to_token_indexs = None
5558

5659
def get_cell_size(self):
5760
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
@@ -243,7 +246,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
243246
def _free_buffers(self):
244247
self.kv_buffer = None
245248

246-
def alloc(self, need_size) -> torch.Tensor:
249+
def alloc(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=True) -> torch.Tensor:
247250
if need_size > self.mark_end - self.mark_start:
248251
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
249252
assert False, "error alloc state"
@@ -255,6 +258,24 @@ def alloc(self, need_size) -> torch.Tensor:
255258

256259
self.can_use_mem_size -= need_size
257260
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
261+
262+
if self.req_to_token_indexs is not None:
263+
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
264+
if is_prefill:
265+
init_req_to_token_indexes(
266+
self.req_to_token_indexs,
267+
b_req_idx,
268+
b_seq_len,
269+
b_ready_cache_len,
270+
ans,
271+
)
272+
else:
273+
copy_kv_index_to_req(
274+
self.req_to_token_indexs,
275+
b_req_idx,
276+
b_seq_len,
277+
ans,
278+
)
258279
return ans
259280

260281
def free(self, free_index: Union[torch.Tensor, List[int]]):

lightllm/common/req_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
6262
self.req_to_token_indexs = torch.zeros(
6363
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
6464
)
65+
mem_manager.req_to_token_indexs = self.req_to_token_indexs
6566
self.mem_manager = mem_manager
6667
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
6768
self.max_request_num = max_request_num

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def padded_prepare_prefill_inputs(
6969
g_infer_state_lock.acquire()
7070
if g_infer_context.radix_cache is not None:
7171
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
72-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
72+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
73+
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True
74+
).cuda()
7375
g_infer_state_lock.release()
7476

7577
if padded_req_num > 0:
@@ -155,7 +157,9 @@ def padded_prepare_decode_inputs(
155157
g_infer_state_lock.acquire()
156158
if g_infer_context.radix_cache is not None:
157159
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
158-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
160+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
161+
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, None, False
162+
).cuda()
159163
g_infer_state_lock.release()
160164

161165
if padded_req_num > 0:

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def prepare_prefill_inputs(
4949
g_infer_state_lock.acquire()
5050
if g_infer_context.radix_cache is not None:
5151
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
52-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda()
52+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
53+
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
54+
).cuda()
5355
g_infer_state_lock.release()
5456

5557
model_input = ModelInput(
@@ -105,7 +107,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
105107
g_infer_state_lock.acquire()
106108
if g_infer_context.radix_cache is not None:
107109
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
108-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda()
110+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
111+
input_ids.shape[0], b_req_idx, b_seq_len, None, False
112+
).cuda()
109113
g_infer_state_lock.release()
110114

111115
model_input = ModelInput(

test/benchmark/static_inference/model_infer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def run_forward_once(
242242
b_seq_len[i] = input_len
243243

244244
total_token_num = batch_size * input_len
245-
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
245+
mem_indexes = model_part.req_manager.mem_manager.alloc(
246+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
247+
).cuda()
246248

247249
rank_id = model_kvargs["rank_id"]
248250

@@ -303,7 +305,9 @@ def run_forward_once(
303305
step_start = time.time()
304306
total_token_num += batch_size
305307
b_seq_len += 1
306-
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
308+
mem_indexes = model_part.req_manager.mem_manager.alloc(
309+
predict_ids.shape[0], b_req_idx, b_seq_len, None, False
310+
).cuda()
307311
max_len_in_batch = input_len + i + 1
308312
logits = decode_fn(
309313
model_part,

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
126126
b_seq_len[i] = input_len
127127

128128
total_token_num = input_len * batch_size
129-
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
129+
mem_indexes = main_model.req_manager.mem_manager.alloc(
130+
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
131+
).cuda()
130132
# Main model Prefill
131133
model_input = ModelInput(
132134
batch_size=batch_size,
@@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
193195

194196
nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
195197
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
196-
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda()
198+
mem_indexes = main_model.req_manager.mem_manager.alloc(
199+
batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len, None, False
200+
).cuda()
197201

198202
model_input = ModelInput(
199203
batch_size=batch_size * (len(draft_models) + 1),

0 commit comments

Comments
 (0)