44"""
55
66import pytest
7- from spyre_util import (VLLM_VERSIONS , compare_results , create_random_request ,
7+ from spyre_util import (compare_results , create_random_request ,
88 generate_hf_output , generate_spyre_vllm_output ,
99 get_spyre_backend_list , get_spyre_model_list )
1010from vllm import EngineArgs , SamplingParams
3333 "warmup_shape" , [(64 , 20 , 4 ), (64 , 20 , 8 ), (128 , 20 , 4 ),
3434 (128 , 20 , 8 )]) # (prompt_length/new_tokens/batch_size)
3535@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
36- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
37- def test_output (
38- model : str ,
39- prompts : list [str ],
40- warmup_shape : tuple [int , int , int ],
41- backend : str ,
42- vllm_version : str ,
43- ) -> None :
36+ def test_output (model : str , prompts : list [str ],
37+ warmup_shape : tuple [int , int , int ], backend : str ) -> None :
4438 '''
4539 The warmup is based on a single shape. After the warmup,
4640 one request with the provided prompts is input to vLLM.
@@ -70,8 +64,7 @@ def test_output(
7064 block_size = 2048 ,
7165 sampling_params = vllm_sampling_params ,
7266 tensor_parallel_size = 1 ,
73- backend = backend ,
74- vllm_version = vllm_version )
67+ backend = backend )
7568
7669 hf_results = generate_hf_output (model = model ,
7770 prompts = prompts ,
@@ -94,13 +87,11 @@ def test_output(
9487@pytest .mark .parametrize (
9588 "warmup_shape" , [(64 , 20 , 4 )]) # (prompt_length/new_tokens/batch_size)
9689@pytest .mark .parametrize ("backend" , ["sendnn_decoder" ])
97- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
9890def test_output_sendnn_decoder (
9991 model : str ,
10092 prompts : list [str ],
10193 warmup_shape : tuple [int , int , int ],
10294 backend : str ,
103- vllm_version : str ,
10495) -> None :
10596 '''
10697 Tests the deprecated sendnn_decoder backend, which should fall-back to
@@ -123,8 +114,7 @@ def test_output_sendnn_decoder(
123114 block_size = 2048 ,
124115 sampling_params = vllm_sampling_params ,
125116 tensor_parallel_size = 1 ,
126- backend = backend ,
127- vllm_version = vllm_version )
117+ backend = backend )
128118
129119 hf_results = generate_hf_output (model = model ,
130120 prompts = prompts ,
@@ -141,11 +131,9 @@ def test_output_sendnn_decoder(
141131
142132@pytest .mark .parametrize ("model" , get_spyre_model_list ())
143133@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
144- @pytest .mark .parametrize ("vllm_version" , VLLM_VERSIONS )
145134def test_batch_handling (
146135 model : str ,
147136 backend : str ,
148- vllm_version : str ,
149137):
150138 """Test that the spyre worker correctly handles batches of requests that
151139 finish after different numbers of forward passes"""
@@ -178,8 +166,7 @@ def test_batch_handling(
178166 block_size = 2048 ,
179167 sampling_params = vllm_sampling_params ,
180168 tensor_parallel_size = 1 ,
181- backend = backend ,
182- vllm_version = vllm_version )
169+ backend = backend )
183170
184171 assert vllm_results [0 ]["text" ] == " 3 2 "
185172 assert vllm_results [1 ]["text" ] == " 6 5 4 3 2 "
@@ -189,10 +176,7 @@ def test_batch_handling(
189176
190177@pytest .mark .parametrize ("model" , get_spyre_model_list ())
191178@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
192- @pytest .mark .parametrize ("vllm_version" ,
193- [pytest .param ("V1" , marks = pytest .mark .v1 , id = "v1" )])
194- def test_full_batch_scheduling (model : str , backend : str , vllm_version : str ,
195- monkeypatch ):
179+ def test_full_batch_scheduling (model : str , backend : str , monkeypatch ):
196180 """Test that we can schedule a full batch of prompts."""
197181
198182 # We need to ensure here that the max number of tokens in a full batch
0 commit comments