Skip to content

Commit 15d3587

Browse files
authored
[CB] Override number of Spyre blocks: replace env var with top level argument (#331)
### [CB] Override number of Spyre blocks: replace env var with top level argument depreciating the env var `VLLM_SPYRE_N_BLOCKS` and use the top level argument `num_gpu_blocks_override` instead to override the number of avail spyre kv cache blocks. Signed-off-by: Yannick Schnider <[email protected]>
1 parent e17fc39 commit 15d3587

File tree

5 files changed

+19
-13
lines changed

5 files changed

+19
-13
lines changed

tests/scheduling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def check_scheduler_inference_steps(
5757
# set env vars
5858
monkeypatch.setenv("VLLM_USE_V1", "1")
5959
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)
60-
if available_blocks > 0:
61-
monkeypatch.setenv("VLLM_SPYRE_N_BLOCKS", str(available_blocks))
6260
if use_cb:
6361
monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1")
6462

@@ -90,7 +88,9 @@ def check_scheduler_inference_steps(
9088
tokenizer=model,
9189
max_model_len=max_model_len,
9290
block_size=max_model_len,
93-
max_num_seqs=max_num_seqs)
91+
max_num_seqs=max_num_seqs,
92+
num_gpu_blocks_override=available_blocks
93+
if available_blocks > 0 else None)
9494
vllm_config = engine_args.create_engine_config()
9595
executor_class = Executor.get_class(vllm_config)
9696
engine_core = EngineCore(vllm_config=vllm_config,

vllm_spyre/envs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[list[int]] = None
1010
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[list[int]] = None
1111
VLLM_SPYRE_USE_CB: bool = False
12-
VLLM_SPYRE_N_BLOCKS: int = 0
1312
VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED: int = 0
1413
VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp"
1514
VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False
@@ -75,10 +74,6 @@ def _backend_backwards_compat() -> str:
7574
"VLLM_SPYRE_USE_CB":
7675
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),
7776

78-
# Overriding the number of KV cache blocks available on Spyre (and CPU)
79-
"VLLM_SPYRE_N_BLOCKS":
80-
lambda: int(os.getenv("VLLM_SPYRE_N_BLOCKS", 0)),
81-
8277
# Enable performance metric logging. This captures startup information
8378
# such as warmup times, and loading times. It is turned off by default.
8479
"VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED":

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,9 @@ def __init__(
331331

332332
def _set_past_key_value_states(self, num_blocks) -> None:
333333
# overwrite num_blocks for testing scheduler constraints
334-
if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0:
335-
num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS
334+
num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override()
335+
if num_blocks_override > 0:
336+
num_blocks = num_blocks_override
336337

337338
# List[layers] of Tuple[k,v] of
338339
# Tensor[num_blocks, block_size, num_kv_heads, head_dim]

vllm_spyre/platform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class SpyrePlatform(Platform):
4343
supported_quantization: list[str] = ["gptq"]
4444
_warmup_shapes: Optional[tuple[dict[str, int], ...]] = None
4545
_block_size: int = 64 # hardcoded Spyre constraint for now
46+
_num_spyre_blocks_override: int = -1 # override num of KV cache blocks
4647
_config: VllmConfig = None
4748

4849
@classmethod
@@ -136,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
136137
# budget available to schedule a full batch
137138
if cache_config is not None:
138139
if envs.VLLM_USE_V1:
140+
# overriding number of available Spyre blocks if not None
141+
if cache_config.num_gpu_blocks_override:
142+
cls._num_spyre_blocks_override = \
143+
cache_config.num_gpu_blocks_override
139144
# The V1 scheduler actually needs 2 blocks for each sequence...
140145
cache_config.num_gpu_blocks_override = \
141146
scheduler_config.max_num_seqs * 2
@@ -237,6 +242,10 @@ def get_warmup_shapes(
237242
def get_block_size(cls) -> int:
238243
return cls._block_size
239244

245+
@classmethod
246+
def get_num_spyre_blocks_override(cls) -> int:
247+
return cls._num_spyre_blocks_override
248+
240249
@classmethod
241250
def supports_v1(cls, model_config: ModelConfig) -> bool:
242251
"""Returns whether the current platform can support v1 for the supplied

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,11 +725,12 @@ def finish_warmup(self) -> None:
725725

726726
def _set_blocks(self, num_blocks: int) -> None:
727727
# overwrite num_blocks for testing scheduler constraints
728-
if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0:
728+
num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override()
729+
if num_blocks_override > 0:
729730
logger.info(
730731
"[WARMUP] Overriding number of KV cache blocks on "
731-
"Spyre/CPU to %d.", envs_spyre.VLLM_SPYRE_N_BLOCKS)
732-
num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS
732+
"Spyre/CPU to %d.", num_blocks_override)
733+
num_blocks = num_blocks_override
733734

734735
# set number of available blocks and populate block_pool
735736
self.n_blocks = num_blocks

0 commit comments

Comments
 (0)