diff --git a/tests/e2e/test_spyre_cb.py b/tests/e2e/test_spyre_cb.py index ed6e7e893..150482331 100644 --- a/tests/e2e/test_spyre_cb.py +++ b/tests/e2e/test_spyre_cb.py @@ -4,8 +4,9 @@ """ import pytest -from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts, - get_spyre_model_list) +from openai import BadRequestError +from spyre_util import (RemoteOpenAIServer, generate_spyre_vllm_output, + get_chicken_soup_prompts, get_spyre_model_list) from vllm import SamplingParams @@ -44,6 +45,64 @@ def test_cb_max_tokens( monkeypatch=monkeypatch) +@pytest.mark.cb +@pytest.mark.parametrize("cb", [True]) +@pytest.mark.parametrize("model", get_spyre_model_list()) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize( + "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) +def test__api_cb_rejects_oversized_request( + remote_openai_server: RemoteOpenAIServer, + model: str, + backend: str, + cb: bool, + max_model_len: int, + max_num_seqs: int, +): + """Verify API rejects request that exceed max_model_len with CB enabled""" + + client = remote_openai_server.get_client() + overflow_prompt = " ".join(["hi"] * max_model_len) + max_tokens = 10 + + with pytest.raises(BadRequestError, + match="This model's maximum context length is"): + client.completions.create( + model=model, + prompt=overflow_prompt, + max_tokens=max_tokens, + ) + + +@pytest.mark.cb +@pytest.mark.parametrize("cb", [True]) +@pytest.mark.parametrize("model", get_spyre_model_list()) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize( + "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) +def test_api_cb_generates_correct_max_tokens( + remote_openai_server: RemoteOpenAIServer, + model: str, + backend: str, + cb: bool, + max_model_len: int, + max_num_seqs: int, +): + """Verify API generates the correct numbers of tokens with CB enabled""" + + client = remote_openai_server.get_client() + max_tokens = 10 + + response = client.completions.create(model=model, + prompt=get_chicken_soup_prompts(1), + max_tokens=max_tokens, + temperature=0) + + assert response.usage.completion_tokens == max_tokens + + @pytest.mark.cb @pytest.mark.spyre @pytest.mark.xfail # TODO: remove once a spyre-base image supports this