|
5 | 5 | import torch.multiprocessing as mp
|
6 | 6 | import torch.distributed as dist
|
7 | 7 | from typing import List, Tuple
|
8 |
| -from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend |
9 |
| -from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context |
| 8 | +from lightllm.server.router.model_infer.infer_batch import InferReq |
10 | 9 | from lightllm.server.pd_io_struct import KVMoveTask, DecodeNodeInfo
|
11 | 10 | from lightllm.utils.log_utils import init_logger
|
12 | 11 | from lightllm.common.basemodel.infer_lock import g_router_lock, g_infer_state_lock
|
13 |
| -from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend |
14 | 12 | from rpyc.utils.server import ThreadedServer
|
15 | 13 | from .prefill_task_cache import g_kv_move_task_cache
|
16 | 14 | from lightllm.utils.device_utils import kv_trans_use_p2p
|
17 | 15 | from lightllm.utils.envs_utils import get_unique_server_name
|
18 | 16 | from lightllm.utils.dist_utils import create_new_group_for_current_dp
|
| 17 | +from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend |
19 | 18 |
|
20 | 19 | logger = init_logger(__name__)
|
21 | 20 |
|
22 | 21 |
|
23 |
| -class ChunckedPrefillForPrefillNode(ModeBackend): |
| 22 | +class ChunckedPrefillForPrefillNode(ChunkedPrefillBackend): |
24 | 23 | def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
|
25 | 24 | super().__init__()
|
26 | 25 | self.info_queue: mp.Queue = info_queue
|
@@ -49,36 +48,23 @@ def init_custom(self):
|
49 | 48 |
|
50 | 49 | return
|
51 | 50 |
|
52 |
| - def decode(self): |
53 |
| - uinit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( |
54 |
| - g_infer_context.infer_req_ids, |
55 |
| - no_decode=True, |
56 |
| - ) |
57 |
| - assert len(decode_reqs) == 0 |
58 |
| - |
59 |
| - if aborted_reqs: |
60 |
| - self._filter_reqs(aborted_reqs) |
61 |
| - |
62 |
| - if ok_finished_reqs: |
63 |
| - self.prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(ok_finished_reqs) |
64 |
| - self._filter_reqs(ok_finished_reqs) |
65 |
| - ok_finished_reqs.clear() |
66 |
| - |
67 |
| - if prefill_reqs: |
68 |
| - ContinuesBatchBackend.normal_prefill_reqs( |
69 |
| - self, prefill_reqs=prefill_reqs, uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs |
70 |
| - ) |
71 |
| - |
72 |
| - self._overlap_req_init_and_filter(uninit_reqs=uinit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) |
| 51 | + def _pre_handle_finished_reqs(self, finished_reqs): |
| 52 | + self._prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(run_reqs=finished_reqs) |
73 | 53 | return
|
74 | 54 |
|
75 |
| - def prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, run_reqs: List[InferReq]): |
76 |
| - # 提前在radix cache中回收相关的信息,并添加引用信息 |
| 55 | + def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, run_reqs: List[InferReq]): |
| 56 | + # 提前在radix cache中回收相关的信息,并添加引用进行锁定,方便传输进程传输kv。 |
77 | 57 | if self.is_master_in_dp:
|
78 | 58 | logger.info("prefill_req_handle_and_frozen_tokens")
|
| 59 | + |
79 | 60 | g_infer_state_lock.acquire()
|
80 | 61 | try:
|
81 | 62 | for req in run_reqs:
|
| 63 | + |
| 64 | + # 区分abort 和 正常结束的请求,正常结束的请求才发起kv传输任务。 |
| 65 | + if not req.finish_status.is_finished(): |
| 66 | + continue |
| 67 | + |
82 | 68 | req: InferReq = req
|
83 | 69 | key = req.get_input_token_ids()[0 : req.cur_kv_len]
|
84 | 70 | key = torch.tensor(key, dtype=torch.int64, device="cpu")
|
|
0 commit comments