Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3691629
Manage supported model configurations
ckadner Sep 5, 2025
1b80333
Reorganize import statements
ckadner Sep 5, 2025
3d8ad48
Lint docs
ckadner Sep 5, 2025
c8ce7da
use 'x86_64' instead of 'amd64'
ckadner Sep 5, 2025
4decb48
typecheck updates
ckadner Sep 5, 2025
09bde9e
more typecheck updates
ckadner Sep 5, 2025
564d9b0
run isort with suggested changes
ckadner Sep 5, 2025
4223ddf
reorganize imports as isort wants them
ckadner Sep 5, 2025
dd6bd49
CI: isort show suggested import changes
ckadner Sep 5, 2025
ca46ba8
update comments in config YAML
ckadner Sep 5, 2025
214a5ce
yapf
ckadner Sep 5, 2025
d67d7b1
run type-check with Python 3.10 by default
ckadner Sep 5, 2025
03c9c76
revert unrelated changes
ckadner Sep 5, 2025
5a67e9a
Merge branch 'main' into model_configs
ckadner Sep 5, 2025
6995cad
Merge branch 'main' into model_configs
ckadner Sep 8, 2025
12bb213
address review comments, add tests
ckadner Sep 22, 2025
654f480
Merge branch 'main' into model_configs
ckadner Sep 24, 2025
566ac37
lint
ckadner Sep 24, 2025
790de2f
remove f-strings from logging statements
ckadner Sep 24, 2025
de4544e
yapf is ruff
ckadner Sep 24, 2025
dca59ba
type-check
ckadner Sep 24, 2025
8a1205b
update supported configs
ckadner Sep 25, 2025
763a112
update supported parameters
ckadner Sep 25, 2025
b2f8649
Merge branch 'main' into model_configs
ckadner Sep 29, 2025
cbb7a1b
assert c.warmup_shapes is None if use_cb
ckadner Sep 30, 2025
cc0a393
update list of supported models
ckadner Sep 30, 2025
3f48a91
requested config `<=` supported config
ckadner Sep 30, 2025
c65aa9e
Validate that prompt + new_tokens <= max_model_len
ckadner Sep 30, 2025
437290a
type-check
ckadner Sep 30, 2025
4f7a804
remove option to error out on unsupported/unknown configuration
ckadner Oct 1, 2025
3970d84
remove configurations that are within the upper bound of another config
ckadner Oct 4, 2025
05405f4
verify config parameters adhere to restrictions
ckadner Oct 4, 2025
82ccf41
Merge branch 'main' into model_configs
ckadner Oct 8, 2025
1caae76
determine model from HF-config (config.json)
ckadner Oct 11, 2025
7f254cf
Merge branch 'main' into model_configs
ckadner Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/user_guide/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,14 @@ configurations.
[Granite-Embedding-278m (Multilingual)]: https://huggingface.co/ibm-granite/granite-embedding-278m-multilingual
[BAAI/BGE-Reranker (v2-m3)]: https://huggingface.co/BAAI/bge-reranker-v2-m3
[BAAI/BGE-Reranker (Large)]: https://huggingface.co/BAAI/bge-reranker-large

## Runtime Validation

At runtime, the Spyre engine validates the requested model and configurations against the list
of supported models and configurations based on the entries in the file
<gh-file:vllm_spyre/config/supported_configurations.yaml>. If a requested model or configuration
is not found, a warning message will be logged.

```python
--8<-- "vllm_spyre/config/supported_configurations.yaml:supported-model-runtime-configurations"
```
131 changes: 131 additions & 0 deletions tests/utils/test_model_config_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging

import pytest
import yaml
from pytest import LogCaptureFixture

from vllm_spyre.config import runtime_config_validator
from vllm_spyre.config.runtime_config_validator import (
validate_runtime_configuration as validate)


def setup_log_capture(caplog: LogCaptureFixture, level=logging.INFO):
"""
Setup log capture for the test.
"""
caplog.set_level(level)
if caplog.handler not in runtime_config_validator.logger.handlers:
runtime_config_validator.logger.addHandler(caplog.handler)


