Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tests/e2e/test_spyre_warmup_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
Run `python -m pytest tests/test_spyre_warmup_shapes.py`.
"""

import os

import pytest
from spyre_util import (compare_results, generate_hf_output,
generate_spyre_vllm_output, get_spyre_backend_list,
get_spyre_model_list)
from vllm import SamplingParams


# temporary for filtering until bug with caching gets fixed
@pytest.mark.skipif(
os.environ.get("TORCH_SENDNN_CACHE_ENABLE") == "1",
reason="torch_sendnn caching is currently broken with this configuration")
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize("prompts", [
7 * [
Expand All @@ -34,7 +40,6 @@ def test_output(
warmup_shapes: list[tuple[int, int, int]],
backend: str,
vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
'''
The warmup is based on two shapes, that 'overlap' each
Expand All @@ -55,8 +60,6 @@ def test_output(
test using 'pytest --capture=no tests/spyre/test_spyre_warmup_shapes.py'
After debugging, DISABLE_ASSERTS should be reset to 'False'.
'''
# temporary until bug with caching gets fixed
monkeypatch.setenv("TORCH_SENDNN_CACHE_ENABLE", "0")

max_new_tokens = max([t[1] for t in warmup_shapes])

Expand Down