Skip to content

Commit b48fb5a

Browse files
committed
fix: specialize finish_requests in V1 scheduler
Signed-off-by: Travis Johnson <[email protected]>
1 parent 3b67ae0 commit b48fb5a

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

vllm_spyre/v1/core/scheduler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections import deque
4-
from typing import TYPE_CHECKING, Deque
4+
from typing import TYPE_CHECKING, Deque, Iterable, Union
55

66
from vllm.logger import init_logger
77
from vllm.sampling_params import SamplingParams
@@ -140,6 +140,43 @@ def get_num_unfinished_requests(self) -> int:
140140
# Override this to include our extra queue
141141
return len(self.waiting) + len(self.running) + len(self.holdback_queue)
142142

143+
def finish_requests(
144+
self,
145+
request_ids: Union[str, Iterable[str]],
146+
finished_status: RequestStatus,
147+
) -> None:
148+
"""Handles the finish signal from outside the scheduler.
149+
150+
For example, the API server can abort a request when the client
151+
disconnects.
152+
153+
Specialized in vllm_spyre to handle the holdback_queue.
154+
"""
155+
assert RequestStatus.is_finished(finished_status)
156+
if isinstance(request_ids, str):
157+
request_ids = (request_ids, )
158+
else:
159+
request_ids = set(request_ids)
160+
161+
for req_id in request_ids:
162+
request = self.requests.get(req_id)
163+
if request is None:
164+
# Invalid request ID.
165+
continue
166+
167+
if request.status == RequestStatus.RUNNING:
168+
self.running.remove(request)
169+
self.scheduled_req_ids.discard(request.request_id)
170+
else:
171+
# this try-except is the specialization for Spyre
172+
try:
173+
self.holdback_queue.remove(request)
174+
except ValueError:
175+
self.waiting.remove(request)
176+
177+
request.status = finished_status
178+
self._free_request(request)
179+
143180
def _get_matching_warmup_shapes(
144181
self, request: Request, warmup_shapes: list[dict[str, int]],
145182
current_batch_size: int) -> list[dict[str, int]]:

0 commit comments

Comments
 (0)