diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 19eab0102..bb338ae2e 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -4,10 +4,15 @@ """ import pytest +from e2e.test_spyre_cb import create_random_request from spyre_util import (VLLM_VERSIONS, compare_results, generate_hf_output, generate_spyre_vllm_output, get_spyre_backend_list, get_spyre_model_list) -from vllm import SamplingParams +from vllm import EngineArgs, SamplingParams +from vllm.v1.engine.core import EngineCore +from vllm.v1.executor.abstract import Executor + +from vllm_spyre.v1.core.scheduler import StaticBatchingSpyreScheduler template = ( "Below is an instruction that describes a task. Write a response that " @@ -128,3 +133,60 @@ def test_batch_handling( assert vllm_results[1]["text"] == " 6 5 4 3 2 " assert vllm_results[2]["text"] == " 4 3 2 " assert vllm_results[3]["text"] == "6 5 4 3 2 " + + +@pytest.mark.parametrize("model", get_spyre_model_list()) +@pytest.mark.parametrize("backend", get_spyre_backend_list()) +@pytest.mark.parametrize("vllm_version", + [pytest.param("V1", marks=pytest.mark.v1, id="v1")]) +def test_full_batch_scheduling(model: str, backend: str, vllm_version: str, + monkeypatch): + """Test that we can schedule a full batch of prompts.""" + + # We need to ensure here that the max number of tokens in a full batch + # is greater than the value set for `--max-num-batched-tokens`. + # This defaults to 2k in many cases for vllm.v1, which will cause problems + # when trying to schedule a static batch with more than 2k tokens. + # The plugin _should_ override this in config for the engine so that the + # scheduler can properly schedule a full batch. + + # Here we set `--max-num-batched-tokens` to 64, and try to schedule a batch + # of 4 x 64-token prompts + max_batched_tokens = 64 + batch_size = 4 + + # set batching config + monkeypatch.setenv("VLLM_SPYRE_WARMUP_BATCH_SIZES", f"{batch_size}") + monkeypatch.setenv("VLLM_SPYRE_WARMUP_PROMPT_LENS", + f"{max_batched_tokens}") + monkeypatch.setenv("VLLM_SPYRE_WARMUP_NEW_TOKENS", "20") + + # So we can access the engine and scheduler in this process + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + monkeypatch.setenv("VLLM_USE_V1", "1") + monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) + + # Setup the engine + engine_args = EngineArgs(model=model, + tokenizer=model, + max_num_batched_tokens=max_batched_tokens, + max_num_seqs=4) + vllm_config = engine_args.create_engine_config() + executor_class = Executor.get_class(vllm_config) + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False) + scheduler: StaticBatchingSpyreScheduler = engine_core.scheduler + + vllm_sampling_params = SamplingParams(max_tokens=20, + temperature=0, + logprobs=0) + for i in range(batch_size): + engine_core.add_request( + create_random_request(request_id=i, + num_tokens=max_batched_tokens, + sampling_params=vllm_sampling_params)) + schedule = scheduler.schedule() + + assert len(schedule.scheduled_new_reqs) == batch_size diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 655cd3889..c22ed107b 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -109,12 +109,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Cache and model config aren't set in the individual worker procs # These are set in the main engine process - # To disable any paged attention ops in the base scheduler, we both: + # To disable any paged attention ops in the base scheduler, we: # - Set the block size (in tokens) to the maximum sequence length # so that the scheduler thinks an entire sequence will fit in # one single block. # - Set the number of blocks to the maximum number of sequences, so # the scheduler always thinks there's a block available + # - Set `max_num_batched_tokens` to the size of a full batch of full + # length requests, so that the scheduler will always have token + # budget available to schedule a full batch if cache_config is not None: if envs.VLLM_USE_V1: # The V1 scheduler actually needs 2 blocks for each sequence... @@ -125,6 +128,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: scheduler_config.max_num_seqs cache_config.block_size = model_config.max_model_len + scheduler_config.max_num_batched_tokens = ( + model_config.max_model_len * scheduler_config.max_num_seqs) logger.info( "Overriding configurations based on warmup shapes. "