Skip to content

Commit 7daf768

Browse files
author
Weichao Luo
committed
chunk transfer.
1 parent 24fef1f commit 7daf768

File tree

4 files changed

+134
-45
lines changed

4 files changed

+134
-45
lines changed

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
KVMoveRequest,
2020
RemotePrefillStatus,
2121
ThreadSafeDict,
22+
TransferState,
2223
)
2324

2425
logger = init_logger(__name__)
@@ -55,15 +56,17 @@ def handle_remote_prefill(req_status: RemotePrefillStatus):
5556
status = req_status.status
5657
if status != 1:
5758
logger.warning(f"remote prefill reqeust: {group_req_id} done with state: {status}")
59+
5860
if run_req := self.remote_prefilled_reqs.get(group_req_id, None):
59-
shm_req: PDChunkedPrefillReq = run_req.shm_req
60-
shm_req.set_pd_req_rank_state(self.rank_in_dp, status)
61-
self.remote_prefilled_reqs.pop(group_req_id)
62-
if self.is_master_in_dp:
63-
logger.info(
64-
f"remote prefill reqeust: {group_req_id} done with status: {status} "
65-
f"took: {time.time() - run_req.remote_prefill_start} seconds"
66-
)
61+
if req_status.is_last or status != 1:
62+
shm_req: PDChunkedPrefillReq = run_req.shm_req
63+
shm_req.set_pd_req_rank_state(self.rank_in_dp, status)
64+
self.remote_prefilled_reqs.pop(group_req_id)
65+
if self.is_master_in_dp:
66+
logger.info(
67+
f"remote prefill reqeust: {group_req_id} done with status: {status} "
68+
f"took: {time.time() - run_req.remote_prefill_start} seconds"
69+
)
6770
else:
6871
if self.is_master_in_dp:
6972
logger.warning(f"remote prefill reqeust: {group_req_id} not found")
@@ -100,12 +103,15 @@ def _wait_transfer_loop(self):
100103
req: InferReq = self.inflght_transfer_requests[req_id]
101104
shm_req: PDChunkedPrefillReq = req.shm_req
102105
shm_req.set_pd_req_rank_state(self.rank_in_dp, state)
103-
del self.inflght_transfer_requests[req_id]
106+
transfer_state = self.remote_prefill_requests[req_id].transfer_state
104107
if self.is_master_in_dp:
105108
logger.info(
106109
f"req: {req_id} kv transfer with state: {state} "
107-
f"took: {time.time() - req.kv_transfer_start} seconds"
110+
f"took: {time.time() - transfer_state.start_time} seconds"
108111
)
112+
del self.remote_prefill_requests[req_id]
113+
del self.inflght_transfer_requests[req_id]
114+
109115
time.sleep(PDNIXLBackendBase._THEAD_WAIT_INTERVAL)
110116

111117
def _handle_prefill_loop(self):
@@ -140,20 +146,37 @@ def _transfer_kv_to_remote(self, req: InferReq):
140146
logger.info(f"remote prefill request {group_req_id} not found")
141147
return
142148

143-
# kick off kv transfer
144-
if req.finish_status.is_finished():
145-
req.kv_transfer_start = time.time()
146-
kv_transfer_req = KVMoveRequest(
147-
group_req_id=group_req_id,
148-
token_ids=self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].tolist(),
149+
remote_request: PrefillRequest = self.remote_prefill_requests[group_req_id]
150+
if remote_request.transfer_state is None:
151+
remote_request.transfer_state = TransferState(
152+
start_time=time.time(),
153+
current_kv_len=0,
154+
current_chunk_id=0,
149155
)
150-
remote_request = self.remote_prefill_requests[group_req_id]
151-
self.nixl_agent.write_blocks(kv_transfer_req, remote_request)
156+
157+
transfer_state = remote_request.transfer_state
158+
token_index = self.model.req_manager.req_to_token_indexs[req.req_idx]
159+
is_finished = req.finish_status.is_finished()
160+
161+
kv_transfer_req = KVMoveRequest(
162+
group_req_id=group_req_id,
163+
token_ids=token_index[ : req.cur_kv_len].tolist(),
164+
prev_kv_len=transfer_state.current_kv_len,
165+
cur_kv_len=req.cur_kv_len,
166+
)
167+
# kick off kv transfer
168+
self.nixl_agent.write_blocks(kv_transfer_req, remote_request, is_finished)
169+
170+
if transfer_state.current_chunk_id == 0:
152171
shm_req: PDChunkedPrefillReq = req.shm_req
153172
shm_req.set_pd_req_rank_state(self.rank_in_dp, 0)
154-
req.kv_transfering = True
173+
req.in_prefill_or_transfer = True
155174
self.inflght_transfer_requests[group_req_id] = req
156175

176+
transfer_state.current_kv_len = req.cur_kv_len
177+
transfer_state.current_chunk_id += 1
178+
179+
157180
def _decode_filter_reqs(
158181
self, prefill_reqs: List[InferReq], aborted_reqs: List[InferReq], decode_reqs: List[InferReq]
159182
):

