Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ def pytest_generate_tests(metafunc):
# When -m full_model is called, all tests tagged with
# full_model mark will be injected with these custom values
if metafunc.definition.get_closest_marker("full_model"):
_add_param(
"model",
["ibm-granite/granite-3.3-8b-instruct"],
metafunc,
existing_markers,
)
_add_param("model", get_spyre_model_list(full_size_models=True),
metafunc, existing_markers)
_add_param(
"backend",
["sendnn"],
Expand Down Expand Up @@ -316,4 +312,4 @@ def caplog_vllm_spyre(temporary_enable_log_propagate, caplog):

@pytest.fixture(scope="function", autouse=True)
def clear_env_cache():
envs.clear_env_cache()
envs.clear_env_cache()
8 changes: 6 additions & 2 deletions tests/llm_cache_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,12 @@ def _get_model(item) -> str:
for key in MODEL_KEYS:
if key in params:
SortKey._assert_param(isinstance(params[key], str | ModelInfo),
"model must be a string", item)
return params[key]
"model must be a string or ModelInfo",
item)
model_or_info = params[key]
if isinstance(model_or_info, ModelInfo):
return model_or_info.name
return model_or_info
# No assumption about default model, we likely don't need an llm if this
# isn't set
return ""
Expand Down
46 changes: 30 additions & 16 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,14 @@ def get_spyre_backend_list():
# get model names from env, if not set then use default models for each type.
# Multiple models can be specified with a comma separated list in
# VLLM_SPYRE_TEST_MODEL_LIST
def get_spyre_model_list(isEmbeddings=False, isScoring=False):
def get_spyre_model_list(isEmbeddings=False,
isScoring=False,
full_size_models=False):
"""Returns a list of pytest.params. The values are NamedTuples with a name
and revision field."""
user_test_model_list = os.environ.get("VLLM_SPYRE_TEST_MODEL_LIST")
if not user_test_model_list:
return _default_test_models(isEmbeddings, isScoring)
return _default_test_models(isEmbeddings, isScoring, full_size_models)

# User overridden model list
spyre_model_dir_path = get_spyre_model_dir_path()
Expand All @@ -243,7 +245,9 @@ def get_spyre_model_list(isEmbeddings=False, isScoring=False):
return test_model_list


def _default_test_models(isEmbeddings=False, isScoring=False):
def _default_test_models(isEmbeddings=False,
isScoring=False,
full_size_models=False):
"""Return the default set of test models as pytest parameterizations"""
if isEmbeddings:
model = ModelInfo(name="sentence-transformers/all-roberta-large-v1",
Expand All @@ -263,20 +267,30 @@ def _default_test_models(isEmbeddings=False, isScoring=False):
# We run tests for both the full-precision bf16 and fp8-quantized models,
# but by default the `pytest.mark.quantized` marker is de-selected unless
# the test command includes `-m quantized`.
tinygranite = ModelInfo(
name="ibm-ai-platform/micro-g3.3-8b-instruct-1b",
revision="6e9c6465a9d7e5e9fa35004a29f0c90befa7d23f")
tinygranite_fp8 = ModelInfo(
name="ibm-ai-platform/micro-g3.3-8b-instruct-1b-FP8",
revision="0dff8bacb968836dbbc7c2895c6d9ead0a05dc9e",
is_quantized=True)
if not full_size_models:
tinygranite = ModelInfo(
name="ibm-ai-platform/micro-g3.3-8b-instruct-1b",
revision="6e9c6465a9d7e5e9fa35004a29f0c90befa7d23f")
tinygranite_fp8 = ModelInfo(
name="ibm-ai-platform/micro-g3.3-8b-instruct-1b-FP8",
revision="0dff8bacb968836dbbc7c2895c6d9ead0a05dc9e",
is_quantized=True)
params = [
pytest.param(tinygranite,
marks=[pytest.mark.decoder],
id=tinygranite.name),
pytest.param(tinygranite_fp8,
marks=[pytest.mark.decoder, pytest.mark.quantized],
id=tinygranite_fp8.name)
]
return params

# Full sized decoders
# The granite 8b fp8 model is not publicly available yet
granite = ModelInfo(name="ibm-granite/granite-3.3-8b-instruct",
revision="51dd4bc2ade4059a6bd87649d68aa11e4fb2529b")
params = [
pytest.param(tinygranite,
marks=[pytest.mark.decoder],
id=tinygranite.name),
pytest.param(tinygranite_fp8,
marks=[pytest.mark.decoder, pytest.mark.quantized],
id=tinygranite_fp8.name)
pytest.param(granite, marks=[pytest.mark.decoder], id=granite.name),
]
return params

Expand Down