Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
"""

import pytest
from e2e.test_spyre_cb import create_random_request
from spyre_util import (VLLM_VERSIONS, compare_results, generate_hf_output,
generate_spyre_vllm_output, get_spyre_backend_list,
get_spyre_model_list)
from vllm import SamplingParams
from vllm import EngineArgs, SamplingParams
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

from vllm_spyre.v1.core.scheduler import StaticBatchingSpyreScheduler

template = (
"Below is an instruction that describes a task. Write a response that "
Expand Down Expand Up @@ -128,3 +133,61 @@ def test_batch_handling(
assert vllm_results[1]["text"] == " 6 5 4 3 2 "
assert vllm_results[2]["text"] == " 4 3 2 "
assert vllm_results[3]["text"] == "6 5 4 3 2 "


@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("vllm_version",
[pytest.param("V1", marks=pytest.mark.v1, id="v1")])
def test_full_batch_scheduling(model: str, backend: str, vllm_version: str,
monkeypatch):
"""Test that we can schedule a full batch of prompts."""

# We need to ensure here that the max number of tokens in a full batch
# is greater than the value set for `--max-num-batched-tokens`.
# This defaults to 2k in many cases for vllm.v1, which will cause problems
# when trying to schedule a static batch with more than 2k tokens.
# The plugin _should_ override this in config for the engine so that the
# scheduler can properly schedule a full batch.

# Here we set `--max-num-batched-tokens` to 64, and try to schedule a batch
# of 4 x 64-token prompts
max_batched_tokens = 64
batch_size = 4

# set batching config
monkeypatch.setenv("VLLM_SPYRE_WARMUP_BATCH_SIZES", f"{batch_size}")
monkeypatch.setenv("VLLM_SPYRE_WARMUP_PROMPT_LENS",
f"{max_batched_tokens}")
monkeypatch.setenv("VLLM_SPYRE_WARMUP_NEW_TOKENS", "20")

# So we can access the engine and scheduler in this process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)

# Setup the engine
engine_args = EngineArgs(model=model,
tokenizer=model,
max_num_batched_tokens=max_batched_tokens,
max_num_seqs=4)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False)
scheduler: StaticBatchingSpyreScheduler = engine_core.scheduler

vllm_sampling_params = SamplingParams(max_tokens=20,
temperature=0,
stop="1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need stop="1" here, looks like a copy paste relict:)

logprobs=0)
for i in range(batch_size):
engine_core.add_request(
create_random_request(request_id=i,
num_tokens=max_batched_tokens,
sampling_params=vllm_sampling_params))
schedule = scheduler.schedule()

assert len(schedule.scheduled_new_reqs) == batch_size
7 changes: 6 additions & 1 deletion vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Cache and model config aren't set in the individual worker procs
# These are set in the main engine process

# To disable any paged attention ops in the base scheduler, we both:
# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
# so that the scheduler thinks an entire sequence will fit in
# one single block.
# - Set the number of blocks to the maximum number of sequences, so
# the scheduler always thinks there's a block available
# - Set `max_num_batched_tokens` to the size of a full batch of full
# length requests, so that the scheduler will always have token
# budget available to schedule a full batch
if cache_config is not None:
if envs.VLLM_USE_V1:
# The V1 scheduler actually needs 2 blocks for each sequence...
Expand All @@ -125,6 +128,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config.max_num_seqs

cache_config.block_size = model_config.max_model_len
scheduler_config.max_num_batched_tokens = (
model_config.max_model_len * scheduler_config.max_num_seqs)

logger.info(
"Overriding configurations based on warmup shapes. "
Expand Down