Skip to content

Commit e8ca058

Browse files
authored
[embedding] support newest vllm main branch (#361)
### [embedding] support newest vllm main branch parameters `default_normalize` and `default_softmax` are deprecated --------- Signed-off-by: Yannick Schnider <[email protected]>
1 parent 8fab64f commit e8ca058

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

tests/utils/test_upstream_compatibility.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ def test_pooler_from_config():
8787
# https://github.com/vllm-project/vllm-spyre/pull/338
8888

8989

90+
@pytest.mark.cpu
91+
def test_pooler_default_args():
92+
93+
from vllm.model_executor.layers.pooler import Pooler
94+
has_from_config = hasattr(Pooler, "from_config_with_defaults")
95+
96+
if not has_from_config:
97+
annotations = inspect.getfullargspec(Pooler.for_embed).annotations
98+
if VLLM_VERSION == "vLLM:main":
99+
assert 'default_normalize' not in annotations
100+
assert 'default_softmax' not in annotations
101+
elif VLLM_VERSION == "vLLM:lowest":
102+
assert 'default_normalize' in annotations
103+
assert 'default_softmax' in annotations
104+
# The compat code introduced in the PR below can now be removed:
105+
# https://github.com/vllm-project/vllm-spyre/pull/361
106+
107+
90108
@pytest.mark.cpu
91109
def test_engine_core_add_request():
92110

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import math
23
import time
34
from abc import ABC, abstractmethod
@@ -1333,11 +1334,21 @@ def __init__(
13331334
normalize=True,
13341335
softmax=False)
13351336
else:
1337+
# TODO: remove this when we no longer support vllm version pre this
1338+
# PR https://github.com/vllm-project/vllm/pull/20538 (post v0.10.0)
1339+
annotations = inspect.getfullargspec(Pooler.for_embed).annotations
1340+
if ('default_normalize' in annotations
1341+
and 'default_softmax' in annotations):
1342+
extra_args = {
1343+
'default_normalize': True,
1344+
'default_softmax': False
1345+
}
1346+
else:
1347+
extra_args = {}
13361348
self.pooler = Pooler.for_embed(
13371349
pooler_config=pooler_config,
13381350
default_pooling_type=PoolingType.CLS,
1339-
default_normalize=True,
1340-
default_softmax=False)
1351+
**extra_args)
13411352

13421353
def build_input_batch(self) -> PoolingInputBatch:
13431354
return PoolingInputBatch(

0 commit comments

Comments
 (0)