Skip to content

Commit 4d98151

Browse files
authored
Add CB API tests on the correct use of max_tokens (#339)
# Description Add tests to check if the API is generating the correct amount of tokens when CB is enabled. ## Related Issues --------- Signed-off-by: Gabriel Marinho <[email protected]>
1 parent dd8d6e7 commit 4d98151

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

tests/e2e/test_spyre_cb.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55

66
import pytest
7-
from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts,
8-
get_spyre_model_list)
7+
from openai import BadRequestError
8+
from spyre_util import (RemoteOpenAIServer, generate_spyre_vllm_output,
9+
get_chicken_soup_prompts, get_spyre_model_list)
910
from vllm import SamplingParams
1011

1112

@@ -44,6 +45,64 @@ def test_cb_max_tokens(
4445
monkeypatch=monkeypatch)
4546

4647

48+
@pytest.mark.cb
49+
@pytest.mark.parametrize("cb", [True])
50+
@pytest.mark.parametrize("model", get_spyre_model_list())
51+
@pytest.mark.parametrize("max_model_len", [256])
52+
@pytest.mark.parametrize("max_num_seqs", [2])
53+
@pytest.mark.parametrize(
54+
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
55+
def test__api_cb_rejects_oversized_request(
56+
remote_openai_server: RemoteOpenAIServer,
57+
model: str,
58+
backend: str,
59+
cb: bool,
60+
max_model_len: int,
61+
max_num_seqs: int,
62+
):
63+
"""Verify API rejects request that exceed max_model_len with CB enabled"""
64+
65+
client = remote_openai_server.get_client()
66+
overflow_prompt = " ".join(["hi"] * max_model_len)
67+
max_tokens = 10
68+
69+
with pytest.raises(BadRequestError,
70+
match="This model's maximum context length is"):
71+
client.completions.create(
72+
model=model,
73+
prompt=overflow_prompt,
74+
max_tokens=max_tokens,
75+
)
76+
77+
78+
@pytest.mark.cb
79+
@pytest.mark.parametrize("cb", [True])
80+
@pytest.mark.parametrize("model", get_spyre_model_list())
81+
@pytest.mark.parametrize("max_model_len", [256])
82+
@pytest.mark.parametrize("max_num_seqs", [2])
83+
@pytest.mark.parametrize(
84+
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
85+
def test_api_cb_generates_correct_max_tokens(
86+
remote_openai_server: RemoteOpenAIServer,
87+
model: str,
88+
backend: str,
89+
cb: bool,
90+
max_model_len: int,
91+
max_num_seqs: int,
92+
):
93+
"""Verify API generates the correct numbers of tokens with CB enabled"""
94+
95+
client = remote_openai_server.get_client()
96+
max_tokens = 10
97+
98+
response = client.completions.create(model=model,
99+
prompt=get_chicken_soup_prompts(1),
100+
max_tokens=max_tokens,
101+
temperature=0)
102+
103+
assert response.usage.completion_tokens == max_tokens
104+
105+
47106
@pytest.mark.cb
48107
@pytest.mark.spyre
49108
@pytest.mark.xfail # TODO: remove once a spyre-base image supports this

0 commit comments

Comments
 (0)