diff --git a/vllm_spyre/v1/core/sched/__init__.py b/vllm_spyre/v1/core/sched/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_spyre/v1/core/sched/output.py b/vllm_spyre/v1/core/sched/output.py new file mode 100644 index 000000000..78058bc29 --- /dev/null +++ b/vllm_spyre/v1/core/sched/output.py @@ -0,0 +1,11 @@ +# This import wraps the importing of some vLLM classes based on the version + +try: + # vllm v0.8.2+ + from vllm.v1.core.sched.output import CachedRequestData # noqa: F401 + from vllm.v1.core.sched.output import NewRequestData # noqa: F401 + from vllm.v1.core.sched.output import SchedulerOutput # noqa: F401 +except ImportError: + from vllm.v1.core.scheduler import CachedRequestData # noqa: F401 + from vllm.v1.core.scheduler import NewRequestData # noqa: F401 + from vllm.v1.core.scheduler import SchedulerOutput # noqa: F401 diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index 391a989c0..f31502479 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -1,16 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from typing import Deque +from typing import TYPE_CHECKING, Deque from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler import Scheduler -from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +try: + from vllm.v1.core.sched.scheduler import Scheduler +except ImportError: + from vllm.v1.core.scheduler import Scheduler + +if TYPE_CHECKING: + from vllm_spyre.v1.core.sched.output import SchedulerOutput +else: + SchedulerOutput = None from vllm_spyre.platform import SpyrePlatform logger = init_logger(__name__) @@ -83,7 +90,7 @@ def update_from_output( outputs.outputs.extend(reject_outputs) return outputs - def schedule(self) -> "SchedulerOutput": + def schedule(self) -> SchedulerOutput: """This override adds constraints and then delegates most of the work to the base scheduler""" diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index deea7e8e6..8ab175574 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -26,8 +26,14 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData, - SchedulerOutput) + from vllm_spyre.v1.core.sched.output import (CachedRequestData, + NewRequestData, + SchedulerOutput) +else: + CachedRequestData = None + SchedulerOutput = None + NewRequestData = None + from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -133,7 +139,7 @@ def vocab_size(self) -> int: def _prepare_prompt( self, - new_requests: List[NewRequestData], + new_requests: list[NewRequestData], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: assert len(new_requests) > 0 input_token_list: List[torch.Tensor] = [] @@ -198,7 +204,7 @@ def _prepare_prompt( def _prepare_decode( self, - cached_requests: List[CachedRequestData], + cached_requests: list[CachedRequestData], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(cached_requests) > 0 input_tokens: List[List[int]] = [ @@ -293,7 +299,7 @@ def prepare_model_input( @SpyrePlatform.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, **kwargs, ) -> ModelRunnerOutput: diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 052eff402..cdac7e049 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -14,8 +14,6 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData, - SchedulerOutput) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase as WorkerBaseV1 @@ -24,6 +22,8 @@ import vllm_spyre.envs as envs_spyre from vllm_spyre.model_executor.model_loader import spyre_setup from vllm_spyre.platform import SpyrePlatform +from vllm_spyre.v1.core.sched.output import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm_spyre.v1.worker.spyre_model_runner import SpyreModelRunner logger = init_logger(__name__)