Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
27d5a25
final fms API
yannicks1 May 16, 2025
de1ccf5
Merge branch 'main' into ysc-final-fms-api
yannicks1 May 27, 2025
242284e
Merge branch 'main' into ysc-final-fms-api
yannicks1 Jun 3, 2025
1f3940d
Merge branch 'main' into ysc-final-fms-api
yannicks1 Jun 4, 2025
18b3ed1
clearer separation of attention kwargs and explicitly naming attn_name
yannicks1 Jun 4, 2025
c69b208
Merge branch 'main' into ysc-final-fms-api
yannicks1 Jun 5, 2025
825df86
update cb test fms branch
yannicks1 Jun 5, 2025
1528140
merge decode batch size 2 (#215)
yannicks1 Jun 5, 2025
1d70fed
name change in fms
yannicks1 Jun 5, 2025
e7be2e9
fix import after name change
yannicks1 Jun 5, 2025
f63c361
:arrow_up: Update locked packages (#213)
joerunde Jun 5, 2025
5906e89
[CB] refactor left padding removal (#211)
yannicks1 Jun 5, 2025
d6b7735
fixed issue with warmup_context not capturing full generate (#219)
JRosenkranz Jun 6, 2025
5dce376
fix formating
yannicks1 Jun 6, 2025
cd06756
apply to v0: fixed issue with warmup_context not capturing full generate
yannicks1 Jun 6, 2025
4fb1f42
:rewind: revert uv.lock to main branch changes
joerunde Jun 13, 2025
59fb0ff
:arrow_up: bump fms lower bound
joerunde Jun 13, 2025
c026cd2
:arrow_up: bump for vulnerability with httpcore
joerunde Jun 13, 2025
ddfa119
:arrow_up: bump for setuptools, tornado vulnerabilities
joerunde Jun 13, 2025
4da37ff
Merge branch 'main' into ysc-final-fms-api
joerunde Jun 13, 2025
e618501
:alembic: enable cb tests on main
joerunde Jun 13, 2025
14ede42
:art: fmt
joerunde Jun 13, 2025
90eb379
:zap: rollup utils tests to reduce number of jobs
joerunde Jun 13, 2025
8126efe
:fire: whoops, remove utils
joerunde Jun 13, 2025
50bdf30
Merge branch 'main' into ysc-final-fms-api
joerunde Jun 17, 2025
91be547
:bug: unmark num_blocks as dynamic for prefill
joerunde Jun 17, 2025
6203b9e
:alembic: try reverting graph changes
joerunde Jun 17, 2025
859ad9b
:bug: oops
joerunde Jun 17, 2025
9e7f017
:art: fmt
joerunde Jun 17, 2025
6ac36c5
remove obsolete comments
yannicks1 Jun 18, 2025
2a02d8c
set free blocks for warmup consistent with KV cache dimension
yannicks1 Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,14 @@ jobs:
markers: "v0 and cpu and e2e"
flags: "--timeout=300"
- name: "V1-e2e"
markers: "v1 and cpu and e2e"
markers: "v1 and cpu and e2e and not cb"
flags: "--timeout=300 --forked"
- name: "V1-worker"
markers: "v1 and not e2e"
flags: "--timeout=300"
- name: "utils"
markers: "utils"
flags: "--timeout=300"
- name: "cb"
markers: "cb"
- name: "V1-cb"
markers: "v1 and cpu and cb"
flags: "--timeout=300 --forked"
- name: "V1-worker and utils"
markers: "v1 and not e2e or utils"
flags: "--timeout=300"

name: "${{ matrix.test_suite.name }} (${{ matrix.vllm_version.name }})"

Expand Down Expand Up @@ -163,10 +160,6 @@ jobs:
# `uv run`, to avoid having `uv run` re-sync any dependencies or
# re-install the vllm_sypre package from source
source .venv/bin/activate
if [ ${{ matrix.test_suite.markers }} == "cb" ]; then
# install custom fms branch
uv pip install git+https://github.com/foundation-model-stack/foundation-model-stack@paged_attn_mock --force-reinstall
fi
# commands to run if condition is true
python3 -m pytest ${{ matrix.test_suite.flags }} \
tests -v -m "${{ matrix.test_suite.markers }}"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"
license = {text = "Apache 2"}
dependencies = [
"fms-model-optimizer>=0.2.0",
"ibm-fms==1.0.0",
"ibm-fms==1.1.0",
"vllm>=0.9.0,!=0.9.1",
]
requires-python = ">=3.9"
Expand Down
20 changes: 7 additions & 13 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
from spyre_util import (create_random_request, generate_cb_spyre_vllm_output,
get_spyre_model_list)
get_spyre_backend_list, get_spyre_model_list)
from vllm import EngineArgs, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
Expand All @@ -18,16 +18,12 @@
from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler


@pytest.mark.cb
@pytest.mark.v1
@pytest.mark.parametrize("max_num_seqs", [2, 3, 4],
ids=lambda val: f"max_num_seqs({val})")
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
@pytest.mark.parametrize("cb",
[pytest.param(1, marks=pytest.mark.cb, id="cb")])
# commenting v1 since we don't want this test to run with v1 marker yet
# @pytest.mark.parametrize("vllm_version",
# [pytest.param("V1", marks=pytest.mark.v1, id="v1")])
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize(
"prompts",
[
Expand All @@ -53,9 +49,7 @@ def test_cb_handling(
model: str,
backend: str,
max_num_seqs: int,
cb: int,
prompts: list[str],
# vllm_version: str,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the spyre worker correctly handles
Expand All @@ -80,7 +74,7 @@ def test_cb_handling(
tensor_parallel_size=1,
backend=backend,
max_num_seqs=max_num_seqs,
use_cb=cb,
use_cb=1,
monkeypatch=monkeypatch,
)

Expand Down Expand Up @@ -654,9 +648,9 @@ def augment_checked_steps(


@pytest.mark.cb
@pytest.mark.v1
@pytest.mark.parametrize("model", get_spyre_model_list())
@pytest.mark.parametrize(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("max_num_seqs", [2])
@pytest.mark.parametrize(
"seqs_max_tokens,prompts_lengths,steps_add_reqs,checked_steps,"
Expand Down
49 changes: 25 additions & 24 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 23 additions & 9 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,16 @@ def __init__(

# set num_blocks to the minimal value of 4 required for warmup
# is reset to the value returned by the Spyre compiler after warmup
self._set_past_key_value_states(num_blocks=4)
# self._set_past_key_value_states(num_blocks=4)
num_blocks = scheduler_config.max_num_seqs * max_model_len // BLOCK_SIZE
self._set_past_key_value_states(num_blocks=num_blocks)

# mark the num_blocks dimension dynamic for Spyre compiler for warmup
# only, compiler will return the number of blocks it can accommodate
for layer in self.past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 0)
# only, compiler will return the number of blocks it can accommodate.
# (This is not yet supported by the compiler)
# for layer in self.past_key_value_states:
# for tensor in layer:
# torch._dynamo.mark_dynamic(tensor, 0)

def _set_past_key_value_states(self, num_blocks) -> None:
# List[layers] of Tuple[k,v] of
Expand Down Expand Up @@ -353,17 +356,25 @@ def forward(
**extra_kwargs,
) -> torch.Tensor:

# import will be not be needed/ handled by FMS soon
import fms.utils.spyre.paged # noqa # pylint: disable=unused-import

# specify attention type for continuous batching
extra_kwargs['attn_name'] = "spyre_paged_attn"

# additional (paged) attention arguments
extra_kwargs['current_tkv_mask'] = current_tkv_mask
extra_kwargs['left_padded_prompt_mask'] = left_padded_prompt_mask
extra_kwargs['block_table'] = block_table
extra_kwargs['slot_mapping'] = slot_mapping

output = self.model(
input_ids,
position_ids=position_ids,
mask=mask,
past_key_value_states=self.past_key_value_states,
use_cache=use_cache,
only_last_token=only_last_token,
current_tkv_mask=current_tkv_mask,
left_padded_prompt_mask=left_padded_prompt_mask,
block_table=block_table,
slot_mapping=slot_mapping,
**extra_kwargs,
)

Expand Down Expand Up @@ -401,6 +412,9 @@ def forward(
**extra_kwargs,
) -> torch.Tensor:

# specify attention type for static batching
extra_kwargs['attn_name'] = "sdpa_bidirectional"

output = self.model(
input_ids,
position_ids=position_ids,
Expand Down
8 changes: 7 additions & 1 deletion vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,13 @@ def __init__(
self.tkv: int = 0
# set self.free_blocks to the minimal value of 4 required for warmup
# is reset to the value returned by the Spyre compiler after warmup
self._set_free_blocks(num_blocks=4)
# self._set_free_blocks(num_blocks=4)
# for the time being we set this to num_blocks consistent with the
# cache dimension of ContinuousBatchingFmsModel.past_key_value_states
num_blocks = (vllm_config.scheduler_config.max_num_seqs *
vllm_config.model_config.max_model_len //
self.BLOCK_SIZE)
self._set_free_blocks(num_blocks=num_blocks)
self.dummy_req_ids2blocks: list[int] = []

# TODO: Remove this once we can prefill and decode
Expand Down
Loading