@pytest.mark.utils
@pytest.mark.cpu
def test_no_eager_validation(monkeypatch, caplog):
"""
Ensure that model runtime config validation is skipped when not running on
Spyre cards.
"""
setup_log_capture(caplog, level=logging.INFO)
with monkeypatch.context() as m:
m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "eager")
validate("test/model")
assert "validation bypassed" in caplog.text


@pytest.mark.utils
@pytest.mark.cpu
def test_model_not_supported(monkeypatch, caplog):
"""
Ensure we can run model runtime config validation when (pretending to) run
on Spyre cards.
"""
setup_log_capture(caplog, level=logging.INFO)
with monkeypatch.context() as m:
m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn")
validate("test/model")
assert "Model 'test/model' is not supported" in caplog.text


@pytest.mark.utils
@pytest.mark.cpu
def test_model_runtime_configurations_file_is_valid(monkeypatch, caplog):
"""
Validate that prompts are multiples of 64
Validate that prompt + new_tokens <= max_model_len
Validate that the batch size is <= a tested upper bound.
"""
setup_log_capture(caplog, level=logging.INFO)
with monkeypatch.context() as m:
m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn")
validate("test/model") # ensure configs got loaded
mrcs = runtime_config_validator.model_runtime_configs
assert len(mrcs) > 0
for mrc in mrcs:
for c in mrc.configs:
assert c.tp_size in [1, 2, 4, 8, 16, 32]
if c.cb:
assert c.warmup_shapes is None
assert c.max_model_len % 64 == 0
assert c.max_model_len <= 32 * 1024
assert c.max_num_seqs <= 32
else:
assert c.warmup_shapes is not None
for ws in c.warmup_shapes:
assert ws[0] % 64 == 0
assert ws[0] <= 32 * 1024
assert ws[2] in [1, 2, 4, 8, 16, 32, 64]


@pytest.mark.utils
@pytest.mark.cpu
def test_model_runtime_configurations(monkeypatch, caplog):
"""
Verify that various example model runtime configurations can get validated
against a small list of sample configurations.
"""
test_configs = yaml.safe_load("""
- model: "test/model"
configs: [
{ cb: True, tp_size: 1, max_model_len: 1024, max_num_seqs: 16 },
{ cb: True, tp_size: 4, max_model_len: 2048, max_num_seqs: 8 },
{ cb: True, tp_size: 4, max_model_len: 4096, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 8192, max_num_seqs: 2 },
{ cb: False, tp_size: 1, warmup_shapes: [[64, 20, 4], [128, 20, 2]] },
{ cb: False, tp_size: 1, warmup_shapes: [[256, 20, 1]] },
{ cb: False, tp_size: 2, warmup_shapes: [[64, 20, 4]] },
]
""")
runtime_config_validator.initialize_supported_configurations(test_configs)

setup_log_capture(caplog, level=logging.INFO)

with monkeypatch.context() as m:
m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn")
m.setenv("VLLM_SPYRE_USE_CB", "1")
assert validate("test/model", 4, 2048, 8)
assert not validate("model/test", 4, 2048, 8)
# assert that individual values of a requested config can be less than
# the upper bound of a supported config
assert validate("test/model", 4, 1024, 8)
assert validate("test/model", 4, 2048, 4)

with monkeypatch.context() as m:
m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn")
m.setenv("VLLM_SPYRE_USE_CB", "0")
assert validate("test/model", 1, warmup_shapes=[[64, 20, 4]])
assert validate("test/model", 1, warmup_shapes=[[128, 20, 2]])
assert validate("test/model",
1,
warmup_shapes=[[64, 20, 4], [128, 20, 2]])
assert validate("test/model",
1,
warmup_shapes=[[128, 20, 2], [64, 20, 4]])
assert validate("test/model", 1, warmup_shapes=[[128, 19, 2]])
assert validate("test/model", 2, warmup_shapes=[[64, 19, 4]])
assert validate("test/model", 2, warmup_shapes=[[64, 19, 2]])
assert not validate(
"test/model", 2, warmup_shapes=[[64, 20, 4], [128, 20, 2]])
assert not validate("test/model",
1,
warmup_shapes=[[64, 20, 4], [128, 20, 2],
[256, 20, 1]])
Empty file added vllm_spyre/config/__init__.py
Empty file.
191 changes: 191 additions & 0 deletions vllm_spyre/config/runtime_config_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml
from vllm.logger import init_logger

