44"""
55
66import pytest
7- from spyre_util import (VLLM_VERSIONS , compare_results , create_random_request ,
7+ from spyre_util import (compare_results , create_random_request ,
88 generate_hf_output , generate_spyre_vllm_output ,
99 get_spyre_backend_list , get_spyre_model_list )
1010from vllm import EngineArgs , SamplingParams
3333 "warmup_shape" , [(64 , 20 , 4 ), (64 , 20 , 8 ), (128 , 20 , 4 ),
3434 (128 , 20 , 8 )]) # (prompt_length/new_tokens/batch_size)
3535@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
36- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
3736def test_output (
3837 model : str ,
3938 prompts : list [str ],
4039 warmup_shape : tuple [int , int , int ],
4140 backend : str ,
42- vllm_version : str ,
4341 monkeypatch : pytest .MonkeyPatch ,
4442) -> None :
4543 '''
@@ -72,7 +70,6 @@ def test_output(
7270 sampling_params = vllm_sampling_params ,
7371 tensor_parallel_size = 1 ,
7472 backend = backend ,
75- vllm_version = vllm_version ,
7673 monkeypatch = monkeypatch )
7774
7875 hf_results = generate_hf_output (model = model ,
@@ -96,13 +93,11 @@ def test_output(
9693@pytest .mark .parametrize (
9794 "warmup_shape" , [(64 , 20 , 4 )]) # (prompt_length/new_tokens/batch_size)
9895@pytest .mark .parametrize ("backend" , ["sendnn_decoder" ])
99- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
10096def test_output_sendnn_decoder (
10197 model : str ,
10298 prompts : list [str ],
10399 warmup_shape : tuple [int , int , int ],
104100 backend : str ,
105- vllm_version : str ,
106101 monkeypatch : pytest .MonkeyPatch ,
107102) -> None :
108103 '''
@@ -127,7 +122,6 @@ def test_output_sendnn_decoder(
127122 sampling_params = vllm_sampling_params ,
128123 tensor_parallel_size = 1 ,
129124 backend = backend ,
130- vllm_version = vllm_version ,
131125 monkeypatch = monkeypatch )
132126
133127 hf_results = generate_hf_output (model = model ,
@@ -145,11 +139,9 @@ def test_output_sendnn_decoder(
145139
146140@pytest .mark .parametrize ("model" , get_spyre_model_list ())
147141@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
148- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
149142def test_batch_handling (
150143 model : str ,
151144 backend : str ,
152- vllm_version : str ,
153145 monkeypatch : pytest .MonkeyPatch ,
154146):
155147 """Test that the spyre worker correctly handles batches of requests that
@@ -184,7 +176,6 @@ def test_batch_handling(
184176 sampling_params = vllm_sampling_params ,
185177 tensor_parallel_size = 1 ,
186178 backend = backend ,
187- vllm_version = vllm_version ,
188179 monkeypatch = monkeypatch )
189180
190181 assert vllm_results [0 ]["text" ] == " 3 2 "
@@ -195,10 +186,7 @@ def test_batch_handling(
195186
196187@pytest .mark .parametrize ("model" , get_spyre_model_list ())
197188@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
198- @pytest .mark .parametrize ("vllm_version" ,
199- [pytest .param ("V1" , marks = pytest .mark .v1 , id = "v1" )])
200- def test_full_batch_scheduling (model : str , backend : str , vllm_version : str ,
201- monkeypatch ):
189+ def test_full_batch_scheduling (model : str , backend : str , monkeypatch ):
202190 """Test that we can schedule a full batch of prompts."""
203191
204192 # We need to ensure here that the max number of tokens in a full batch
0 commit comments