88from typing import Any
99
1010import pytest
11- < << << << HEAD
1211from spyre_util import (compare_results , create_random_request ,
1312 generate_hf_output , generate_spyre_vllm_output ,
14- get_spyre_model_list )
15- == == == =
16- from spyre_util import (create_random_request , generate_cb_spyre_vllm_output ,
1713 get_spyre_backend_list , get_spyre_model_list )
18- > >> >> >> origin / main
1914from vllm import EngineArgs , SamplingParams
2015from vllm .v1 .engine import EngineCoreRequest
2116from vllm .v1 .engine .core import EngineCore
2823 "appropriately completes the request. Be polite in your response to the "
2924 "user.\n \n ### Instruction:\n {}\n \n ### Response:" )
3025
31- < << << << HEAD
3226
3327@pytest .mark .cb
3428@pytest .mark .parametrize ("max_num_seqs" , [2 , 3 , 4 ],
3529 ids = lambda val : f"max_num_seqs({ val } )" )
3630@pytest .mark .parametrize ("model" , get_spyre_model_list ())
37- @pytest .mark .parametrize (
38- "backend" , [pytest .param ("eager" , marks = pytest .mark .cpu , id = "eager" )])
39- # commenting v1 since we don't want this test to run with v1 marker yet
40- # @pytest.mark.v1
31+ @pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
4132@pytest .mark .parametrize ("prompts" , [[
4233 template .format ("Provide a list of instructions "
4334 "for preparing chicken soup." ),
4738 "how do I add multiple new columns in m for power query or power bi?" ),
4839 template .format ("Convert char to string in Java." ),
4940]])
50- == == == =
51- @pytest .mark .cb
52- @pytest .mark .v1
53- @pytest .mark .parametrize ("max_num_seqs" , [2 , 3 , 4 ],
54- ids = lambda val : f"max_num_seqs({ val } )" )
55- @pytest .mark .parametrize ("model" , get_spyre_model_list ())
56- @pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
57- @pytest .mark .parametrize (
58- "prompts" ,
59- [
60- [
61- "7 6 5 4" ,
62- "10 9 8 7" ,
63- ],
64- [
65- "7 6 5 4" ,
66- "10 9 8 7" ,
67- "8 7 6 5" ,
68- ],
69- [
70- "7 6 5 4" ,
71- "10 9 8 7" ,
72- "8 7 6 5" ,
73- "9 8 7 6" ,
74- ],
75- ],
76- ids = lambda val : f"num_prompts({ len (val )} )" ,
77- )
78- > >> >> >> origin / main
7941def test_cb_handling (
8042 model : str ,
8143 backend : str ,
@@ -107,7 +69,6 @@ def test_cb_handling(
10769 backend = backend ,
10870 max_num_seqs = max_num_seqs ,
10971 use_cb = True ,
110- vllm_version = "V1" , # CB runs in V1 only
11172 monkeypatch = monkeypatch )
11273
11374 hf_results = generate_hf_output (model = model ,
@@ -124,7 +85,6 @@ def test_cb_handling(
12485
12586
12687@pytest .mark .cb
127- # @pytest.mark.v1
12888@pytest .mark .parametrize ("max_num_seqs" , [2 ])
12989@pytest .mark .parametrize ("model" , get_spyre_model_list ())
13090@pytest .mark .parametrize (
@@ -149,18 +109,16 @@ def test_cb_max_tokens(
149109 logprobs = 0 )
150110
151111 with pytest .raises (ValueError , match = "max model context length" ):
152- generate_spyre_vllm_output (
153- model = model ,
154- prompts = overflow_prompt ,
155- max_model_len = max_model_len ,
156- block_size = max_model_len ,
157- sampling_params = vllm_sampling_params ,
158- tensor_parallel_size = 1 ,
159- backend = backend ,
160- max_num_seqs = max_num_seqs ,
161- use_cb = True ,
162- vllm_version = "V1" , # CB runs in V1 only
163- monkeypatch = monkeypatch )
112+ generate_spyre_vllm_output (model = model ,
113+ prompts = overflow_prompt ,
114+ max_model_len = max_model_len ,
115+ block_size = max_model_len ,
116+ sampling_params = vllm_sampling_params ,
117+ tensor_parallel_size = 1 ,
118+ backend = backend ,
119+ max_num_seqs = max_num_seqs ,
120+ use_cb = True ,
121+ monkeypatch = monkeypatch )
164122
165123
166124def get_params_test_blocks_borders_aligned_prompts ():
@@ -683,7 +641,6 @@ def augment_checked_steps(
683641
684642
685643@pytest .mark .cb
686- @pytest .mark .v1
687644@pytest .mark .parametrize ("model" , get_spyre_model_list ())
688645@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
689646@pytest .mark .parametrize ("max_num_seqs" , [2 ])
0 commit comments