Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,6 @@ def test_output(model: ModelInfo, tp_size: int, backend: str, cb: int,
**kwargs)


@pytest.mark.parametrize("backend", [
pytest.param(
"sendnn_decoder", marks=pytest.mark.spyre, id="sendnn_decoder")
])
def test_output_sendnn_decoder(model: ModelInfo,
warmup_shapes: DecodeWarmupShapes, backend: str,
monkeypatch: pytest.MonkeyPatch,
use_llm_cache) -> None:
'''
Tests the deprecated sendnn_decoder backend, which should fall-back to
sendnn
'''

max_new_tokens = warmup_shapes[0][1]
prompts = get_chicken_soup_prompts(1)

vllm_sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=True)

validate_vllm_vs_hf_output(model=model,
prompts=prompts,
warmup_shapes=warmup_shapes,
max_model_len=2048,
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend,
monkeypatch=monkeypatch,
max_new_tokens=max_new_tokens)


def test_batch_handling(model: ModelInfo, backend: str, cb: int, warmup_shapes,
max_num_seqs: int, max_model_len: int,
monkeypatch: pytest.MonkeyPatch, use_llm_cache):
Expand Down
50 changes: 50 additions & 0 deletions tests/utils/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Tests for our environment configs"""

import os

import pytest

from vllm_spyre import envs

pytestmark = pytest.mark.cpu


def test_env_vars_are_cached(monkeypatch):
monkeypatch.setenv("VLLM_SPYRE_NUM_CPUS", "42")
assert envs.VLLM_SPYRE_NUM_CPUS == 42

# Future reads don't query the environment every time, so this should not
# return the updated value
monkeypatch.setenv("VLLM_SPYRE_NUM_CPUS", "77")
assert envs.VLLM_SPYRE_NUM_CPUS == 42


def test_env_vars_override(monkeypatch):
monkeypatch.setenv("VLLM_SPYRE_NUM_CPUS", "42")
assert envs.VLLM_SPYRE_NUM_CPUS == 42

# This override both sets the environment variable and updates our cache
envs.override("VLLM_SPYRE_NUM_CPUS", "77")
assert envs.VLLM_SPYRE_NUM_CPUS == 77
assert os.getenv("VLLM_SPYRE_NUM_CPUS") == "77"


def test_env_vars_override_with_bad_value(monkeypatch):
monkeypatch.setenv("VLLM_SPYRE_NUM_CPUS", "42")
assert envs.VLLM_SPYRE_NUM_CPUS == 42

# envs.override ensures the value can be parsed correctly
with pytest.raises(ValueError, match=r"invalid literal for int"):
envs.override("VLLM_SPYRE_NUM_CPUS", "notanumber")


def test_env_vars_override_for_invalid_config():
with pytest.raises(ValueError, match=r"not a known setting"):
envs.override("VLLM_SPYRE_NOT_A_CONFIG", "nothing")


def test_sendnn_decoder_backwards_compat(monkeypatch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🌶️ 🌶️ Test simplification

# configuring the deprecated `sendnn_decoder` backend will swap to the new
# `sendnn` backend instead
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder")
assert envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn"
13 changes: 3 additions & 10 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,10 @@

def override(name: str, value: str) -> None:
if name not in environment_variables:
raise ValueError(f"The variable {name} is not a known \
setting and cannot be overridden")
original_value = os.environ.get(name)
raise ValueError(f"The variable {name} is not a known "
"setting and cannot be overridden")
os.environ[name] = value
try:
parsed_value = environment_variables[name]()
_cache[name] = parsed_value
# Changes back avoid polluting the global environment
finally:
if original_value is not None:
os.environ[name] = original_value
_cache[name] = environment_variables[name]()


def clear_env_cache():
Expand Down