File tree Expand file tree Collapse file tree 2 files changed +31
-2
lines changed Expand file tree Collapse file tree 2 files changed +31
-2
lines changed Original file line number Diff line number Diff line change @@ -87,6 +87,24 @@ def test_pooler_from_config():
8787 # https://github.com/vllm-project/vllm-spyre/pull/338
8888
8989
90+ @pytest .mark .cpu
91+ def test_pooler_default_args ():
92+
93+ from vllm .model_executor .layers .pooler import Pooler
94+ has_from_config = hasattr (Pooler , "from_config_with_defaults" )
95+
96+ if not has_from_config :
97+ annotations = inspect .getfullargspec (Pooler .for_embed ).annotations
98+ if VLLM_VERSION == "vLLM:main" :
99+ assert 'default_normalize' not in annotations
100+ assert 'default_softmax' not in annotations
101+ elif VLLM_VERSION == "vLLM:lowest" :
102+ assert 'default_normalize' in annotations
103+ assert 'default_softmax' in annotations
104+ # The compat code introduced in the PR below can now be removed:
105+ # https://github.com/vllm-project/vllm-spyre/pull/361
106+
107+
90108@pytest .mark .cpu
91109def test_engine_core_add_request ():
92110
Original file line number Diff line number Diff line change 1+ import inspect
12import math
23import time
34from abc import ABC , abstractmethod
@@ -1333,11 +1334,21 @@ def __init__(
13331334 normalize = True ,
13341335 softmax = False )
13351336 else :
1337+ # TODO: remove this when we no longer support vllm version pre this
1338+ # PR https://github.com/vllm-project/vllm/pull/20538 (post v0.10.0)
1339+ annotations = inspect .getfullargspec (Pooler .for_embed ).annotations
1340+ if ('default_normalize' in annotations
1341+ and 'default_softmax' in annotations ):
1342+ extra_args = {
1343+ 'default_normalize' : True ,
1344+ 'default_softmax' : False
1345+ }
1346+ else :
1347+ extra_args = {}
13361348 self .pooler = Pooler .for_embed (
13371349 pooler_config = pooler_config ,
13381350 default_pooling_type = PoolingType .CLS ,
1339- default_normalize = True ,
1340- default_softmax = False )
1351+ ** extra_args )
13411352
13421353 def build_input_batch (self ) -> PoolingInputBatch :
13431354 return PoolingInputBatch (
You can’t perform that action at this time.
0 commit comments