lightllm/server/router/model_infer/mode_backend/pd_nixl/nixl_kv_transporter.py

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
from lightllm.utils.log_utils import init_logger
88

9-
from .pd_remote_prefill_obj import RemoteAgent, KVMoveRequest, PrefillRequest, RemotePrefillStatus, ThreadSafeDict
9+
from .pd_remote_prefill_obj import (
10+
RemoteAgent, KVMoveRequest, PrefillRequest,
11+
RemotePrefillStatus, ThreadSafeDict, KVMoveRequestState
12+
)
1013

1114

1215
logger = init_logger(__name__)
@@ -108,11 +111,20 @@ def _get_token_desc_ids(self, token_ids: List[int]):
108111
descs_ids.append(layer_id * self.num_tokens + token_id)
109112
return descs_ids
110113

111-
def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest):
114+
def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest, is_finished: bool):
112115
group_reqeust_id = request.group_req_id
113116
skip_kv_move_len = prefill_request.data.local_cached_len
114-
src_token_ids = request.token_ids[skip_kv_move_len:]
115-
dst_token_ids = prefill_request.data.token_ids
117+
118+
# current kv len is less than remote cached kv len, just skip
119+
if request.cur_kv_len <= skip_kv_move_len:
120+
return
121+
122+
kv_move_start = max(skip_kv_move_len, request.prev_kv_len)
123+
kv_move_end = request.cur_kv_len
124+
125+
src_token_ids = request.token_ids[kv_move_start:]
126+
dst_token_ids = prefill_request.data.token_ids[kv_move_start - skip_kv_move_len: kv_move_end]
127+
116128
remote_agent: RemoteAgent = self.remote_agents[prefill_request.decode_id][
117129
self.tp_idx
118130
] # TODO one-one mapping now
@@ -124,52 +136,85 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest):
124136

125137
src_handle = self.local_xfer_handles
126138
dst_handle = remote_agent.kv_xfer_handles
127-
notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=1)
139+
notify_status = RemotePrefillStatus(
140+
group_req_id=group_reqeust_id,
141+
status=1,
142+
chunk_id=prefill_request.transfer_state.current_chunk_id,
143+
is_last=is_finished)
144+
128145
handle = self.nixl_agent.make_prepped_xfer(
129146
"WRITE", src_handle, src_token_descs, dst_handle, dst_token_descs, notify_status.serialize()
130147
)
131148

132149
status = self.nixl_agent.transfer(handle)
133150
assert status != "ERR"
134151

135-
self.inflight_transfers[group_reqeust_id] = (handle, remote_agent, False)
152+
if group_reqeust_id not in self.inflight_transfers:
153+
self.inflight_transfers[group_reqeust_id] = KVMoveRequestState(
154+
handles=[],
155+
done_handles=[],
156+
remote_agent=remote_agent,
157+
abort=False
158+
)
159+
self.inflight_transfers[group_reqeust_id].handles.append(handle)
136160

137161
return handle
138162

139163
return None
140164

141165
def send_abort_notify(self, remote_id: int, group_reqeust_id):
142166
remote_agent: RemoteAgent = self.remote_agents[remote_id][self.tp_idx]
143-
notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1)
167+
notify_status = RemotePrefillStatus(group_req_id=group_reqeust_id, status=-1, chunk_id=-1, is_last=True)
144168
self.nixl_agent.send_notif(remote_agent.name, notify_status.serialize())
145169

146170
if group_reqeust_id in self.inflight_transfers:
147-
self.inflight_transfers[group_reqeust_id][2] = True
171+
self.inflight_transfers[group_reqeust_id].abort = True
148172

149173
def get_done_tranfers(self):
150174
done_req_ids = []
151175

152-
for req_id, (handle, remote_agent, is_abort) in self.inflight_transfers.items():
153-
if is_abort:
176+
for req_id, kv_move_state in self.inflight_transfers.items():
177+
kv_move_state: KVMoveRequestState
178+
if kv_move_state.abort:
154179
logger.warning(f"{req_id} Transfer aborted")
155180
done_req_ids.append((req_id, -1))
156181
continue
157182

158-
remote_agent: RemoteAgent
159-
xfer_state = self.nixl_agent.check_xfer_state(handle)
160-
if xfer_state == "DONE":
161-
done_req_ids.append((req_id, 1))
162-
elif xfer_state == "PROC":
163-
continue
164-
else:
165-
logger.warning(f"{req_id} Transfer failed with state {xfer_state}")
183+
remote_agent: RemoteAgent = kv_move_state.remote_agent
184+
185+
left_handles = []
186+
failed = False
187+
for handle in kv_move_state.handles:
188+
if failed:
189+
left_handles.append(handle)
190+
continue
191+
192+
xfer_state = self.nixl_agent.check_xfer_state(handle)
193+
194+
if xfer_state == "DONE":
195+
kv_move_state.done_handles.append(handle)
196+
elif xfer_state == "PROC":
197+
left_handles.append(handle)
198+
else:
199+
logger.warning(f"{req_id} Transfer failed with state {xfer_state}")
200+
failed = True
201+
kv_move_state.done_handles.append(handle)
202+
notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1, chunk_id=-1, is_last=True)
203+
self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize())
204+
205+
kv_move_state.handles = left_handles
206+
207+
if failed:
166208
done_req_ids.append((req_id, -1))
167-
notify_failed_status = RemotePrefillStatus(group_req_id=req_id, status=-1)
168-
self.nixl_agent.send_notif(remote_agent.name, notify_failed_status.serialize())
209+
elif len(left_handles) == 0:
210+
done_req_ids.append((req_id, 1))
169211

