diff --git a/tests/e2e/test_spyre_cb.py b/tests/e2e/test_spyre_cb.py index 181fe9b13..53914aba4 100644 --- a/tests/e2e/test_spyre_cb.py +++ b/tests/e2e/test_spyre_cb.py @@ -10,10 +10,11 @@ import pytest from openai import BadRequestError from spyre_util import (RemoteOpenAIServer, check_output_against_hf, - compare_results, generate_spyre_vllm_output, + compare_results, create_seq_prompt, extract_output, + force_engine_shutdown, generate_spyre_vllm_output, get_chicken_soup_prompts, get_spyre_model_list, skip_unsupported_tp_size) -from vllm import SamplingParams +from vllm import LLM, SamplingParams @pytest.mark.parametrize("model", get_spyre_model_list()) @@ -163,6 +164,83 @@ def test_continuous_batching_with_long_contexts(model, monkeypatch): assert vllm_cpu_results[i]["text"] == vllm_spyre_results[i]["text"] +@pytest.mark.cb +@pytest.mark.parametrize("model", get_spyre_model_list()) +@pytest.mark.parametrize( + "backend", [pytest.param("sendnn", marks=pytest.mark.spyre, id="sendnn")]) +@pytest.mark.parametrize( + "tp_size", + [ + pytest.param(4, marks=pytest.mark.multi), + ], + ids=lambda val: f"TP({val})", +) +def test_long_context_batches( + model: str, + backend: str, + tp_size: int, + monkeypatch: pytest.MonkeyPatch, +): + """Tests continuous batching with various batch sizes and prompt lengths.""" + + skip_unsupported_tp_size(tp_size, backend) + + monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1") + monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) + monkeypatch.setenv("VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER", "1") + + max_model_len = 32768 + max_num_seqs = 32 + max_tokens = 10 + + # (batch_size, prompt_length) pairs + batch_token_pairs = [ + (32, 512), + (16, 1500), + (8, 3000), + (4, 5000), + (2, 9000), + (1, 17000), + ] + + vllm_model = LLM( + model=model, + tokenizer=model, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + block_size=2048, + tensor_parallel_size=tp_size, + ) + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=0, + ignore_eos=True, + logprobs=0, + ) + + for batch_size, token_len in batch_token_pairs: + prompt = create_seq_prompt(model, token_length=token_len) + prompts = [prompt] * batch_size + + vllm_outputs = vllm_model.generate(prompts, sampling_params) + + results = [] + for req_output in vllm_outputs: + result = extract_output(req_output) + results.append(result) + + check_output_against_hf( + model=model, + backend=backend, + max_new_tokens=max_tokens, + vllm_results=results, + prompts=prompts, + ) + + force_engine_shutdown(vllm_model) + + @pytest.mark.spyre @pytest.mark.cb @pytest.mark.parametrize( diff --git a/tests/e2e/test_spyre_static_batching_limits.py b/tests/e2e/test_spyre_static_batching_limits.py index 633cf6a65..a075d4879 100644 --- a/tests/e2e/test_spyre_static_batching_limits.py +++ b/tests/e2e/test_spyre_static_batching_limits.py @@ -38,9 +38,9 @@ def test_max_prompt_len_and_new_tokens(model: str, # Craft a request with a prompt that is slightly too long for the warmup # shape prompt = create_text_prompt(model, - min_tokens=max_prompt_length, - max_tokens=max_prompt_length + max_new_tokens - - 1) + min_token_length=max_prompt_length, + max_token_length=max_prompt_length + + max_new_tokens - 1) sampling_params = SamplingParams(max_tokens=1) with pytest.raises(ValueError, match="warmup"): diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 5d3c766e4..36f8d702c 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -175,6 +175,24 @@ def patch_warmup_shapes(warmup_shapes: Union[list[tuple[int, int, int]], ','.join(str(val) for val in warmup_new_tokens)) +def extract_output(req_output): + """Extract text, token_ids, tokens, and logprobs from request output.""" + + result = {} + result['text'] = req_output.outputs[0].text + + # TODO: Workaround for V1, if request does not fit in a warmup shape + # token_ids may be filled with -1. + token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0] + result['token_ids'] = tuple(token_ids) + result['tokens'] = tuple(req_output.outputs[0].logprobs[i][t].decoded_token + for i, t in enumerate(token_ids)) + result['logprobs'] = tuple(req_output.outputs[0].logprobs[i][t].logprob + for i, t in enumerate(token_ids)) + + return result + + # vLLM / Spyre def generate_spyre_vllm_output( model: str, @@ -226,20 +244,7 @@ def generate_spyre_vllm_output( results = [] for req_output in vllm_outputs: - result = {} - result['text'] = req_output.outputs[0].text - # TODO: Workaround for V1, if request does not fit in a warmup shape - # token_ids may be filled with -1. - token_ids = [t for t in req_output.outputs[0].token_ids if t >= 0] - result['token_ids'] = tuple(token_ids) - result['tokens'] = tuple([ - req_output.outputs[0].logprobs[i][t].decoded_token - for i, t in enumerate(result['token_ids']) - ]) - result['logprobs'] = tuple([ - req_output.outputs[0].logprobs[i][t].logprob - for i, t in enumerate(result['token_ids']) - ]) + result = extract_output(req_output) results.append(result) force_engine_shutdown(vllm_model) @@ -554,7 +559,8 @@ def _default_test_models(isEmbeddings=False): return params -def create_text_prompt(model: str, min_tokens: int, max_tokens: int) -> str: +def create_text_prompt(model: str, min_token_length: int, + max_token_length: int) -> str: """Create a text prompt for the specified model that will tokenize to within the specified token length range.""" tokenizer = AutoTokenizer.from_pretrained(model) @@ -562,18 +568,41 @@ def create_text_prompt(model: str, min_tokens: int, max_tokens: int) -> str: pepper_tokens = len(tokenizer.encode(pepper, add_special_tokens=False)) # Find a good starting number of peppers - prompt = pepper * (min_tokens // pepper_tokens + 1) + prompt = pepper * (min_token_length // pepper_tokens + 1) # And add more until we're over the minimum token length - while len(tokenizer.encode(prompt)) <= min_tokens: + while len(tokenizer.encode(prompt)) <= min_token_length: prompt += pepper # Make sure this prompt is within the specified range - assert min_tokens < len(tokenizer.encode(prompt)) < max_tokens + assert min_token_length < len(tokenizer.encode(prompt)) < max_token_length return prompt +def create_seq_prompt(model: str, token_length: int) -> str: + """Create a repeating sequential number prompt for the specified + model that will tokenize to exactly the specified token length.""" + + tokenizer = AutoTokenizer.from_pretrained(model) + + # 20-token pattern + pattern = "0 1 2 3 4 5 6 7 8 9 " + + # Repeat to token_length + repeat_count = (token_length // 20) + 1 + text_prompt = pattern * repeat_count + + # Tokenize and slice + tokens = tokenizer.encode(text_prompt)[:token_length] + + # Assert exact token length + assert len(tokens) == token_length, \ + f"Token length mismatch: {len(tokens)} != {token_length}" + + return tokenizer.decode(tokens) + + def create_random_request( request_id: int, num_tokens: int,