diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 380d81d4..a06ce79e 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -4,7 +4,7 @@ """ import pytest -from output_util import check_output_against_hf, generate_spyre_vllm_output +from output_util import validate_vllm_vs_hf_output from spyre_util import (DecodeWarmupShapes, ModelInfo, create_random_request, get_chicken_soup_prompts, patch_environment, skip_unsupported_tp_size) @@ -53,17 +53,15 @@ def test_output(model: ModelInfo, tp_size: int, backend: str, cb: int, logprobs=0, # return logprobs of generated tokens only ignore_eos=True) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - sampling_params=vllm_sampling_params, - tensor_parallel_size=tp_size, - backend=backend, - monkeypatch=monkeypatch, - max_model_len=max_model_len, - **kwargs) - check_output_against_hf(model, backend, max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + sampling_params=vllm_sampling_params, + tensor_parallel_size=tp_size, + backend=backend, + monkeypatch=monkeypatch, + max_model_len=max_model_len, + max_new_tokens=max_new_tokens, + **kwargs) @pytest.mark.parametrize("backend", [ @@ -88,18 +86,15 @@ def test_output_sendnn_decoder(model: ModelInfo, logprobs=0, # return logprobs of generated tokens only ignore_eos=True) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - warmup_shapes=warmup_shapes, - max_model_len=2048, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch) - - check_output_against_hf(model, backend, max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + max_model_len=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + monkeypatch=monkeypatch, + max_new_tokens=max_new_tokens) def test_batch_handling(model: ModelInfo, backend: str, cb: int, warmup_shapes, @@ -134,18 +129,15 @@ def test_batch_handling(model: ModelInfo, backend: str, cb: int, warmup_shapes, "warmup_shapes": warmup_shapes } - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - max_model_len=max_model_len, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch, - **kwargs) - - check_output_against_hf(model, backend, max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + max_model_len=max_model_len, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + monkeypatch=monkeypatch, + max_new_tokens=max_new_tokens, + **kwargs) def test_full_batch_scheduling(model: ModelInfo, backend: str, monkeypatch): diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index 48278c90..54503030 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -4,7 +4,7 @@ """ import pytest -from output_util import check_output_against_hf, generate_spyre_vllm_output +from output_util import validate_vllm_vs_hf_output from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts from vllm import SamplingParams @@ -59,15 +59,12 @@ def test_output(model: ModelInfo, stop_last: bool, max_model_len: int, "warmup_shapes": warmup_shapes }) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch, - max_model_len=max_model_len, - **kwargs) - - check_output_against_hf(model, backend, hf_max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + monkeypatch=monkeypatch, + max_new_tokens=hf_max_new_tokens, + max_model_len=max_model_len, + **kwargs) diff --git a/tests/e2e/test_spyre_stagger_basic.py b/tests/e2e/test_spyre_stagger_basic.py index 422d53c1..a3707134 100644 --- a/tests/e2e/test_spyre_stagger_basic.py +++ b/tests/e2e/test_spyre_stagger_basic.py @@ -5,7 +5,7 @@ """ import pytest -from output_util import check_output_against_hf, generate_spyre_vllm_output +from output_util import validate_vllm_vs_hf_output from spyre_util import (ModelInfo, get_chicken_soup_prompts, skip_unsupported_tp_size) from vllm import SamplingParams @@ -44,14 +44,12 @@ def test_stagger_output(model: ModelInfo, tp_size: int, backend: str, cb: int, logprobs=0, # return logprobs of generated tokens only ignore_eos=True) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - sampling_params=vllm_sampling_params, - tensor_parallel_size=tp_size, - backend=backend, - monkeypatch=monkeypatch, - max_model_len=max_model_len, - **kwargs) - check_output_against_hf(model, backend, max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + sampling_params=vllm_sampling_params, + tensor_parallel_size=tp_size, + backend=backend, + monkeypatch=monkeypatch, + max_model_len=max_model_len, + max_new_tokens=max_new_tokens, + **kwargs) diff --git a/tests/e2e/test_spyre_warmup_shapes.py b/tests/e2e/test_spyre_warmup_shapes.py index 7c13f428..977c59a2 100644 --- a/tests/e2e/test_spyre_warmup_shapes.py +++ b/tests/e2e/test_spyre_warmup_shapes.py @@ -4,12 +4,11 @@ """ import pytest -from output_util import check_output_against_hf, generate_spyre_vllm_output +from output_util import generate_spyre_vllm_output, validate_vllm_vs_hf_output from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts from vllm import SamplingParams -@pytest.mark.xfail(reason="Failing currently because of output mismatch") @pytest.mark.parametrize( "warmup_shapes", [[(64, 20, 4), (128, 20, 2)]]) # (prompt_length/new_tokens/batch_size) @@ -42,18 +41,15 @@ def test_multiple_warmup_shapes(model: ModelInfo, logprobs=0, # return logprobs of generated tokens only ignore_eos=True) - vllm_results = generate_spyre_vllm_output( - model=model, - prompts=prompts, - warmup_shapes=warmup_shapes, - max_model_len=2048, - sampling_params=vllm_sampling_params, - tensor_parallel_size=1, - backend=backend, - monkeypatch=monkeypatch) - - check_output_against_hf(model, backend, max_new_tokens, vllm_results, - prompts) + validate_vllm_vs_hf_output(model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + max_model_len=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend, + max_new_tokens=max_new_tokens, + monkeypatch=monkeypatch) @pytest.mark.parametrize("prompts", [["Hello"]]) diff --git a/tests/llm_cache.py b/tests/llm_cache.py index d6fca798..b545d299 100644 --- a/tests/llm_cache.py +++ b/tests/llm_cache.py @@ -1,6 +1,7 @@ """Contains utilities for caching models (instantiated as vLLM endpoints) across test cases, to speed up test runtime.""" +import os from typing import Callable, Generic, Optional, TypeVar import pytest @@ -178,6 +179,12 @@ def get_engine( revision = None model_name = model + # Register golden token injector if not disabled + disable_golden_token = \ + bool(int(os.getenv("VLLM_SPYRE_TEST_DISABLE_GOLDEN_TOKEN", "0"))) + logits_processors = [] if disable_golden_token else \ + [GoldenTokenInjector] + # 🌶️🌶️🌶️ # Messing with the blocks and context length by either: # - setting context < 512 tokens @@ -192,15 +199,13 @@ def get_engine( # Spyre compilation. This seems more robust and helps that all tests in # tests/e2e/test_spyre_cb_inference_steps.py pass on Spyre. max_num_seqs_compiled = 1 << (max_num_seqs - 1).bit_length() - engine_args = EngineArgs( - model=model_name, - tokenizer=model_name, - max_model_len=max(max_model_len, 512), - max_num_seqs=max_num_seqs_compiled, - num_gpu_blocks_override=None, - revision=revision, - # We always include it, but does not means we always use it - logits_processors=[GoldenTokenInjector]) + engine_args = EngineArgs(model=model_name, + tokenizer=model_name, + max_model_len=max(max_model_len, 512), + max_num_seqs=max_num_seqs_compiled, + num_gpu_blocks_override=None, + revision=revision, + logits_processors=logits_processors) vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) diff --git a/tests/output_util.py b/tests/output_util.py index 37d751da..a413f266 100644 --- a/tests/output_util.py +++ b/tests/output_util.py @@ -19,8 +19,10 @@ DISABLE_ASSERTS = False # used for debugging -ISCLOSE_ABS_TOL = 0.08 -ISCLOSE_ABS_TOL_QUANTIZATION = 0.125 +ISCLOSE_ABS_TOL = \ + float(os.environ.get("VLLM_SPYRE_TEST_ABS_TOL", '0.08')) +ISCLOSE_ABS_TOL_QUANTIZATION = \ + float(os.environ.get("VLLM_SPYRE_TEST_QUANTIZED_ABS_TOL", '0.125')) HF_RESULT_CACHE = HFResultCache() @@ -370,6 +372,88 @@ def spyre_vllm_embeddings( return results +def setup_golden_token( + model: ModelInfo, + sampling_params: Union[SamplingParams, list[SamplingParams]], + hf_outputs: list[dict[str, Any]], +) -> Union[SamplingParams, list[SamplingParams]]: + + abs_tol = ISCLOSE_ABS_TOL_QUANTIZATION if model.is_quantized \ + else ISCLOSE_ABS_TOL + + if isinstance(sampling_params, SamplingParams): + # Single Sampling params case + hf = hf_outputs[0] + sampling_params.extra_args = { + "golden_token_injector": { + "expected_token_ids": hf['token_ids'], + "expected_logprobs": hf['logprobs'], + "error_threshold": abs_tol, + "label": "#0" + } + } + return sampling_params + + # Multiple sampling params case + assert len(sampling_params) == len(hf_outputs) + for idx, (param, hf) in enumerate(zip(sampling_params, hf_outputs)): + param.extra_args = { + "golden_token_injector": { + "expected_token_ids": hf['token_ids'], + "expected_logprobs": hf['logprobs'], + "error_threshold": abs_tol, + "label": f"#{idx}" + } + } + return sampling_params + + +def validate_vllm_vs_hf_output( + model: ModelInfo, + prompts: Union[list[str], list[list[int]]], + max_model_len: int, + max_new_tokens: Union[int, list[int]], + sampling_params: Union[SamplingParams, list[SamplingParams]], + tensor_parallel_size: int, + backend: str, + monkeypatch: pytest.MonkeyPatch, + warmup_shapes: DecodeWarmupShapes | None = None, + max_num_seqs: Optional[int] = None, + use_cb: bool = False, + use_golden_token=True, +) -> None: + hf_outputs = generate_hf_output( + model=model, + prompts=prompts, + max_new_tokens=max_new_tokens, + ignore_eos=True, + ) + + if use_golden_token: + sampling_params = setup_golden_token(model, sampling_params, + hf_outputs) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + max_model_len=max_model_len, + sampling_params=sampling_params, + tensor_parallel_size=tensor_parallel_size, + backend=backend, + monkeypatch=monkeypatch, + warmup_shapes=warmup_shapes, + max_num_seqs=max_num_seqs, + use_cb=use_cb, + ) + + compare_results(model=model, + tensor_parallel_size=1, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_outputs, + prompts=prompts) + + # vLLM / Spyre def generate_spyre_vllm_output( model: str | ModelInfo,