Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ jobs:
- name: "worker and utils"
markers: "not e2e"
flags: "--timeout=300"
- name: "compatibility"
markers: "compat"
flags: "--timeout=300"

name: "${{ matrix.test_suite.name }} (${{ matrix.vllm_version.name }})"

Expand Down Expand Up @@ -90,6 +93,7 @@ jobs:
if: (steps.changed-src-files.outputs.any_changed == 'true' && matrix.vllm_version.repo)
run: |
uv add ${{ matrix.vllm_version.repo }}
echo "TEST_VLLM_VERSION=main" >> "$GITHUB_ENV"

- name: "Install vLLM with Spyre plugin"
if: steps.changed-src-files.outputs.any_changed == 'true'
Expand Down
25 changes: 25 additions & 0 deletions tests/utils/test_upstream_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

import pytest

pytestmark = pytest.mark.compat

VLLM_VERSION = os.getenv("TEST_VLLM_VERSION", "default")


def test_vllm_bert_support(monkeypatch):
'''
Test if the vllm version under test already has Bert support for V1
'''

from vllm.model_executor.models.bert import BertEmbeddingModel

bert_supports_v0_only = getattr(BertEmbeddingModel, "supports_v0_only",
False)

if VLLM_VERSION == "main":
assert not bert_supports_v0_only
else:
assert bert_supports_v0_only, (
"The currently supported vLLM version already"
"supports Bert in V1. Remove the compatibility workarounds.")
9 changes: 7 additions & 2 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if sys.platform.startswith("darwin"):
if sys.modules.get('triton'):
del sys.modules['triton']

import inspect
import math
import operator
import os
Expand Down Expand Up @@ -71,8 +71,13 @@ class SpyrePlatform(Platform):
def device_type(cls):
# TODO: temporary hack while BertModels
# inherit SupportsV0Only in vllm upstream.
import vllm.model_executor.models as me_models
from vllm.config import ModelConfig
ModelConfig.is_v1_compatible = is_v1_compatible

# no need to patch after the model_config change
if 'model_config' not in \
inspect.getfullargspec(me_models.ModelRegistry.is_v1_compatible).args:
ModelConfig.is_v1_compatible = is_v1_compatible
return cls._device_type

@classmethod
Expand Down
Loading