Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
include:
- vllm_version:
name: "vLLM:lowest"
repo: "git+https://github.com/vllm-project/vllm --tag v0.10.1.1"
repo: "git+https://github.com/vllm-project/vllm --tag v0.10.2"
test_suite:
name: "backward compat"
markers: "compat or (cpu and basic)"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = {text = "Apache 2"}
dependencies = [
"fms-model-optimizer[fp8]>=0.6.0",
"ibm-fms>=1.4.0,<2.0",
"vllm>=0.10.1.1,<=0.10.2",
"vllm>=0.10.2,<=0.11.0",
"pytest-mock>=3.15.0",
]
requires-python = ">=3.11"
Expand Down
20 changes: 1 addition & 19 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import json
import math
import os
Expand Down Expand Up @@ -367,31 +366,14 @@ def create_random_request(
assert (len(prompt_token_ids) == num_tokens
), f"need {num_tokens} but got {len(prompt_token_ids)}"

# temporary backward compat code for 0.10.1.1
annotations = inspect.getfullargspec(Request).annotations
extra_args = {} # noqa
if ('multi_modal_hashes' in annotations):
extra_args.update({
'multi_modal_hashes': None,
})
if ('multi_modal_placeholders' in annotations):
extra_args.update({
'multi_modal_placeholders': None,
})
if ('multi_modal_kwargs' in annotations):
extra_args.update({
'multi_modal_kwargs': None,
})

return Request(request_id=str(request_id),
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
lora_request=None,
pooling_params=None,
cache_salt=None,
**extra_args)
cache_salt=None)


def skip_unsupported_tp_size(size: int, backend: str):
Expand Down
88 changes: 20 additions & 68 deletions tests/utils/test_upstream_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,35 @@
import inspect
import os

import pytest
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import FullAttentionSpec

from vllm_spyre.compat_utils import dataclass_fields

pytestmark = pytest.mark.compat

VLLM_VERSION = os.getenv("TEST_VLLM_VERSION", "default")


@pytest.mark.cpu
def test_init_distributed_environment():
"""Tests whether vllm's init_distributed_environment
has the custom timeout argument"""
from vllm.distributed import init_distributed_environment

annotations = inspect.getfullargspec(
init_distributed_environment).annotations
def test_mm_inputs():

if VLLM_VERSION == "vLLM:lowest":
assert 'timeout' \
not in annotations, ("we should remove compat code which is now"
" part of released vllm version")


def test_request():

from vllm.v1.request import Request

annotations = inspect.getfullargspec(Request).annotations

if VLLM_VERSION == "vLLM:main":
assert 'multi_modal_kwargs' not in annotations
assert 'multi_modal_hashes' not in annotations
assert 'multi_modal_placeholders' not in annotations
elif VLLM_VERSION == "vLLM:lowest":
assert 'multi_modal_hashes' in annotations
assert 'multi_modal_placeholders' in annotations
assert 'multi_modal_placeholders' in annotations
# The compat code introduced in the PR below can now be removed:
# https://github.com/vllm-project/vllm-spyre/pull/463


def test_model_runner_output():

from vllm.v1.outputs import ModelRunnerOutput
# Can remove "mm_kwargs", "mm_hashes", "mm_positions"
# (replaced by mm_features)
assert 'mm_kwargs' in dataclass_fields(NewRequestData)

annotations = inspect.getfullargspec(ModelRunnerOutput).annotations

if VLLM_VERSION == "vLLM:main":
assert 'spec_token_ids' not in annotations
elif VLLM_VERSION == "vLLM:lowest":
assert 'spec_token_ids' in annotations
# The compat code introduced in the PR below can now be removed:
# https://github.com/vllm-project/vllm-spyre/pull/463


def test_pooling_metadata():

from vllm.v1.pool.metadata import PoolingMetadata

has_build_pooling_cursor = getattr(PoolingMetadata, "build_pooling_cursor",
False)

if VLLM_VERSION == "vLLM:main":
assert has_build_pooling_cursor
elif VLLM_VERSION == "vLLM:lowest":
assert not has_build_pooling_cursor
# The compat code introduced in the PR below can now be removed:
# https://github.com/vllm-project/vllm-spyre/pull/463


def test_scheduler_output():
def test_get_sampler():
if VLLM_VERSION == "vLLM:lowest":
try:
from vllm.model_executor.layers.sampler import ( # # noqa
get_sampler)
except ImportError as e:
raise AssertionError(
"Remove backwards compatibility for get_sampler") from e

from vllm.v1.core.sched.output import SchedulerOutput
annotations = inspect.getfullargspec(SchedulerOutput).annotations

if VLLM_VERSION == "vLLM:main":
assert 'free_encoder_mm_hashes' in annotations
elif VLLM_VERSION == "vLLM:lowest":
assert 'free_encoder_mm_hashes' not in annotations
# The compat code introduced in the PR below can now be removed:
# https://github.com/vllm-project/vllm-spyre/pull/463
def test_use_mla():
if VLLM_VERSION == "vLLM:lowest":
# Can remove backwards compatibility for use_mla
assert "use_mla" in dataclass_fields(FullAttentionSpec)
42 changes: 21 additions & 21 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

import vllm_spyre.envs as envs_spyre
import vllm_spyre.utils as utils_spyre
Expand Down Expand Up @@ -60,7 +61,13 @@ def __init__(
self.logits_processor = LogitsProcessor(
vllm_config.model_config.hf_config.vocab_size,
logits_as_input=True)
self.sampler = get_sampler()

try:
## Temporary backwards compatibility for 0.10.2
from vllm.model_executor.layers.sampler import get_sampler
self.sampler = get_sampler()
except (ImportError, ModuleNotFoundError):
self.sampler = Sampler()

# boolean tensor of length batch size with indices:
# True for unfinished sequences and
Expand Down
Loading