|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | from collections import deque |
4 | | -from typing import TYPE_CHECKING, Deque |
| 4 | +from typing import TYPE_CHECKING, Deque, Iterable, Union |
5 | 5 |
|
6 | 6 | from vllm.logger import init_logger |
7 | 7 | from vllm.sampling_params import SamplingParams |
@@ -140,6 +140,43 @@ def get_num_unfinished_requests(self) -> int: |
140 | 140 | # Override this to include our extra queue |
141 | 141 | return len(self.waiting) + len(self.running) + len(self.holdback_queue) |
142 | 142 |
|
| 143 | + def finish_requests( |
| 144 | + self, |
| 145 | + request_ids: Union[str, Iterable[str]], |
| 146 | + finished_status: RequestStatus, |
| 147 | + ) -> None: |
| 148 | + """Handles the finish signal from outside the scheduler. |
| 149 | +
|
| 150 | + For example, the API server can abort a request when the client |
| 151 | + disconnects. |
| 152 | +
|
| 153 | + Specialized in vllm_spyre to handle the holdback_queue. |
| 154 | + """ |
| 155 | + assert RequestStatus.is_finished(finished_status) |
| 156 | + if isinstance(request_ids, str): |
| 157 | + request_ids = (request_ids, ) |
| 158 | + else: |
| 159 | + request_ids = set(request_ids) |
| 160 | + |
| 161 | + for req_id in request_ids: |
| 162 | + request = self.requests.get(req_id) |
| 163 | + if request is None: |
| 164 | + # Invalid request ID. |
| 165 | + continue |
| 166 | + |
| 167 | + if request.status == RequestStatus.RUNNING: |
| 168 | + self.running.remove(request) |
| 169 | + self.scheduled_req_ids.discard(request.request_id) |
| 170 | + else: |
| 171 | + # this try-except is the specialization for Spyre |
| 172 | + try: |
| 173 | + self.holdback_queue.remove(request) |
| 174 | + except ValueError: |
| 175 | + self.waiting.remove(request) |
| 176 | + |
| 177 | + request.status = finished_status |
| 178 | + self._free_request(request) |
| 179 | + |
143 | 180 | def _get_matching_warmup_shapes( |
144 | 181 | self, request: Request, warmup_shapes: list[dict[str, int]], |
145 | 182 | current_batch_size: int) -> list[dict[str, int]]: |
|
0 commit comments