170212
for req_id, _ in done_req_ids:
171-
# release will abort inflight transfer
172-
self.nixl_agent.release_xfer_handle(self.inflight_transfers[req_id][0])
213+
kv_move_state: KVMoveRequestState = self.inflight_transfers[req_id]
214+
for handle in kv_move_state.handles + kv_move_state.done_handles:
215+
# release will abort inflight transfer
216+
self.nixl_agent.release_xfer_handle(handle)
217+
173218
del self.inflight_transfers[req_id]
174219

175220
return done_req_ids

lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ def main_loop(self):
9595
queue.put(request)
9696

9797
success = True
98-
for idx in range(self.tp_size):
98+
for idx in range(self.dist_info.node_world_size):
9999
ack = self.from_backend_queue.get()
100+
logger.info(f"received ack from backend {idx}: {ack}")
100101
if ack != "OK":
101102
success = False
102103
break
103104

104105
self.recv_from_decode.send_pyobj_multipart(client_obj, success)
106+
logger.info(f"Sent ack to decode: {success}")
105107
if not success:
106108
logger.warning(f"Remote connect failed: {request}")
107109

@@ -166,6 +168,7 @@ def _send_nixl_agent(self, socket: SockWithPoller):
166168
)
167169

168170
success = socket.recv_pyobj(timeout=60)
171+
logger.info(f"recv remote nixl connect response {success}")
169172
if success is None:
170173
logger.warning("timeout to recv remote nixl connect response")
171174
return False
@@ -203,6 +206,8 @@ def main_loop(self):
203206
RemotePrefillStatus(
204207
group_req_id=prefill_tasks.prefill_request.sampling_params.group_request_id,
205208
status=-1,
209+
chunk_id=-1,
210+
is_last=True,
206211
)
207212
)
208213
except Exception as e:
@@ -212,7 +217,7 @@ def main_loop(self):
212217
def remote_prefill(self, server_id: int, prefill_request: RemotePrefillRequest):
213218
socket, _ = self.remote_prefill_servers[server_id]
214219
prefill_request.sampling_params.max_new_tokens = 1
215-
socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request))
220+
socket.send_pyobj(PrefillRequest(type=RemoteRequstType.REMOTE_PREFILL, decode_id=self.id, data=prefill_request, transfer_state=None))
216221

217222

218223
def remote_prefill_server_loop(

lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_remote_prefill_obj.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
logger = init_logger(__name__)
1515

1616
try:
17-
from nixl._api import nixlBind, nixl_prepped_dlist_handle
17+
from nixl._api import nixlBind, nixl_prepped_dlist_handle, nixl_xfer_handle
1818

1919
except ImportError:
2020
logger.error("nixl is not installed, which is required for pd disagreggation!!!")
@@ -53,17 +53,26 @@ class ConnectRequest(RemoteRequest):
5353
agent_metadatas: List[bytes]
5454
agent_mem_descs: List[bytes]
5555

56+
@dataclass
57+
class TransferState:
58+
start_time: float
59+
current_kv_len: int
60+
current_chunk_id: int
5661

5762
@dataclass
5863
class PrefillRequest(RemoteRequest):
5964
decode_id: int
6065
data: RemotePrefillRequest
66+
# transfer status
67+
transfer_state: Optional[TransferState]
6168

6269

6370
@dataclass
6471
class KVMoveRequest:
6572
group_req_id: int
6673
token_ids: List[int]
74+
prev_kv_len: int
75+
cur_kv_len: int
6776

6877

6978
@dataclass
@@ -73,6 +82,13 @@ class RemoteAgent:
7382
kv_mem_desc: nixlBind.nixlRegDList
7483
kv_xfer_handles: nixl_prepped_dlist_handle
7584

85+
@dataclass
86+
class KVMoveRequestState:
87+
handles: List[nixl_xfer_handle]
88+
done_handles: List[nixl_xfer_handle]
89+
remote_agent: RemoteAgent
90+
abort: bool
91+
7692

7793
@dataclass
7894
class RemotePrefillStatus:
@@ -160,7 +176,7 @@ def recv_pyobj_multipart(self):
160176
client_id, data = self.sock.recv_multipart()
161177
return client_id, pickle.loads(data)
162178

163-
def send_pyobj_multipart(self, client_id: str, data: Any):
179+
def send_pyobj_multipart(self, client_id: bytes, data: Any):
164180
return self.sock.send_multipart([client_id, pickle.dumps(data)])
165181

166182
def bind(self, addr: str):

0 commit comments

Comments
 (0)