Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions vllm_spyre/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from vllm.logger import init_logger
# SPYRE SPECIFIC CODE BLOCK END
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStatus)
from vllm.utils import Device, PyObjectCache

from vllm_spyre.platform import SpyrePlatform
Expand Down Expand Up @@ -179,7 +179,11 @@ def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Only for testing purposes.
self.swapped.append(seq_group)

def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
def abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
"""Aborts a sequence group with the given ID.

Check if the sequence group with the given ID
Expand Down
16 changes: 8 additions & 8 deletions vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ def __init__(self, *args, **kwargs) -> None:
self.spyre_warmup_shapes: tuple[dict[str, int], ...] = \
SpyrePlatform.get_warmup_shapes(self.scheduler_config)

# We'll put all new requests into this queue so that the base scheduler
# does not attempt to schedule them until we release them into the
# waiting queue. This lets us ensure that the set of requests the base
# scheduler sees have at least one common warmup shape.
# Requests are temporarily moved to this queue so that the base
# scheduler does not see them. This lets us ensure that the set of
# requests scheduled have at least one common warmup shape.
self.holdback_queue: Deque[Request] = deque()

self.rejected_requests: set[str] = set()
Expand Down Expand Up @@ -136,11 +135,12 @@ def schedule(self) -> SchedulerOutput:
len(self.running))

outputs = super().schedule()
return outputs

def get_num_unfinished_requests(self) -> int:
# Override this to include our extra queue
return len(self.waiting) + len(self.running) + len(self.holdback_queue)
# move unscheduled requests back to the waiting queue
while self.holdback_queue:
self.waiting.append(self.holdback_queue.popleft())

return outputs

def _get_matching_warmup_shapes(
self, request: Request, warmup_shapes: list[dict[str, int]],
Expand Down
5 changes: 4 additions & 1 deletion vllm_spyre/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def prepare_model_input(
# updating indices: set indices of newly finished sequences False
if finished_requests_ids:
for seq_id in finished_requests_ids:
self.model.indices[self._req_ids2idx[seq_id]] = False
# ignore requests that are not in the batch, eg. requests
# cancelled while waiting
if (idx := self._req_ids2idx.get(seq_id)) is not None:
self.model.indices[idx] = False
(input_tokens, input_positions,
input_masks) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
Expand Down