diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6094e4272..9bda3f6f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -43,17 +43,14 @@ jobs: markers: "v0 and cpu and e2e" flags: "--timeout=300" - name: "V1-e2e" - markers: "v1 and cpu and e2e" + markers: "v1 and cpu and e2e and not cb" flags: "--timeout=300 --forked" - - name: "V1-worker" - markers: "v1 and not e2e" - flags: "--timeout=300" - - name: "utils" - markers: "utils" - flags: "--timeout=300" - - name: "cb" - markers: "cb" + - name: "V1-cb" + markers: "v1 and cpu and cb" flags: "--timeout=300 --forked" + - name: "V1-worker and utils" + markers: "v1 and not e2e or utils" + flags: "--timeout=300" name: "${{ matrix.test_suite.name }} (${{ matrix.vllm_version.name }})" @@ -163,10 +160,6 @@ jobs: # `uv run`, to avoid having `uv run` re-sync any dependencies or # re-install the vllm_sypre package from source source .venv/bin/activate - if [ ${{ matrix.test_suite.markers }} == "cb" ]; then - # install custom fms branch - uv pip install git+https://github.com/foundation-model-stack/foundation-model-stack@paged_attn_mock --force-reinstall - fi # commands to run if condition is true python3 -m pytest ${{ matrix.test_suite.flags }} \ tests -v -m "${{ matrix.test_suite.markers }}" diff --git a/pyproject.toml b/pyproject.toml index 78711203d..e5be4e999 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ readme = "README.md" license = {text = "Apache 2"} dependencies = [ "fms-model-optimizer>=0.2.0", - "ibm-fms==1.0.0", + "ibm-fms==1.1.0", "vllm>=0.9.0,!=0.9.1", ] requires-python = ">=3.9" diff --git a/tests/e2e/test_spyre_cb.py b/tests/e2e/test_spyre_cb.py index c9df1c0bc..c57db3531 100644 --- a/tests/e2e/test_spyre_cb.py +++ b/tests/e2e/test_spyre_cb.py @@ -9,7 +9,7 @@ import pytest from spyre_util import (create_random_request, generate_cb_spyre_vllm_output, - get_spyre_model_list) + get_spyre_backend_list, get_spyre_model_list) from vllm import EngineArgs, SamplingParams from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore @@ -18,16 +18,12 @@ from vllm_spyre.v1.core.scheduler import ContinuousBatchingSpyreScheduler +@pytest.mark.cb +@pytest.mark.v1 @pytest.mark.parametrize("max_num_seqs", [2, 3, 4], ids=lambda val: f"max_num_seqs({val})") @pytest.mark.parametrize("model", get_spyre_model_list()) -@pytest.mark.parametrize( - "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) -@pytest.mark.parametrize("cb", - [pytest.param(1, marks=pytest.mark.cb, id="cb")]) -# commenting v1 since we don't want this test to run with v1 marker yet -# @pytest.mark.parametrize("vllm_version", -# [pytest.param("V1", marks=pytest.mark.v1, id="v1")]) +@pytest.mark.parametrize("backend", get_spyre_backend_list()) @pytest.mark.parametrize( "prompts", [ @@ -53,9 +49,7 @@ def test_cb_handling( model: str, backend: str, max_num_seqs: int, - cb: int, prompts: list[str], - # vllm_version: str, monkeypatch: pytest.MonkeyPatch, ): """Test that the spyre worker correctly handles @@ -80,7 +74,7 @@ def test_cb_handling( tensor_parallel_size=1, backend=backend, max_num_seqs=max_num_seqs, - use_cb=cb, + use_cb=1, monkeypatch=monkeypatch, ) @@ -654,9 +648,9 @@ def augment_checked_steps( @pytest.mark.cb +@pytest.mark.v1 @pytest.mark.parametrize("model", get_spyre_model_list()) -@pytest.mark.parametrize( - "backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")]) +@pytest.mark.parametrize("backend", get_spyre_backend_list()) @pytest.mark.parametrize("max_num_seqs", [2]) @pytest.mark.parametrize( "seqs_max_tokens,prompts_lengths,steps_add_reqs,checked_steps," diff --git a/uv.lock b/uv.lock index 7d617a3cc..af21ef57b 100644 --- a/uv.lock +++ b/uv.lock @@ -1247,11 +1247,11 @@ wheels = [ [[package]] name = "h11" -version = "0.14.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418, upload-time = "2022-09-25T15:40:01.519Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] [[package]] @@ -1271,15 +1271,15 @@ wheels = [ [[package]] name = "httpcore" -version = "1.0.7" +version = "1.0.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi", marker = "python_full_version >= '3.10'" }, { name = "h11", marker = "python_full_version >= '3.10'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196, upload-time = "2024-11-15T12:30:47.531Z" } +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551, upload-time = "2024-11-15T12:30:45.782Z" }, + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, ] [[package]] @@ -1366,13 +1366,13 @@ hf-xet = [ [[package]] name = "ibm-fms" -version = "1.0.0" +version = "1.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "torch", marker = "python_full_version >= '3.10' and sys_platform == 'never'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/c5/960cbde1eb640115e7ffab3878ed0da3ee46364810fd790ba7a2d607bc2c/ibm_fms-1.0.0-py3-none-any.whl", hash = "sha256:b943744cc15777f8a4971f0feefd6bb089beec549e29cec8bbd9999c17b62ce5", size = 160644, upload-time = "2025-05-16T18:57:35.2Z" }, + { url = "https://files.pythonhosted.org/packages/e6/53/29c8588a74e7909756201e85e28c761641e80ec02205fbbfab86fe09b0e4/ibm_fms-1.1.0-py3-none-any.whl", hash = "sha256:f0fed9a07f1f166e8e676b4060c0f40e43371a72894afbf3a8ad98e61e1bf07e", size = 166993, upload-time = "2025-06-13T19:57:33.557Z" }, ] [[package]] @@ -4199,11 +4199,11 @@ wheels = [ [[package]] name = "setuptools" -version = "78.1.0" +version = "79.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/5a/0db4da3bc908df06e5efae42b44e75c81dd52716e10192ff36d0c1c8e379/setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54", size = 1367827, upload-time = "2025-03-25T22:49:35.332Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/71/b6365e6325b3290e14957b2c3a804a529968c77a049b2ed40c095f749707/setuptools-79.0.1.tar.gz", hash = "sha256:128ce7b8f33c3079fd1b067ecbb4051a66e8526e7b65f6cec075dfc650ddfa88", size = 1367909, upload-time = "2025-04-23T22:20:59.241Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/21/f43f0a1fa8b06b32812e0975981f4677d28e0f3271601dc88ac5a5b83220/setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8", size = 1256108, upload-time = "2025-03-25T22:49:33.13Z" }, + { url = "https://files.pythonhosted.org/packages/0d/6d/b4752b044bf94cb802d88a888dc7d288baaf77d7910b7dedda74b5ceea0c/setuptools-79.0.1-py3-none-any.whl", hash = "sha256:e147c0549f27767ba362f9da434eab9c5dc0045d5304feb602a0af001089fc51", size = 1256281, upload-time = "2025-04-23T22:20:56.768Z" }, ] [[package]] @@ -4474,20 +4474,21 @@ dependencies = [ [[package]] name = "tornado" -version = "6.4.2" +version = "6.5.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/59/45/a0daf161f7d6f36c3ea5fc0c2de619746cc3dd4c76402e9db545bd920f63/tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b", size = 501135, upload-time = "2024-11-22T03:06:38.036Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/89/c72771c81d25d53fe33e3dca61c233b665b2780f21820ba6fd2c6793c12b/tornado-6.5.1.tar.gz", hash = "sha256:84ceece391e8eb9b2b95578db65e920d2a61070260594819589609ba9bc6308c", size = 509934, upload-time = "2025-05-22T18:15:38.788Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/26/7e/71f604d8cea1b58f82ba3590290b66da1e72d840aeb37e0d5f7291bd30db/tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1", size = 436299, upload-time = "2024-11-22T03:06:20.162Z" }, - { url = "https://files.pythonhosted.org/packages/96/44/87543a3b99016d0bf54fdaab30d24bf0af2e848f1d13d34a3a5380aabe16/tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803", size = 434253, upload-time = "2024-11-22T03:06:22.39Z" }, - { url = "https://files.pythonhosted.org/packages/cb/fb/fdf679b4ce51bcb7210801ef4f11fdac96e9885daa402861751353beea6e/tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec", size = 437602, upload-time = "2024-11-22T03:06:24.214Z" }, - { url = "https://files.pythonhosted.org/packages/4f/3b/e31aeffffc22b475a64dbeb273026a21b5b566f74dee48742817626c47dc/tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946", size = 436972, upload-time = "2024-11-22T03:06:25.559Z" }, - { url = "https://files.pythonhosted.org/packages/22/55/b78a464de78051a30599ceb6983b01d8f732e6f69bf37b4ed07f642ac0fc/tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf", size = 437173, upload-time = "2024-11-22T03:06:27.584Z" }, - { url = "https://files.pythonhosted.org/packages/79/5e/be4fb0d1684eb822c9a62fb18a3e44a06188f78aa466b2ad991d2ee31104/tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634", size = 437892, upload-time = "2024-11-22T03:06:28.933Z" }, - { url = "https://files.pythonhosted.org/packages/f5/33/4f91fdd94ea36e1d796147003b490fe60a0215ac5737b6f9c65e160d4fe0/tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73", size = 437334, upload-time = "2024-11-22T03:06:30.428Z" }, - { url = "https://files.pythonhosted.org/packages/2b/ae/c1b22d4524b0e10da2f29a176fb2890386f7bd1f63aacf186444873a88a0/tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c", size = 437261, upload-time = "2024-11-22T03:06:32.458Z" }, - { url = "https://files.pythonhosted.org/packages/b5/25/36dbd49ab6d179bcfc4c6c093a51795a4f3bed380543a8242ac3517a1751/tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482", size = 438463, upload-time = "2024-11-22T03:06:34.71Z" }, - { url = "https://files.pythonhosted.org/packages/61/cc/58b1adeb1bb46228442081e746fcdbc4540905c87e8add7c277540934edb/tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38", size = 438907, upload-time = "2024-11-22T03:06:36.71Z" }, + { url = "https://files.pythonhosted.org/packages/77/89/f4532dee6843c9e0ebc4e28d4be04c67f54f60813e4bf73d595fe7567452/tornado-6.5.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d50065ba7fd11d3bd41bcad0825227cc9a95154bad83239357094c36708001f7", size = 441948, upload-time = "2025-05-22T18:15:20.862Z" }, + { url = "https://files.pythonhosted.org/packages/15/9a/557406b62cffa395d18772e0cdcf03bed2fff03b374677348eef9f6a3792/tornado-6.5.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:9e9ca370f717997cb85606d074b0e5b247282cf5e2e1611568b8821afe0342d6", size = 440112, upload-time = "2025-05-22T18:15:22.591Z" }, + { url = "https://files.pythonhosted.org/packages/55/82/7721b7319013a3cf881f4dffa4f60ceff07b31b394e459984e7a36dc99ec/tornado-6.5.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b77e9dfa7ed69754a54c89d82ef746398be82f749df69c4d3abe75c4d1ff4888", size = 443672, upload-time = "2025-05-22T18:15:24.027Z" }, + { url = "https://files.pythonhosted.org/packages/7d/42/d11c4376e7d101171b94e03cef0cbce43e823ed6567ceda571f54cf6e3ce/tornado-6.5.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253b76040ee3bab8bcf7ba9feb136436a3787208717a1fb9f2c16b744fba7331", size = 443019, upload-time = "2025-05-22T18:15:25.735Z" }, + { url = "https://files.pythonhosted.org/packages/7d/f7/0c48ba992d875521ac761e6e04b0a1750f8150ae42ea26df1852d6a98942/tornado-6.5.1-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:308473f4cc5a76227157cdf904de33ac268af770b2c5f05ca6c1161d82fdd95e", size = 443252, upload-time = "2025-05-22T18:15:27.499Z" }, + { url = "https://files.pythonhosted.org/packages/89/46/d8d7413d11987e316df4ad42e16023cd62666a3c0dfa1518ffa30b8df06c/tornado-6.5.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:caec6314ce8a81cf69bd89909f4b633b9f523834dc1a352021775d45e51d9401", size = 443930, upload-time = "2025-05-22T18:15:29.299Z" }, + { url = "https://files.pythonhosted.org/packages/78/b2/f8049221c96a06df89bed68260e8ca94beca5ea532ffc63b1175ad31f9cc/tornado-6.5.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:13ce6e3396c24e2808774741331638ee6c2f50b114b97a55c5b442df65fd9692", size = 443351, upload-time = "2025-05-22T18:15:31.038Z" }, + { url = "https://files.pythonhosted.org/packages/76/ff/6a0079e65b326cc222a54720a748e04a4db246870c4da54ece4577bfa702/tornado-6.5.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5cae6145f4cdf5ab24744526cc0f55a17d76f02c98f4cff9daa08ae9a217448a", size = 443328, upload-time = "2025-05-22T18:15:32.426Z" }, + { url = "https://files.pythonhosted.org/packages/49/18/e3f902a1d21f14035b5bc6246a8c0f51e0eef562ace3a2cea403c1fb7021/tornado-6.5.1-cp39-abi3-win32.whl", hash = "sha256:e0a36e1bc684dca10b1aa75a31df8bdfed656831489bc1e6a6ebed05dc1ec365", size = 444396, upload-time = "2025-05-22T18:15:34.205Z" }, + { url = "https://files.pythonhosted.org/packages/7b/09/6526e32bf1049ee7de3bebba81572673b19a2a8541f795d887e92af1a8bc/tornado-6.5.1-cp39-abi3-win_amd64.whl", hash = "sha256:908e7d64567cecd4c2b458075589a775063453aeb1d2a1853eedb806922f568b", size = 444840, upload-time = "2025-05-22T18:15:36.1Z" }, + { url = "https://files.pythonhosted.org/packages/55/a7/535c44c7bea4578e48281d83c615219f3ab19e6abc67625ef637c73987be/tornado-6.5.1-cp39-abi3-win_arm64.whl", hash = "sha256:02420a0eb7bf617257b9935e2b754d1b63897525d8a289c9d65690d580b4dcf7", size = 443596, upload-time = "2025-05-22T18:15:37.433Z" }, ] [[package]] @@ -4815,7 +4816,7 @@ lint = [ [package.metadata] requires-dist = [ { name = "fms-model-optimizer", specifier = ">=0.2.0" }, - { name = "ibm-fms", specifier = "==1.0.0" }, + { name = "ibm-fms", specifier = "==1.1.0" }, { name = "vllm", specifier = ">=0.9.0,!=0.9.1" }, ] diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index a6465f520..d827a1b00 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -314,13 +314,16 @@ def __init__( # set num_blocks to the minimal value of 4 required for warmup # is reset to the value returned by the Spyre compiler after warmup - self._set_past_key_value_states(num_blocks=4) + # self._set_past_key_value_states(num_blocks=4) + num_blocks = scheduler_config.max_num_seqs * max_model_len // BLOCK_SIZE + self._set_past_key_value_states(num_blocks=num_blocks) # mark the num_blocks dimension dynamic for Spyre compiler for warmup - # only, compiler will return the number of blocks it can accommodate - for layer in self.past_key_value_states: - for tensor in layer: - torch._dynamo.mark_dynamic(tensor, 0) + # only, compiler will return the number of blocks it can accommodate. + # (This is not yet supported by the compiler) + # for layer in self.past_key_value_states: + # for tensor in layer: + # torch._dynamo.mark_dynamic(tensor, 0) def _set_past_key_value_states(self, num_blocks) -> None: # List[layers] of Tuple[k,v] of @@ -353,6 +356,18 @@ def forward( **extra_kwargs, ) -> torch.Tensor: + # import will be not be needed/ handled by FMS soon + import fms.utils.spyre.paged # noqa # pylint: disable=unused-import + + # specify attention type for continuous batching + extra_kwargs['attn_name'] = "spyre_paged_attn" + + # additional (paged) attention arguments + extra_kwargs['current_tkv_mask'] = current_tkv_mask + extra_kwargs['left_padded_prompt_mask'] = left_padded_prompt_mask + extra_kwargs['block_table'] = block_table + extra_kwargs['slot_mapping'] = slot_mapping + output = self.model( input_ids, position_ids=position_ids, @@ -360,10 +375,6 @@ def forward( past_key_value_states=self.past_key_value_states, use_cache=use_cache, only_last_token=only_last_token, - current_tkv_mask=current_tkv_mask, - left_padded_prompt_mask=left_padded_prompt_mask, - block_table=block_table, - slot_mapping=slot_mapping, **extra_kwargs, ) @@ -401,6 +412,9 @@ def forward( **extra_kwargs, ) -> torch.Tensor: + # specify attention type for static batching + extra_kwargs['attn_name'] = "sdpa_bidirectional" + output = self.model( input_ids, position_ids=position_ids, diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 74d28b6bf..4b22a1802 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -578,7 +578,13 @@ def __init__( self.tkv: int = 0 # set self.free_blocks to the minimal value of 4 required for warmup # is reset to the value returned by the Spyre compiler after warmup - self._set_free_blocks(num_blocks=4) + # self._set_free_blocks(num_blocks=4) + # for the time being we set this to num_blocks consistent with the + # cache dimension of ContinuousBatchingFmsModel.past_key_value_states + num_blocks = (vllm_config.scheduler_config.max_num_seqs * + vllm_config.model_config.max_model_len // + self.BLOCK_SIZE) + self._set_free_blocks(num_blocks=num_blocks) self.dummy_req_ids2blocks: list[int] = [] # TODO: Remove this once we can prefill and decode diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index b8b1a5467..712489ea7 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -61,9 +61,10 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def compile_or_warm_up_model(self) -> None: """Prepare model for execution through compilation/warmup.""" - # TO DO: implement warmup for continuous batching + if envs_spyre.VLLM_SPYRE_USE_CB: - self._warmup_spyre_dynamic_size(self.restricted_tokens) + with _maybe_warmup_context(): + self._warmup_spyre_dynamic_size(self.restricted_tokens) return num_shape_combinations = len(self.spyre_warmup_shapes) @@ -89,8 +90,10 @@ def compile_or_warm_up_model(self) -> None: logger.info( "Warming up for prompt length %d, decoding %d tokens with " "batch size %d", prompt_len, num_decode_tokens, batch_size) - self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, - self.restricted_tokens, batch_size) + with _maybe_warmup_context(): + self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, + self.restricted_tokens, + batch_size) all_warmup_end_t = time.time() all_warmup_total_t = all_warmup_end_t - all_warmup_start_t self.perf_metrics.log("total warmup time", all_warmup_total_t) @@ -382,8 +385,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): ) logger.info("Warmup decode 1/1...") - with _maybe_warmup_context(): - self.execute_model(scheduler_output) + self.execute_model(scheduler_output) # Needed to clean up the data of model runner scheduler_output = SchedulerOutput( @@ -452,7 +454,6 @@ def _get_num_blocks_available(self) -> int: str(max_model_len), max_concurrency_spyre) return num_blocks_spyre else: # dynamo backend 'eager' - # TODO: how do we get a meaningful value for CPU here num_blocks_cpu = max_batch_size * min_req_num_blocks assert num_blocks_cpu >= min_req_num_blocks, ( "Number of pages available on CPU (%d) is not enough to " @@ -536,9 +537,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # First full forward pass logger.info("Warmup forward pass 1/2...") - with _maybe_warmup_context(): - self._warmup_model_forward_pass(scheduler_output, dummy_requests, - cached_requests, num_decode_tokens) + self._warmup_model_forward_pass(scheduler_output, dummy_requests, + cached_requests, num_decode_tokens) self.perf_metrics.log("warmup 1 time", time.time() - warmup_start_t, batch_size=batch_size, diff --git a/vllm_spyre/worker/spyre_worker.py b/vllm_spyre/worker/spyre_worker.py index 53da2e1a6..60b4845ab 100644 --- a/vllm_spyre/worker/spyre_worker.py +++ b/vllm_spyre/worker/spyre_worker.py @@ -218,8 +218,9 @@ def load_model(self): print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, " f"decoding {num_decode_tokens} tokens with batch " f"size {batch_size}") - self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, - restricted_tokens, batch_size) + with _maybe_warmup_context(): + self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, + restricted_tokens, batch_size) all_warmup_end_t = time.time() all_warmup_total_t = all_warmup_end_t - all_warmup_start_t self.perf_metrics.log("total warmup time", all_warmup_total_t) @@ -262,13 +263,11 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, f"{prompt_len} and max output tokens {num_decode_tokens}.") print("[SpyreWorker] warmup 1/2...") - with _maybe_warmup_context(): - # TODO: torch_sendnn.CleanGraph() should be necessary? - # warmup 1st forward pass - self._warmup_model_forward_pass(warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs) + # warmup 1st forward pass + self._warmup_model_forward_pass(warmup_tokens_tensor, + valid_token_ids_tensor, prompt_len, + num_decode_tokens, batch_size, + extra_kwargs) self.perf_metrics.log("warmup 1 time", time.time() - warmup_start_t, batch_size=batch_size,