Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
Run `python -m pytest tests/e2e/test_spyre_cb.py`.
"""

import math

import pytest
from openai import BadRequestError
from spyre_util import (RemoteOpenAIServer, generate_spyre_vllm_output,
get_chicken_soup_prompts, get_spyre_model_list)
get_chicken_soup_prompts, get_spyre_backend_list,
get_spyre_model_list, skip_unsupported_tp_size)
from vllm import SamplingParams


Expand Down Expand Up @@ -155,3 +158,68 @@ def test_continuous_batching_with_long_contexts(model, monkeypatch):
# As long as no sequences have top candidate tokens with very close
# logprobs, the generated text should be identical.
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]


@pytest.mark.cb
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize(
"tp_size",
[
pytest.param(4, marks=pytest.mark.multi),
],
ids=lambda val: f"TP({val})",
)
@pytest.mark.parametrize(
"batch_size,prompt_len",
[
(32, 512),
(16, 1500),
(8, 3000),
(4, 5000),
(2, 9000),
(1, 17000),
],
)
def test_long_context_batches(
model: str,
backend: str,
tp_size: int,
batch_size: int,
prompt_len: int,
monkeypatch: pytest.MonkeyPatch,
):
"""Tests continuous batching with various batch sizes and prompt lengths."""

skip_unsupported_tp_size(tp_size, backend)

max_model_len = 32768
max_num_seqs = 32
max_tokens = 10

# Extend tokens to at least the required length
base_prompt = get_chicken_soup_prompts(1)[0]
words = base_prompt.split()
prompt = " ".join(words * math.ceil(prompt_len / len(words)))

prompts = [prompt] * batch_size

sampling_params = SamplingParams(max_tokens=max_tokens,
temperature=0,
ignore_eos=True,
logprobs=0)

results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=max_model_len,
block_size=max_model_len,
sampling_params=sampling_params,
tensor_parallel_size=tp_size,
backend=backend,
max_num_seqs=max_num_seqs,
use_cb=True,
monkeypatch=monkeypatch,
)

assert len(results) == batch_size