diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index c7046e6b6..02d83c6ec 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -6,8 +6,9 @@ from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None import vllm.envs as envs from vllm.platforms import Platform, PlatformEnum @@ -165,3 +166,10 @@ def set_warmup_shapes(cls, scheduler_config) -> None: @classmethod def get_warmup_shapes(cls) -> tuple[dict[str, int], ...]: return cls.spyre_warmup_shapes + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + """Returns whether the current platform can support v1 for the supplied + model configuration. + """ + return True