6
6
7
7
from lightllm .utils .log_utils import init_logger
8
8
9
- from .pd_remote_prefill_obj import RemoteAgent , KVMoveRequest , PrefillRequest , RemotePrefillStatus , ThreadSafeDict
9
+ from .pd_remote_prefill_obj import (
10
+ RemoteAgent , KVMoveRequest , PrefillRequest ,
11
+ RemotePrefillStatus , ThreadSafeDict , KVMoveRequestState
12
+ )
10
13
11
14
12
15
logger = init_logger (__name__ )
@@ -108,11 +111,20 @@ def _get_token_desc_ids(self, token_ids: List[int]):
108
111
descs_ids .append (layer_id * self .num_tokens + token_id )
109
112
return descs_ids
110
113
111
- def write_blocks (self , request : KVMoveRequest , prefill_request : PrefillRequest ):
114
+ def write_blocks (self , request : KVMoveRequest , prefill_request : PrefillRequest , is_finished : bool ):
112
115
group_reqeust_id = request .group_req_id
113
116
skip_kv_move_len = prefill_request .data .local_cached_len
114
- src_token_ids = request .token_ids [skip_kv_move_len :]
115
- dst_token_ids = prefill_request .data .token_ids
117
+
118
+ # current kv len is less than remote cached kv len, just skip
119
+ if request .cur_kv_len <= skip_kv_move_len :
120
+ return
121
+
122
+ kv_move_start = max (skip_kv_move_len , request .prev_kv_len )
123
+ kv_move_end = request .cur_kv_len
124
+
125
+ src_token_ids = request .token_ids [kv_move_start :]
126
+ dst_token_ids = prefill_request .data .token_ids [kv_move_start - skip_kv_move_len : kv_move_end ]
127
+
116
128
remote_agent : RemoteAgent = self .remote_agents [prefill_request .decode_id ][
117
129
self .tp_idx
118
130
] # TODO one-one mapping now
@@ -124,52 +136,85 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest):
124
136
125
137
src_handle = self .local_xfer_handles
126
138
dst_handle = remote_agent .kv_xfer_handles
127
- notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = 1 )
139
+ notify_status = RemotePrefillStatus (
140
+ group_req_id = group_reqeust_id ,
141
+ status = 1 ,
142
+ chunk_id = prefill_request .transfer_state .current_chunk_id ,
143
+ is_last = is_finished )
144
+
128
145
handle = self .nixl_agent .make_prepped_xfer (
129
146
"WRITE" , src_handle , src_token_descs , dst_handle , dst_token_descs , notify_status .serialize ()
130
147
)
131
148
132
149
status = self .nixl_agent .transfer (handle )
133
150
assert status != "ERR"
134
151
135
- self .inflight_transfers [group_reqeust_id ] = (handle , remote_agent , False )
152
+ if group_reqeust_id not in self .inflight_transfers :
153
+ self .inflight_transfers [group_reqeust_id ] = KVMoveRequestState (
154
+ handles = [],
155
+ done_handles = [],
156
+ remote_agent = remote_agent ,
157
+ abort = False
158
+ )
159
+ self .inflight_transfers [group_reqeust_id ].handles .append (handle )
136
160
137
161
return handle
138
162
139
163
return None
140
164
141
165
def send_abort_notify (self , remote_id : int , group_reqeust_id ):
142
166
remote_agent : RemoteAgent = self .remote_agents [remote_id ][self .tp_idx ]
143
- notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = - 1 )
167
+ notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = - 1 , chunk_id = - 1 , is_last = True )
144
168
self .nixl_agent .send_notif (remote_agent .name , notify_status .serialize ())
145
169
146
170
if group_reqeust_id in self .inflight_transfers :
147
- self .inflight_transfers [group_reqeust_id ][ 2 ] = True
171
+ self .inflight_transfers [group_reqeust_id ]. abort = True
148
172
149
173
def get_done_tranfers (self ):
150
174
done_req_ids = []
151
175
152
- for req_id , (handle , remote_agent , is_abort ) in self .inflight_transfers .items ():
153
- if is_abort :
176
+ for req_id , kv_move_state in self .inflight_transfers .items ():
177
+ kv_move_state : KVMoveRequestState
178
+ if kv_move_state .abort :
154
179
logger .warning (f"{ req_id } Transfer aborted" )
155
180
done_req_ids .append ((req_id , - 1 ))
156
181
continue
157
182
158
- remote_agent : RemoteAgent
159
- xfer_state = self .nixl_agent .check_xfer_state (handle )
160
- if xfer_state == "DONE" :
161
- done_req_ids .append ((req_id , 1 ))
162
- elif xfer_state == "PROC" :
163
- continue
164
- else :
165
- logger .warning (f"{ req_id } Transfer failed with state { xfer_state } " )
183
+ remote_agent : RemoteAgent = kv_move_state .remote_agent
184
+
185
+ left_handles = []
186
+ failed = False
187
+ for handle in kv_move_state .handles :
188
+ if failed :
189
+ left_handles .append (handle )
190
+ continue
191
+
192
+ xfer_state = self .nixl_agent .check_xfer_state (handle )
193
+
194
+ if xfer_state == "DONE" :
195
+ kv_move_state .done_handles .append (handle )
196
+ elif xfer_state == "PROC" :
197
+ left_handles .append (handle )
198
+ else :
199
+ logger .warning (f"{ req_id } Transfer failed with state { xfer_state } " )
200
+ failed = True
201
+ kv_move_state .done_handles .append (handle )
202
+ notify_failed_status = RemotePrefillStatus (group_req_id = req_id , status = - 1 , chunk_id = - 1 , is_last = True )
203
+ self .nixl_agent .send_notif (remote_agent .name , notify_failed_status .serialize ())
204
+
205
+ kv_move_state .handles = left_handles
206
+
207
+ if failed :
166
208
done_req_ids .append ((req_id , - 1 ))
167
- notify_failed_status = RemotePrefillStatus ( group_req_id = req_id , status = - 1 )
168
- self . nixl_agent . send_notif ( remote_agent . name , notify_failed_status . serialize ( ))
209
+ elif len ( left_handles ) == 0 :
210
+ done_req_ids . append (( req_id , 1 ))
169
211
170
212
for req_id , _ in done_req_ids :
171
- # release will abort inflight transfer
172
- self .nixl_agent .release_xfer_handle (self .inflight_transfers [req_id ][0 ])
213
+ kv_move_state : KVMoveRequestState = self .inflight_transfers [req_id ]
214
+ for handle in kv_move_state .handles + kv_move_state .done_handles :
215
+ # release will abort inflight transfer
216
+ self .nixl_agent .release_xfer_handle (handle )
217
+
173
218
del self .inflight_transfers [req_id ]
174
219
175
220
return done_req_ids
0 commit comments