Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_output(
warmup_shape: tuple[int, int, int],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
Expand Down Expand Up @@ -71,7 +72,8 @@ def test_output(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_output_sendnn_decoder(
warmup_shape: tuple[int, int, int],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
Tests the deprecated sendnn_decoder backend, which should fall-back to
Expand All @@ -124,7 +127,8 @@ def test_output_sendnn_decoder(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
Expand All @@ -146,6 +150,7 @@ def test_batch_handling(
model: str,
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the spyre worker correctly handles batches of requests that
finish after different numbers of forward passes"""
Expand Down Expand Up @@ -179,7 +184,8 @@ def test_batch_handling(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

assert vllm_results[0]["text"] == " 3 2 "
assert vllm_results[1]["text"] == " 6 5 4 3 2 "
Expand Down
75 changes: 55 additions & 20 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,46 @@
from typing import Any

import pytest
<<<<<<< HEAD

Check failure on line 11 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:11:7: SyntaxError: Expected a statement

Check failure on line 11 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:11:5: SyntaxError: Expected a statement

Check failure on line 11 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:11:3: SyntaxError: Expected a statement

Check failure on line 11 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:11:1: SyntaxError: Expected a statement
from spyre_util import (compare_results, create_random_request,
generate_hf_output, generate_spyre_vllm_output,
get_spyre_model_list)
=======

Check failure on line 15 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:15:7: SyntaxError: Expected a statement

Check failure on line 15 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:15:5: SyntaxError: Expected a statement

Check failure on line 15 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:15:3: SyntaxError: Expected a statement

Check failure on line 15 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:15:1: SyntaxError: Expected a statement
from spyre_util import (create_random_request, generate_cb_spyre_vllm_output,

Check failure on line 16 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:15:8: SyntaxError: Expected a statement
get_spyre_backend_list, get_spyre_model_list)
>>>>>>> origin/main

Check failure on line 18 in tests/e2e/test_spyre_cb.py

View workflow job for this annotation

GitHub Actions / lint-code (3.12)

Ruff

tests/e2e/test_spyre_cb.py:18:1: SyntaxError: Expected a statement
from vllm import EngineArgs, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler

template = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request. Be polite in your response to the "
"user.\n\n### Instruction:\n{}\n\n### Response:")

<<<<<<< HEAD

@pytest.mark.cb
@pytest.mark.parametrize("max_num_seqs", [2, 3, 4],
ids=lambda val: f"max_num_seqs({val})")
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
# commenting v1 since we don't want this test to run with v1 marker yet
# @pytest.mark.v1
@pytest.mark.parametrize("prompts", [[
template.format("Provide a list of instructions "
"for preparing chicken soup."),
template.format("Provide me a list of things that I can do with my "
"new found wealth."),
template.format(
"how do I add multiple new columns in m for power query or power bi?"),
template.format("Convert char to string in Java."),
]])
=======
@pytest.mark.cb
@pytest.mark.v1
@pytest.mark.parametrize("max_num_seqs", [2, 3, 4],
Expand Down Expand Up @@ -45,6 +75,7 @@
],
ids=lambda val: f"num_prompts({len(val)})",
)
>>>>>>> origin/main
def test_cb_handling(
model: str,
backend: str,
Expand All @@ -56,16 +87,17 @@
continuous batches of requests that
finish after different numbers of forward passes"""

vllm_sampling_params = SamplingParams(max_tokens=20,
max_tokens = 20

vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
temperature=0,
stop="1",
ignore_eos=True,
logprobs=0)

# Ensure that both:
# - The model doesn't crash
# - The output sequences are correct
vllm_results = generate_cb_spyre_vllm_output(
vllm_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
max_model_len=2048,
Expand All @@ -74,30 +106,33 @@
tensor_parallel_size=1,
backend=backend,
max_num_seqs=max_num_seqs,
use_cb=1,
monkeypatch=monkeypatch,
)
use_cb=True,
vllm_version="V1", # CB runs in V1 only
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
max_new_tokens=max_tokens)

for i, prompt in enumerate(prompts):
assert (vllm_results[i]["text"] == [
" " + " ".join(
str(i)
for i in range(int(prompt.split()[-1]) - 1, 1, -1)) + " "
][0])
compare_results(model=model,
prompts=prompts,
warmup_shapes=[],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_results,
hf_results=hf_results)


@pytest.mark.cb
# @pytest.mark.v1
@pytest.mark.parametrize("max_num_seqs", [2])
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
@pytest.mark.parametrize("cb",
[pytest.param(1, marks=pytest.mark.cb, id="cb")])
# @pytest.mark.v1
def test_cb_max_tokens(
model: str,
backend: str,
max_num_seqs: int,
cb: int,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that continuous batches of requests that
Expand All @@ -114,7 +149,7 @@
logprobs=0)

with pytest.raises(ValueError, match="max model context length"):
generate_cb_spyre_vllm_output(
generate_spyre_vllm_output(
model=model,
prompts=overflow_prompt,
max_model_len=max_model_len,
Expand All @@ -123,9 +158,9 @@
tensor_parallel_size=1,
backend=backend,
max_num_seqs=max_num_seqs,
use_cb=cb,
monkeypatch=monkeypatch,
)
use_cb=True,
vllm_version="V1", # CB runs in V1 only
monkeypatch=monkeypatch)


def get_params_test_blocks_borders_aligned_prompts():
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_spyre_max_new_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_output(
warmup_shape: tuple[int, int, int],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
Expand Down Expand Up @@ -87,7 +88,8 @@ def test_output(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_spyre_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_seed(
warmup_shape: tuple[int, int, int],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
Expand Down Expand Up @@ -60,7 +61,8 @@ def test_seed(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

# compare all generated outputs against the first generated output
for vllm_result in vllm_results:
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_spyre_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_output(
tp_size: int,
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on one or multiple shapes. After the warmup,
Expand Down Expand Up @@ -65,7 +66,8 @@ def test_output(
sampling_params=vllm_sampling_params,
tensor_parallel_size=tp_size,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
Expand Down
8 changes: 6 additions & 2 deletions tests/e2e/test_spyre_warmup_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_output(
warmup_shapes: list[tuple[int, int, int]],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on two shapes, that 'overlap' each
Expand Down Expand Up @@ -72,7 +73,8 @@ def test_output(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)

hf_results = generate_hf_output(model=model,
prompts=prompts,
Expand All @@ -99,6 +101,7 @@ def test_invalid_prompt_len(
warmup_shapes: list[tuple[int, int, int]],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
Expects an error to be raised if the warmup prompt length
Expand All @@ -119,4 +122,5 @@ def test_invalid_prompt_len(
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
vllm_version=vllm_version)
vllm_version=vllm_version,
monkeypatch=monkeypatch)
Loading
Loading