|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import pytest |
7 | | -from spyre_util import (generate_spyre_vllm_output, get_chicken_soup_prompts, |
8 | | - get_spyre_model_list) |
| 7 | +from openai import BadRequestError |
| 8 | +from spyre_util import (RemoteOpenAIServer, generate_spyre_vllm_output, |
| 9 | + get_chicken_soup_prompts, get_spyre_model_list) |
9 | 10 | from vllm import SamplingParams |
10 | 11 |
|
11 | 12 |
|
@@ -44,6 +45,64 @@ def test_cb_max_tokens( |
44 | 45 | monkeypatch=monkeypatch) |
45 | 46 |
|
46 | 47 |
|
| 48 | +@pytest.mark.cb |
| 49 | +@pytest.mark.parametrize("cb", [True]) |
| 50 | +@pytest.mark.parametrize("model", get_spyre_model_list()) |
| 51 | +@pytest.mark.parametrize("max_model_len", [256]) |
| 52 | +@pytest.mark.parametrize("max_num_seqs", [2]) |
| 53 | +@pytest.mark.parametrize( |
| 54 | + "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) |
| 55 | +def test__api_cb_rejects_oversized_request( |
| 56 | + remote_openai_server: RemoteOpenAIServer, |
| 57 | + model: str, |
| 58 | + backend: str, |
| 59 | + cb: bool, |
| 60 | + max_model_len: int, |
| 61 | + max_num_seqs: int, |
| 62 | +): |
| 63 | + """Verify API rejects request that exceed max_model_len with CB enabled""" |
| 64 | + |
| 65 | + client = remote_openai_server.get_client() |
| 66 | + overflow_prompt = " ".join(["hi"] * max_model_len) |
| 67 | + max_tokens = 10 |
| 68 | + |
| 69 | + with pytest.raises(BadRequestError, |
| 70 | + match="This model's maximum context length is"): |
| 71 | + client.completions.create( |
| 72 | + model=model, |
| 73 | + prompt=overflow_prompt, |
| 74 | + max_tokens=max_tokens, |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +@pytest.mark.cb |
| 79 | +@pytest.mark.parametrize("cb", [True]) |
| 80 | +@pytest.mark.parametrize("model", get_spyre_model_list()) |
| 81 | +@pytest.mark.parametrize("max_model_len", [256]) |
| 82 | +@pytest.mark.parametrize("max_num_seqs", [2]) |
| 83 | +@pytest.mark.parametrize( |
| 84 | + "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) |
| 85 | +def test_api_cb_generates_correct_max_tokens( |
| 86 | + remote_openai_server: RemoteOpenAIServer, |
| 87 | + model: str, |
| 88 | + backend: str, |
| 89 | + cb: bool, |
| 90 | + max_model_len: int, |
| 91 | + max_num_seqs: int, |
| 92 | +): |
| 93 | + """Verify API generates the correct numbers of tokens with CB enabled""" |
| 94 | + |
| 95 | + client = remote_openai_server.get_client() |
| 96 | + max_tokens = 10 |
| 97 | + |
| 98 | + response = client.completions.create(model=model, |
| 99 | + prompt=get_chicken_soup_prompts(1), |
| 100 | + max_tokens=max_tokens, |
| 101 | + temperature=0) |
| 102 | + |
| 103 | + assert response.usage.completion_tokens == max_tokens |
| 104 | + |
| 105 | + |
47 | 106 | @pytest.mark.cb |
48 | 107 | @pytest.mark.spyre |
49 | 108 | @pytest.mark.xfail # TODO: remove once a spyre-base image supports this |
|
0 commit comments