from vllm_spyre import envs as envs_spyre

_config_file = Path(__file__).parent / "supported_configurations.yaml"

logger = init_logger(__name__)

# warmup_shape = [prompt_length, new_tokens, batch_size]
WarmupShapes = list[tuple[int, int, int]] | list[list[int]]


@dataclass(order=True)
class RuntimeConfiguration:
cb: bool = False
tp_size: int = 1
max_model_len: int = 0
max_num_seqs: int = 0
warmup_shapes: WarmupShapes | None = field(compare=False, default=None)

def __post_init__(self):
if self.warmup_shapes is not None:
self.warmup_shapes = [(ws[0], ws[1], ws[2])
if isinstance(ws, list) else ws
for ws in self.warmup_shapes] # yapf: disable


@dataclass
class ModelRuntimeConfiguration:
model: str
configs: list[RuntimeConfiguration] | None = None
ignore: bool = False

def __post_init__(self):
self.configs = [
RuntimeConfiguration(**cfg) if isinstance(cfg, dict) else cfg
for cfg in self.configs or []
]


model_runtime_configs: list[ModelRuntimeConfiguration] | None = None
ignored_models: set[str] = set()
runtime_configs_by_model: dict[str, list[RuntimeConfiguration]]


def load_config_yaml() -> list[dict[str, Any]]:
with open(_config_file, encoding="utf-8") as f:
yaml_data = yaml.safe_load(f)
return yaml_data


def initialize_supported_configurations(yaml_data: list[dict[str, Any]]):
global model_runtime_configs, ignored_models, runtime_configs_by_model
model_runtime_configs = [
ModelRuntimeConfiguration(**config_dict) for config_dict in yaml_data
]
ignored_models = {mrc.model for mrc in model_runtime_configs if mrc.ignore}
runtime_configs_by_model = {
mrc.model: mrc.configs or []
for mrc in model_runtime_configs if not mrc.ignore
}


def initialize_supported_configurations_from_file():
yaml_data = load_config_yaml()
initialize_supported_configurations(yaml_data)


def validate_runtime_configuration(
model: str,
tp_size: int = 0,
max_model_len: int = 0,
max_num_seqs: int = 0,
warmup_shapes: WarmupShapes | None = None) -> bool:
"""
Verify if the requested model and configuration is supported by comparing
the requested configuration to all the supported configurations.
"""
# we only validate runtime configurations when running on Spyre cards
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn":
logger.info(
"Model and runtime configuration validation bypassed for"
" backend '%s'", envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND)
return True

global model_runtime_configs
if model_runtime_configs is None:
initialize_supported_configurations_from_file()

if model in ignored_models:
logger.info("Model '%s' is ignored", model)
return True

if model not in runtime_configs_by_model:
logger.warning("Model '%s' is not supported", model)
return False

use_cb = envs_spyre.VLLM_SPYRE_USE_CB

requested_config = RuntimeConfiguration(
cb=use_cb,
tp_size=tp_size,
max_model_len=max_model_len if use_cb else 0,
max_num_seqs=max_num_seqs if use_cb else 0,
warmup_shapes=warmup_shapes if not use_cb else None)

supported_configs = runtime_configs_by_model.get(model, [])

matching_configs: list[RuntimeConfiguration] = list(
filter(
lambda supported_config: is_requested_config_supported(
requested_config=requested_config,
supported_config=supported_config),
supported_configs,
))

if len(matching_configs) == 0:
logger.warning(
"The requested configuration is not supported for"
" model '%s': %s", model, str(requested_config))
return False
else:
logger.info(
"The requested configuration is supported for"
" model '%s': %s", model, str(requested_config))
return True


def is_requested_config_supported(
requested_config: RuntimeConfiguration,
supported_config: RuntimeConfiguration) -> bool:
"""
Check if the requested configuration is supported by comparing the requested
configuration to all the supported configurations.
"""
# Don't use `if requested_configuration not in supported_configurations:...`
# since warmup shapes don't compare easily (excluded from dataclass __eq__)
# Instead, use filter here and do a set-compare for warmup_shapes separately
return (requested_config.cb == supported_config.cb
and requested_config <= supported_config
and (requested_config.cb or is_warmup_shapes_supported(
requested_config, supported_config)))


