Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import pytest
from openai import BadRequestError
from spyre_util import (RemoteOpenAIServer, check_output_against_hf,
compare_results, generate_spyre_vllm_output,
compare_results, create_seq_prompt, extract_output,
force_engine_shutdown, generate_spyre_vllm_output,
get_chicken_soup_prompts, get_spyre_model_list,
skip_unsupported_tp_size)
from vllm import SamplingParams
from vllm import LLM, SamplingParams


@pytest.mark.parametrize("model", get_spyre_model_list())
Expand Down Expand Up @@ -163,6 +164,83 @@ def test_continuous_batching_with_long_contexts(model, monkeypatch):
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]


@pytest.mark.cb
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"backend", [pytest.param("sendnn", marks=pytest.mark.spyre, id="sendnn")])
@pytest.mark.parametrize(
"tp_size",
[
pytest.param(4, marks=pytest.mark.multi),
],
ids=lambda val: f"TP({val})",
)
def test_long_context_batches(
model: str,
backend: str,
tp_size: int,
monkeypatch: pytest.MonkeyPatch,
):
"""Tests continuous batching with various batch sizes and prompt lengths."""

skip_unsupported_tp_size(tp_size, backend)

monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
monkeypatch.setenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1")

max_model_len = 32768
max_num_seqs = 32
max_tokens = 10

# (batch_size, prompt_length) pairs
batch_token_pairs = [
(32, 512),
(16, 1500),
(8, 3000),
(4, 5000),
(2, 9000),
(1, 17000),
]

vllm_model = LLM(
model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
block_size=2048,
tensor_parallel_size=tp_size,
)

sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0,
ignore_eos=True,
logprobs=0,
)

for batch_size, token_len in batch_token_pairs:
prompt = create_seq_prompt(model, token_length=token_len)
prompts = [prompt] * batch_size

vllm_outputs = vllm_model.generate(prompts, sampling_params)

results = []
for req_output in vllm_outputs:
result = extract_output(req_output)
results.append(result)

check_output_against_hf(
model=model,
backend=backend,
max_new_tokens=max_tokens,
vllm_results=results,
prompts=prompts,
)

force_engine_shutdown(vllm_model)


@pytest.mark.spyre
@pytest.mark.cb
@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_spyre_static_batching_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def test_max_prompt_len_and_new_tokens(model: str,
# Craft a request with a prompt that is slightly too long for the warmup
# shape
prompt = create_text_prompt(model,
min_tokens=max_prompt_length,
max_tokens=max_prompt_length + max_new_tokens -
1)
min_token_length=max_prompt_length,
max_token_length=max_prompt_length +
max_new_tokens - 1)
sampling_params = SamplingParams(max_tokens=1)

with pytest.raises(ValueError, match="warmup"):
Expand Down
65 changes: 47 additions & 18 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,24 @@ def patch_warmup_shapes(warmup_shapes: Union[list[tuple[int, int, int]],
','.join(str(val) for val in warmup_new_tokens))


def extract_output(req_output):
"""Extract text, token_ids, tokens, and logprobs from request output."""

result = {}
result['text'] = req_output.outputs[0].text

# TODO: Workaround for V1, if request does not fit in a warmup shape
# token_ids may be filled with -1.
token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0]
result['token_ids'] = tuple(token_ids)
result['tokens'] = tuple(req_output.outputs[0].logprobs[i][t].decoded_token
for i, t in enumerate(token_ids))
result['logprobs'] = tuple(req_output.outputs[0].logprobs[i][t].logprob
for i, t in enumerate(token_ids))

return result


# vLLM / Spyre
def generate_spyre_vllm_output(
model: str,
Expand Down Expand Up @@ -226,20 +244,7 @@ def generate_spyre_vllm_output(
results = []

for req_output in vllm_outputs:
result = {}
result['text'] = req_output.outputs[0].text
# TODO: Workaround for V1, if request does not fit in a warmup shape
# token_ids may be filled with -1.
token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0]
result['token_ids'] = tuple(token_ids)
result['tokens'] = tuple([
req_output.outputs[0].logprobs[i][t].decoded_token
for i, t in enumerate(result['token_ids'])
])
result['logprobs'] = tuple([
req_output.outputs[0].logprobs[i][t].logprob
for i, t in enumerate(result['token_ids'])
])
result = extract_output(req_output)
results.append(result)

force_engine_shutdown(vllm_model)
Expand Down Expand Up @@ -554,26 +559,50 @@ def _default_test_models(isEmbeddings=False):
return params


def create_text_prompt(model: str, min_tokens: int, max_tokens: int) -> str:
def create_text_prompt(model: str, min_token_length: int,
max_token_length: int) -> str:
"""Create a text prompt for the specified model that will tokenize to within
the specified token length range."""
tokenizer = AutoTokenizer.from_pretrained(model)
pepper = "🌶️"
pepper_tokens = len(tokenizer.encode(pepper, add_special_tokens=False))

# Find a good starting number of peppers
prompt = pepper * (min_tokens // pepper_tokens + 1)
prompt = pepper * (min_token_length // pepper_tokens + 1)

# And add more until we're over the minimum token length
while len(tokenizer.encode(prompt)) <= min_tokens:
while len(tokenizer.encode(prompt)) <= min_token_length:
prompt += pepper

# Make sure this prompt is within the specified range
assert min_tokens < len(tokenizer.encode(prompt)) < max_tokens
assert min_token_length < len(tokenizer.encode(prompt)) < max_token_length

return prompt


def create_seq_prompt(model: str, token_length: int) -> str:
"""Create a repeating sequential number prompt for the specified
model that will tokenize to exactly the specified token length."""

tokenizer = AutoTokenizer.from_pretrained(model)

# 20-token pattern
pattern = "0 1 2 3 4 5 6 7 8 9 "

# Repeat to token_length
repeat_count = (token_length // 20) + 1
text_prompt = pattern * repeat_count

# Tokenize and slice
tokens = tokenizer.encode(text_prompt)[:token_length]

# Assert exact token length
assert len(tokens) == token_length, \
f"Token length mismatch: {len(tokens)} != {token_length}"

return tokenizer.decode(tokens)


def create_random_request(
request_id: int,
num_tokens: int,
Expand Down
Loading