diff --git a/tests/scheduling_utils.py b/tests/scheduling_utils.py index 0bf00fad6..bc395094c 100644 --- a/tests/scheduling_utils.py +++ b/tests/scheduling_utils.py @@ -57,8 +57,6 @@ def check_scheduler_inference_steps( # set env vars monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) - if available_blocks > 0: - monkeypatch.setenv("VLLM_SPYRE_N_BLOCKS", str(available_blocks)) if use_cb: monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1") @@ -90,7 +88,9 @@ def check_scheduler_inference_steps( tokenizer=model, max_model_len=max_model_len, block_size=max_model_len, - max_num_seqs=max_num_seqs) + max_num_seqs=max_num_seqs, + num_gpu_blocks_override=available_blocks + if available_blocks > 0 else None) vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, diff --git a/vllm_spyre/envs.py b/vllm_spyre/envs.py index 6d7148def..916778dd3 100644 --- a/vllm_spyre/envs.py +++ b/vllm_spyre/envs.py @@ -9,7 +9,6 @@ VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[list[int]] = None VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[list[int]] = None VLLM_SPYRE_USE_CB: bool = False - VLLM_SPYRE_N_BLOCKS: int = 0 VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED: int = 0 VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp" VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False @@ -75,10 +74,6 @@ def _backend_backwards_compat() -> str: "VLLM_SPYRE_USE_CB": lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))), - # Overriding the number of KV cache blocks available on Spyre (and CPU) - "VLLM_SPYRE_N_BLOCKS": - lambda: int(os.getenv("VLLM_SPYRE_N_BLOCKS", 0)), - # Enable performance metric logging. This captures startup information # such as warmup times, and loading times. It is turned off by default. "VLLM_SPYRE_PERF_METRIC_LOGGING_ENABLED": diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 25bc5e6e1..0e2b9a092 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -331,8 +331,9 @@ def __init__( def _set_past_key_value_states(self, num_blocks) -> None: # overwrite num_blocks for testing scheduler constraints - if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0: - num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS + num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() + if num_blocks_override > 0: + num_blocks = num_blocks_override # List[layers] of Tuple[k,v] of # Tensor[num_blocks, block_size, num_kv_heads, head_dim] diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index f8e22d1fa..80bc2a30f 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -43,6 +43,7 @@ class SpyrePlatform(Platform): supported_quantization: list[str] = ["gptq"] _warmup_shapes: Optional[tuple[dict[str, int], ...]] = None _block_size: int = 64 # hardcoded Spyre constraint for now + _num_spyre_blocks_override: int = -1 # override num of KV cache blocks _config: VllmConfig = None @classmethod @@ -136,6 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # budget available to schedule a full batch if cache_config is not None: if envs.VLLM_USE_V1: + # overriding number of available Spyre blocks if not None + if cache_config.num_gpu_blocks_override: + cls._num_spyre_blocks_override = \ + cache_config.num_gpu_blocks_override # The V1 scheduler actually needs 2 blocks for each sequence... cache_config.num_gpu_blocks_override = \ scheduler_config.max_num_seqs * 2 @@ -237,6 +242,10 @@ def get_warmup_shapes( def get_block_size(cls) -> int: return cls._block_size + @classmethod + def get_num_spyre_blocks_override(cls) -> int: + return cls._num_spyre_blocks_override + @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: """Returns whether the current platform can support v1 for the supplied diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 61a6cc88e..a6122ec9c 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -725,11 +725,12 @@ def finish_warmup(self) -> None: def _set_blocks(self, num_blocks: int) -> None: # overwrite num_blocks for testing scheduler constraints - if envs_spyre.VLLM_SPYRE_N_BLOCKS > 0: + num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() + if num_blocks_override > 0: logger.info( "[WARMUP] Overriding number of KV cache blocks on " - "Spyre/CPU to %d.", envs_spyre.VLLM_SPYRE_N_BLOCKS) - num_blocks = envs_spyre.VLLM_SPYRE_N_BLOCKS + "Spyre/CPU to %d.", num_blocks_override) + num_blocks = num_blocks_override # set number of available blocks and populate block_pool self.n_blocks = num_blocks