Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import pytest
from openai import BadRequestError
from spyre_util import (RemoteOpenAIServer, generate_spyre_vllm_output,
get_chicken_soup_prompts, get_spyre_model_list)
from vllm import SamplingParams
from spyre_util import (RemoteOpenAIServer, create_text_prompt,
force_engine_shutdown, generate_spyre_vllm_output,
get_chicken_soup_prompts, get_spyre_backend_list,
get_spyre_model_list, skip_unsupported_tp_size)
from vllm import LLM, SamplingParams


@pytest.mark.cb
Expand Down Expand Up @@ -155,3 +157,88 @@ def test_continuous_batching_with_long_contexts(model, monkeypatch):
# As long as no sequences have top candidate tokens with very close
# logprobs, the generated text should be identical.
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]


@pytest.mark.cb
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize(
"tp_size",
[
pytest.param(4, marks=pytest.mark.multi),
],
ids=lambda val: f"TP({val})",
)
def test_long_context_batches(
model: str,
backend: str,
tp_size: int,
monkeypatch: pytest.MonkeyPatch,
):
"""Tests continuous batching with various batch sizes and prompt lengths."""

monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
monkeypatch.setenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1")

max_model_len = 32768
max_num_seqs = 32
max_tokens = 10
cases = [
(32, 512),
(16, 1500),
(8, 3000),
(4, 5000),
(2, 9000),
(1, 17000),
]

skip_unsupported_tp_size(tp_size, backend)

vllm_model = LLM(
model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
block_size=max_model_len,
tensor_parallel_size=tp_size,
)

sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0,
ignore_eos=True,
logprobs=0,
)

for batch_size, prompt_len in cases:
prompt = create_text_prompt(model,
min_tokens=prompt_len,
max_tokens=prompt_len + 1)
prompts = [prompt] * batch_size

vllm_outputs = vllm_model.generate(prompts, sampling_params)

results = []
for req_output in vllm_outputs:
token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0]
results.append({
"text":
req_output.outputs[0].text,
"token_ids":
tuple(token_ids),
"tokens":
tuple([
req_output.outputs[0].logprobs[i][t].decoded_token
for i, t in enumerate(token_ids)
]),
"logprobs":
tuple([
req_output.outputs[0].logprobs[i][t].logprob
for i, t in enumerate(token_ids)
]),
})

assert len(results) == batch_size

force_engine_shutdown(vllm_model)