Skip to content

Commit f76c391

Browse files
authored
[Tests] Add long context batch tests (#365)
# Description - Cleans up `create_text_prompt` - Adds `create_seq_prompt` - Adds six test cases: - 32 prompts with 512 tokens each - 16 prompts with 1.5k tokens each - 8 prompts with 3k tokens each - 4 prompts with 5k tokens each - 2 prompts with 9k tokens each - 1 prompt with 17k tokens --------- Signed-off-by: Rafael Vasquez <[email protected]>
1 parent 9ac2d65 commit f76c391

File tree

3 files changed

+130
-23
lines changed

3 files changed

+130
-23
lines changed

tests/e2e/test_spyre_cb.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import pytest
1111
from openai import BadRequestError
1212
from spyre_util import (RemoteOpenAIServer, check_output_against_hf,
13-
compare_results, generate_spyre_vllm_output,
13+
compare_results, create_seq_prompt, extract_output,
14+
force_engine_shutdown, generate_spyre_vllm_output,
1415
get_chicken_soup_prompts, get_spyre_model_list,
1516
skip_unsupported_tp_size)
16-
from vllm import SamplingParams
17+
from vllm import LLM, SamplingParams
1718

1819

1920
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -163,6 +164,83 @@ def test_continuous_batching_with_long_contexts(model, monkeypatch):
163164
assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"]
164165

165166

167+
@pytest.mark.cb
168+
@pytest.mark.parametrize("model", get_spyre_model_list())
169+
@pytest.mark.parametrize(
170+
"backend", [pytest.param("sendnn", marks=pytest.mark.spyre, id="sendnn")])
171+
@pytest.mark.parametrize(
172+
"tp_size",
173+
[
174+
pytest.param(4, marks=pytest.mark.multi),
175+
],
176+
ids=lambda val: f"TP({val})",
177+
)
178+
def test_long_context_batches(
179+
model: str,
180+
backend: str,
181+
tp_size: int,
182+
monkeypatch: pytest.MonkeyPatch,
183+
):
184+
"""Tests continuous batching with various batch sizes and prompt lengths."""
185+
186+
skip_unsupported_tp_size(tp_size, backend)
187+
188+
monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")
189+
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
190+
monkeypatch.setenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1")
191+
192+
max_model_len = 32768
193+
max_num_seqs = 32
194+
max_tokens = 10
195+
196+
# (batch_size, prompt_length) pairs
197+
batch_token_pairs = [
198+
(32, 512),
199+
(16, 1500),
200+
(8, 3000),
201+
(4, 5000),
202+
(2, 9000),
203+
(1, 17000),
204+
]
205+
206+
vllm_model = LLM(
207+
model=model,
208+
tokenizer=model,
209+
max_model_len=max_model_len,
210+
max_num_seqs=max_num_seqs,
211+
block_size=2048,
212+
tensor_parallel_size=tp_size,
213+
)
214+
215+
sampling_params = SamplingParams(
216+
max_tokens=max_tokens,
217+
temperature=0,
218+
ignore_eos=True,
219+
logprobs=0,
220+
)
221+
222+
for batch_size, token_len in batch_token_pairs:
223+
prompt = create_seq_prompt(model, token_length=token_len)
224+
prompts = [prompt] * batch_size
225+
226+
vllm_outputs = vllm_model.generate(prompts, sampling_params)
227+
228+
results = []
229+
for req_output in vllm_outputs:
230+
result = extract_output(req_output)
231+
results.append(result)
232+
233+
check_output_against_hf(
234+
model=model,
235+
backend=backend,
236+
max_new_tokens=max_tokens,
237+
vllm_results=results,
238+
prompts=prompts,
239+
)
240+
241+
force_engine_shutdown(vllm_model)
242+
243+
166244
@pytest.mark.spyre
167245
@pytest.mark.cb
168246
@pytest.mark.parametrize(

tests/e2e/test_spyre_static_batching_limits.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def test_max_prompt_len_and_new_tokens(model: str,
3838
# Craft a request with a prompt that is slightly too long for the warmup
3939
# shape
4040
prompt = create_text_prompt(model,
41-
min_tokens=max_prompt_length,
42-
max_tokens=max_prompt_length + max_new_tokens -
43-
1)
41+
min_token_length=max_prompt_length,
42+
max_token_length=max_prompt_length +
43+
max_new_tokens - 1)
4444
sampling_params = SamplingParams(max_tokens=1)
4545

4646
with pytest.raises(ValueError, match="warmup"):

tests/spyre_util.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,24 @@ def patch_warmup_shapes(warmup_shapes: Union[list[tuple[int, int, int]],
175175
','.join(str(val) for val in warmup_new_tokens))
176176

177177

178+
def extract_output(req_output):
179+
"""Extract text, token_ids, tokens, and logprobs from request output."""
180+
181+
result = {}
182+
result['text'] = req_output.outputs[0].text
183+
184+
# TODO: Workaround for V1, if request does not fit in a warmup shape
185+
# token_ids may be filled with -1.
186+
token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0]
187+
result['token_ids'] = tuple(token_ids)
188+
result['tokens'] = tuple(req_output.outputs[0].logprobs[i][t].decoded_token
189+
for i, t in enumerate(token_ids))
190+
result['logprobs'] = tuple(req_output.outputs[0].logprobs[i][t].logprob
191+
for i, t in enumerate(token_ids))
192+
193+
return result
194+
195+
178196
# vLLM / Spyre
179197
def generate_spyre_vllm_output(
180198
model: str,
@@ -226,20 +244,7 @@ def generate_spyre_vllm_output(
226244
results = []
227245

228246
for req_output in vllm_outputs:
229-
result = {}
230-
result['text'] = req_output.outputs[0].text
231-
# TODO: Workaround for V1, if request does not fit in a warmup shape
232-
# token_ids may be filled with -1.
233-
token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0]
234-
result['token_ids'] = tuple(token_ids)
235-
result['tokens'] = tuple([
236-
req_output.outputs[0].logprobs[i][t].decoded_token
237-
for i, t in enumerate(result['token_ids'])
238-
])
239-
result['logprobs'] = tuple([
240-
req_output.outputs[0].logprobs[i][t].logprob
241-
for i, t in enumerate(result['token_ids'])
242-
])
247+
result = extract_output(req_output)
243248
results.append(result)
244249

245250
force_engine_shutdown(vllm_model)
@@ -554,26 +559,50 @@ def _default_test_models(isEmbeddings=False):
554559
return params
555560

556561

557-
def create_text_prompt(model: str, min_tokens: int, max_tokens: int) -> str:
562+
def create_text_prompt(model: str, min_token_length: int,
563+
max_token_length: int) -> str:
558564
"""Create a text prompt for the specified model that will tokenize to within
559565
the specified token length range."""
560566
tokenizer = AutoTokenizer.from_pretrained(model)
561567
pepper = "🌶️"
562568
pepper_tokens = len(tokenizer.encode(pepper, add_special_tokens=False))
563569

564570
# Find a good starting number of peppers
565-
prompt = pepper * (min_tokens // pepper_tokens + 1)
571+
prompt = pepper * (min_token_length // pepper_tokens + 1)
566572

567573
# And add more until we're over the minimum token length
568-
while len(tokenizer.encode(prompt)) <= min_tokens:
574+
while len(tokenizer.encode(prompt)) <= min_token_length:
569575
prompt += pepper
570576

571577
# Make sure this prompt is within the specified range
572-
assert min_tokens < len(tokenizer.encode(prompt)) < max_tokens
578+
assert min_token_length < len(tokenizer.encode(prompt)) < max_token_length
573579

574580
return prompt
575581

576582

583+
def create_seq_prompt(model: str, token_length: int) -> str:
584+
"""Create a repeating sequential number prompt for the specified
585+
model that will tokenize to exactly the specified token length."""
586+
587+
tokenizer = AutoTokenizer.from_pretrained(model)
588+
589+
# 20-token pattern
590+
pattern = "0 1 2 3 4 5 6 7 8 9 "
591+
592+
# Repeat to token_length
593+
repeat_count = (token_length // 20) + 1
594+
text_prompt = pattern * repeat_count
595+
596+
# Tokenize and slice
597+
tokens = tokenizer.encode(text_prompt)[:token_length]
598+
599+
# Assert exact token length
600+
assert len(tokens) == token_length, \
601+
f"Token length mismatch: {len(tokens)} != {token_length}"
602+
603+
return tokenizer.decode(tokens)
604+
605+
577606
def create_random_request(
578607
request_id: int,
579608
num_tokens: int,

0 commit comments

Comments
 (0)