def is_warmup_shapes_supported(requested_config: RuntimeConfiguration,
supported_config: RuntimeConfiguration) -> bool:
"""
Check if the requested warmup_shapes are a subset of the supported
warmup_shapes. If a singular warmup_shape is requested, check
if its context length is less than or equal to the context length of a
supported warmup_shapes with the same (or larger) batch size.
"""
requested_shapes = requested_config.warmup_shapes or []
supported_shapes = supported_config.warmup_shapes or []
return (set(requested_shapes).issubset(set(supported_shapes))
or is_context_length_supported(requested_shapes, supported_shapes))


def is_context_length_supported(requested_shapes: WarmupShapes,
supported_shapes: WarmupShapes) -> bool:
"""
If a singular warmup_shape is requested, check if it's context length is
less than or equal to the context length for any of the supported
warmup_shapes with the same batch size (or larger supported batch size).
(context length = prompt_length + new_tokens)
"""
if len(requested_shapes) > 1:
return False
request_batch_size = requested_shapes[0][2]
supported_shapes_with_matching_batch_size = [(ws[0], ws[1], ws[2])
for ws in supported_shapes
if request_batch_size <= ws[2]
]
return (
len(supported_shapes_with_matching_batch_size) > 0 and
(get_max_model_length(requested_shapes)
<= get_max_model_length(supported_shapes_with_matching_batch_size)))


def get_max_model_length(warmup_shapes: WarmupShapes) -> int:
"""
Return the maximum model length from the given warmup shapes.
"""
# max_model_len = prompt_length + new_tokens
# warmup_shape = [prompt_length, new_tokens, batch_size]
return max([ws[0] + ws[1] for ws in warmup_shapes or []])
57 changes: 57 additions & 0 deletions vllm_spyre/config/supported_configurations.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# --8<-- [start:supported-model-runtime-configurations]

# Parameters:
# - cb: True, for continuous batching; False, for static batching mode
# - tp_size: tensor parallel size
# - max_model_len: context length (prompt_length + max_new_tokens)
# - max_num_seqs: number of sequences in a batch (per instance)
# - warmup_shapes: [(fixed_prompt_length, max_new_tokens, batch_size)]

- model: "ibm-granite/granite-3.3-8b-instruct"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[2048, 1024, 16]] },
{ cb: False, tp_size: 4, warmup_shapes: [[6144, 2048, 1]] },
{ cb: False, tp_size: 4, warmup_shapes: [[7168, 1024, 1]] },
{ cb: False, tp_size: 4, warmup_shapes: [[7168, 1024, 4]] },
{ cb: True, tp_size: 1, max_model_len: 3072, max_num_seqs: 16 },
{ cb: True, tp_size: 1, max_model_len: 8192, max_num_seqs: 4 },
{ cb: True, tp_size: 2, max_model_len: 8192, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 8192, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 16384, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 32768, max_num_seqs: 32 },
]
- model: "ibm-granite/granite-3.3-8b-instruct-FP8"
configs: [
{ cb: True, tp_size: 1, max_model_len: 3072, max_num_seqs: 16 },
{ cb: True, tp_size: 4, max_model_len: 8192, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 16384, max_num_seqs: 4 },
{ cb: True, tp_size: 4, max_model_len: 32768, max_num_seqs: 32 },
]
- model: "ibm-granite/granite-embedding-125m-english"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 1]] },
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 64]] },
]
- model: "ibm-granite/granite-embedding-278m-multilingual"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 1]] },
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 64]] },
]
- model: "BAAI/bge-reranker-v2-m3"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[8192, 0, 1]] },
]
- model: "BAAI/bge-reranker-large"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 1]] },
{ cb: False, tp_size: 1, warmup_shapes: [[512, 0, 64]] },
]
- model: "sentence-transformers/all-roberta-large-v1"
configs: [
{ cb: False, tp_size: 1, warmup_shapes: [[64, 0, 4], [64, 0, 8], [128, 0, 4], [128, 0, 8]] },
]
# --8<-- [end:supported-model-runtime-configurations]
- model: "ibm-ai-platform/micro-g3.3-8b-instruct-1b"
ignore: True
- model: "ibm-ai-platform/micro-g3.3-8b-instruct-1b-FP8"
ignore: True
Loading