From e8f624dbdfb6c4066987dc6efd296f5d10efe18a Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 28 Mar 2025 13:36:27 -0600 Subject: [PATCH 1/5] fix: add optional arg to abort_seq_group for compat with v0.8 Signed-off-by: Travis Johnson --- vllm_spyre/core/scheduler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm_spyre/core/scheduler.py b/vllm_spyre/core/scheduler.py index 733b586f4..fb5d4afb9 100644 --- a/vllm_spyre/core/scheduler.py +++ b/vllm_spyre/core/scheduler.py @@ -24,8 +24,8 @@ from vllm.logger import init_logger # SPYRE SPECIFIC CODE BLOCK END from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, - SequenceStatus) + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupMetadataDelta, SequenceStatus) from vllm.utils import Device, PyObjectCache from vllm_spyre.platform import SpyrePlatform @@ -179,7 +179,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: # Only for testing purposes. self.swapped.append(seq_group) - def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + def abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: """Aborts a sequence group with the given ID. Check if the sequence group with the given ID From 87e573e8aa1ee3e00acb8171a83612e3b9b98f07 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 28 Mar 2025 13:40:57 -0600 Subject: [PATCH 2/5] fix: guard against KeyError with _req_ids2idx Signed-off-by: Travis Johnson --- vllm_spyre/worker/spyre_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py index c0253285f..a1df1e79f 100644 --- a/vllm_spyre/worker/spyre_model_runner.py +++ b/vllm_spyre/worker/spyre_model_runner.py @@ -284,7 +284,10 @@ def prepare_model_input( # updating indices: set indices of newly finished sequences False if finished_requests_ids: for seq_id in finished_requests_ids: - self.model.indices[self._req_ids2idx[seq_id]] = False + # ignore requests that are not in the batch, eg. requests + # cancelled while waiting + if idx := self._req_ids2idx.get(seq_id): + self.model.indices[idx] = False (input_tokens, input_positions, input_masks) = self._prepare_decode(seq_group_metadata_list) seq_lens = [] From 8cf95eb36f39542fb8c181f43134c435f8fca43b Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 28 Mar 2025 14:43:25 -0600 Subject: [PATCH 3/5] fix: specialize finish_requests in V1 scheduler Signed-off-by: Travis Johnson --- vllm_spyre/v1/core/scheduler.py | 39 ++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index f0f8f3537..fa9eb5e70 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from typing import TYPE_CHECKING, Deque +from typing import TYPE_CHECKING, Deque, Iterable, Union from vllm.logger import init_logger from vllm.sampling_params import SamplingParams @@ -142,6 +142,43 @@ def get_num_unfinished_requests(self) -> int: # Override this to include our extra queue return len(self.waiting) + len(self.running) + len(self.holdback_queue) + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + + Specialized in vllm_spyre to handle the holdback_queue. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + if request.status == RequestStatus.RUNNING: + self.running.remove(request) + self.scheduled_req_ids.discard(request.request_id) + else: + # this try-except is the specialization for Spyre + try: + self.holdback_queue.remove(request) + except ValueError: + self.waiting.remove(request) + + request.status = finished_status + self._free_request(request) + def _get_matching_warmup_shapes( self, request: Request, warmup_shapes: list[dict[str, int]], current_batch_size: int) -> list[dict[str, int]]: From 8b659ea13f82c8a27d71e5a42ac9922c0382224d Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 28 Mar 2025 16:34:19 -0600 Subject: [PATCH 4/5] fix: check against None... Signed-off-by: Travis Johnson --- vllm_spyre/worker/spyre_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py index a1df1e79f..e91e77c2c 100644 --- a/vllm_spyre/worker/spyre_model_runner.py +++ b/vllm_spyre/worker/spyre_model_runner.py @@ -286,7 +286,7 @@ def prepare_model_input( for seq_id in finished_requests_ids: # ignore requests that are not in the batch, eg. requests # cancelled while waiting - if idx := self._req_ids2idx.get(seq_id): + if (idx := self._req_ids2idx.get(seq_id)) is not None: self.model.indices[idx] = False (input_tokens, input_positions, input_masks) = self._prepare_decode(seq_group_metadata_list) From d83732993f91c032cddd0c4fa28240aa1f86adc4 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 7 Apr 2025 14:07:49 -0600 Subject: [PATCH 5/5] refactor: make holdback queue use more temporary Signed-off-by: Travis Johnson --- vllm_spyre/v1/core/scheduler.py | 53 +++++---------------------------- 1 file changed, 8 insertions(+), 45 deletions(-) diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index fa9eb5e70..cc85e3ce5 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from typing import TYPE_CHECKING, Deque, Iterable, Union +from typing import TYPE_CHECKING, Deque from vllm.logger import init_logger from vllm.sampling_params import SamplingParams @@ -41,10 +41,9 @@ def __init__(self, *args, **kwargs) -> None: self.spyre_warmup_shapes: tuple[dict[str, int], ...] = \ SpyrePlatform.get_warmup_shapes(self.scheduler_config) - # We'll put all new requests into this queue so that the base scheduler - # does not attempt to schedule them until we release them into the - # waiting queue. This lets us ensure that the set of requests the base - # scheduler sees have at least one common warmup shape. + # Requests are temporarily moved to this queue so that the base + # scheduler does not see them. This lets us ensure that the set of + # requests scheduled have at least one common warmup shape. self.holdback_queue: Deque[Request] = deque() self.rejected_requests: set[str] = set() @@ -136,48 +135,12 @@ def schedule(self) -> SchedulerOutput: len(self.running)) outputs = super().schedule() - return outputs - - def get_num_unfinished_requests(self) -> int: - # Override this to include our extra queue - return len(self.waiting) + len(self.running) + len(self.holdback_queue) - - def finish_requests( - self, - request_ids: Union[str, Iterable[str]], - finished_status: RequestStatus, - ) -> None: - """Handles the finish signal from outside the scheduler. - For example, the API server can abort a request when the client - disconnects. + # move unscheduled requests back to the waiting queue + while self.holdback_queue: + self.waiting.append(self.holdback_queue.popleft()) - Specialized in vllm_spyre to handle the holdback_queue. - """ - assert RequestStatus.is_finished(finished_status) - if isinstance(request_ids, str): - request_ids = (request_ids, ) - else: - request_ids = set(request_ids) - - for req_id in request_ids: - request = self.requests.get(req_id) - if request is None: - # Invalid request ID. - continue - - if request.status == RequestStatus.RUNNING: - self.running.remove(request) - self.scheduled_req_ids.discard(request.request_id) - else: - # this try-except is the specialization for Spyre - try: - self.holdback_queue.remove(request) - except ValueError: - self.waiting.remove(request) - - request.status = finished_status - self._free_request(request) + return outputs def _get_matching_warmup_shapes( self, request: Request, warmup_shapes: list[dict[str, int]],