diff --git a/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json b/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json new file mode 100644 index 00000000..f731a4bb --- /dev/null +++ b/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "GraniteForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "attention_multiplier": 0.0078125, + "bos_token_id": 0, + "embedding_multiplier": 12.0, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12800, + "logits_scaling": 16.0, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "granite", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "pad_token_id": 0, + "residual_multiplier": 0.22, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": true, + "vocab_size": 49159 +} diff --git a/tests/fixtures/models/granite-3.3-micro-config-only/config.json b/tests/fixtures/models/granite-3.3-micro-config-only/config.json new file mode 100644 index 00000000..67065dd6 --- /dev/null +++ b/tests/fixtures/models/granite-3.3-micro-config-only/config.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "GraniteForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "attention_multiplier": 0.0078125, + "bos_token_id": 0, + "dtype": "bfloat16", + "embedding_multiplier": 12.0, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12800, + "logits_scaling": 16.0, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "granite", + "num_attention_heads": 32, + "num_hidden_layers": 4, + "num_key_value_heads": 8, + "pad_token_id": 0, + "residual_multiplier": 0.22, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "tie_word_embeddings": false, + "transformers_version": "4.56.1", + "use_cache": true, + "vocab_size": 49159 +} diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py new file mode 100644 index 00000000..8d9cba0a --- /dev/null +++ b/tests/models/test_granite.py @@ -0,0 +1,49 @@ +"""Tests for model-specific overrides for granite""" +import os +from pathlib import Path +from unittest import mock + +import pytest +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig + +from vllm_spyre.platform import SpyrePlatform + +FIXTURES_PATH = Path(__file__).parent.parent / "fixtures" / "models" + +NO_SWAP_CONFIG = CacheConfig(swap_space=0.001) + + +@pytest.mark.cpu +def test_granite_3_8b_detection(): + """Check that we can detect the model config for granite 3 8b""" + + granite_3_8b_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), + cache_config=NO_SWAP_CONFIG) + + granite_micro_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-micro-config-only")), + cache_config=NO_SWAP_CONFIG) + + assert SpyrePlatform.is_granite_3_8b(granite_3_8b_config.model_config) + + assert not SpyrePlatform.is_granite_3_8b(granite_micro_config.model_config) + + +@pytest.mark.cpu +def test_granite_3_8b_overrides(): + """Check that the correct values are overridden for g3.3 8b""" + + # Must ensure no env vars have been overridden before testing + with mock.patch.dict(os.environ, clear=True): + tp4_config = ParallelConfig(tensor_parallel_size=4) + + granite_3_8b_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), + parallel_config=tp4_config, + cache_config=NO_SWAP_CONFIG) + + assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080 + + assert int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT")) == 128 * 1024 + assert int(os.getenv("FLEX_HDMA_P2PSIZE")) == 256 * 1024 * 1024 diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index a94db1b6..f2b113f1 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -9,7 +9,7 @@ import torch.nn as nn from fms.models import get_model from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -51,17 +51,15 @@ class SpyreCausalLM(nn.Module): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, ) -> None: super().__init__() - self.logits_processor = LogitsProcessor( - model_config.hf_config.vocab_size, logits_as_input=True) + vllm_config.model_config.hf_config.vocab_size, + logits_as_input=True) self.sampler = get_sampler() # boolean tensor of length batch size with indices: @@ -78,14 +76,10 @@ def __init__( # FMS Model if envs_spyre.VLLM_SPYRE_USE_CB: - self.model = ContinuousBatchingFmsModel(model_config, - parallel_config, - scheduler_config, rank) + self.model = ContinuousBatchingFmsModel(vllm_config, rank) else: self.model = StaticBatchingFmsModel( - model_config, - parallel_config, - scheduler_config, + vllm_config, max_prompt_length, max_decode_length, rank, @@ -155,8 +149,7 @@ class FmsModelBase(nn.Module): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, @@ -164,23 +157,26 @@ def __init__( ) -> None: super().__init__() - self.config: PretrainedConfig = model_config.hf_config + self.config: PretrainedConfig = vllm_config.model_config.hf_config # Actual FMS model self.model: nn.Module - self.model_config = model_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.scheduler_config = vllm_config.scheduler_config self.dtype = self.get_dtype() # Load the weights from the cached or downloaded files. self.load_weights( - model_config=model_config, + model_config=self.model_config, max_prompt_length=max_prompt_length, max_decode_length=max_decode_length, distributed_strategy="tp" - if parallel_config.world_size > 1 else None, + if self.parallel_config.world_size > 1 else None, sendnn_dynamic=sendnn_dynamic, rank=rank, - world_size=parallel_config.world_size, + world_size=self.parallel_config.world_size, ) def load_weights( @@ -321,14 +317,11 @@ class ContinuousBatchingFmsModel(FmsModelBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, + vllm_config: VllmConfig, rank: int, ) -> None: - BLOCK_SIZE = SpyrePlatform.get_block_size() - max_model_len = scheduler_config.max_model_len + max_model_len = vllm_config.scheduler_config.max_model_len # edge case: prompt fills model length: can produce 1 token with prefill max_prompt_length = max_model_len @@ -336,22 +329,20 @@ def __init__( # can produce 1 token with prefill plus rest of model length max_decode_length = max_model_len - BLOCK_SIZE + 1 - super().__init__(model_config, - parallel_config, + super().__init__(vllm_config, max_prompt_length, max_decode_length, rank, sendnn_dynamic=True) - self.scheduler_config = scheduler_config - self.parallel_config = parallel_config self.prefill_past_key_values = None # physical KV cache on AIU Spyre: will eventually not live in this class self.kv_cache_specs = {} self.kv_cache_specs['block_size'] = BLOCK_SIZE - self.kv_cache_specs['num_kv_heads'] = model_config.get_num_kv_heads( - parallel_config) + self.kv_cache_specs[ + 'num_kv_heads'] = self.model_config.get_num_kv_heads( + self.parallel_config) if self.config.model_type in {'llama', 'granite'}: self.kv_cache_specs['num_layers'] = self.config.num_hidden_layers @@ -375,81 +366,9 @@ def __init__( self.current_scale: Optional[list[tuple]] = None - def get_num_blocks_available(self) -> int: - """Function returns the number of available blocks/pages. - Will eventually contain a function in torch_sendnn which reads - the actual value provided by the compiler for backend sendnn""" - - max_batch_size = self.scheduler_config.max_num_seqs - max_model_len = self.scheduler_config.max_model_len - block_size = self.kv_cache_specs['block_size'] - - min_req_num_blocks = max_model_len // block_size - - # TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function - # in torch_sendnn which returns the value set by the Spyre compiler. - if ('granite-3.3-8b-instruct' in self.model_config.model - and self.parallel_config.world_size == 4): - # hard coded value for tensor parallel size 4 with the below model - # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct - - # num_blocks_spyre must be multiple of max_batch_size - NUM_BLOCKS_SPYRE = max_batch_size * (2080 // max_batch_size) - logger.info( - "Model %s and tensor parallel " - "size %d detected. Using NUM_BLOCKS_SPYRE = %d", - self.model_config.model, - self.parallel_config.world_size, - NUM_BLOCKS_SPYRE, - ) - else: - # default value for any other model/ tensor parallel size - NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks - logger.info("No model / tensor parallel size specific value for " \ - "the number of KV cache blocks available on Spyre found. Using " \ - "default value (max_batch_size * max_model_len / block_size): %d", - NUM_BLOCKS_SPYRE) - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn': - num_blocks_spyre = NUM_BLOCKS_SPYRE - assert num_blocks_spyre >= min_req_num_blocks, ( - "Number of pages available on Spyre (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_spyre, min_req_num_blocks)) - max_concurrency_spyre = num_blocks_spyre * block_size \ - / max_model_len - logger.info("Spyre KV cache size: %s tokens", - num_blocks_spyre * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_spyre) - - assert num_blocks_spyre % max_batch_size == 0, \ - "num_blocks_spyre must be multiple of max_batch_size" - return num_blocks_spyre - else: # dynamo backend 'eager' - # for debugging purposes we also put the spyre value here for cpu - num_blocks_cpu = NUM_BLOCKS_SPYRE - assert num_blocks_cpu >= min_req_num_blocks, ( - "Number of pages available on CPU (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_cpu, min_req_num_blocks)) - max_concurrency_cpu = num_blocks_cpu * block_size \ - / max_model_len - logger.info("CPU KV cache size: %s tokens", - num_blocks_cpu * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_cpu) - return num_blocks_cpu - def set_past_key_value_states(self, num_blocks) -> None: - # overwrite num_blocks for testing scheduler constraints - num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() - if num_blocks_override > 0: - num_blocks = num_blocks_override - # List[layers] of Tuple[k,v] of # Tensor[num_blocks, block_size, num_kv_heads, head_dim] - if not self.model_config.quantization: self.past_key_value_states = [ (torch.zeros(num_blocks, @@ -665,15 +584,12 @@ class StaticBatchingFmsModel(FmsModelBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - _: SchedulerConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, ) -> None: - super().__init__(model_config, - parallel_config, + super().__init__(vllm_config, max_prompt_length, max_decode_length, rank, diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index ec986f10..685dd844 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union import torch +from transformers.models.granite import GraniteConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.pooling_params import PoolingParams @@ -62,7 +63,6 @@ class SpyrePlatform(Platform): supported_quantization: list[str] = ["gptq", "compressed-tensors"] _warmup_shapes: tuple[dict[str, int], ...] | None = None _block_size: int = 64 # hardcoded Spyre constraint for now - _num_spyre_blocks_override: int = -1 # override num of KV cache blocks _config: VllmConfig = None # Backend for dynamic compilation ops @@ -167,20 +167,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # - Set the block size (in tokens) to the maximum sequence length # so that the scheduler thinks an entire sequence will fit in # one single block. - # - Set the number of blocks to the maximum number of sequences, so - # the scheduler always thinks there's a block available # - Set `max_num_batched_tokens` to the size of a full batch of full # length requests, so that the scheduler will always have token # budget available to schedule a full batch if cache_config is not None: - # overriding number of available Spyre blocks if not None - if cache_config.num_gpu_blocks_override: - cls._num_spyre_blocks_override = \ - cache_config.num_gpu_blocks_override - # The V1 scheduler actually needs 2 blocks for each sequence... - cache_config.num_gpu_blocks_override = \ - scheduler_config.max_num_seqs * 2 - cache_config.block_size = model_config.max_model_len scheduler_config.max_num_batched_tokens = ( model_config.max_model_len * scheduler_config.max_num_seqs) @@ -188,9 +178,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.info( "Overriding configurations based on warmup shapes. " "max_model_len=%d, max_num_seqs=%d, block_size=%d, " - "num_gpu_blocks_override=%d, max_num_batched_tokens=%d", - model_config.max_model_len, scheduler_config.max_num_seqs, - cache_config.block_size, cache_config.num_gpu_blocks_override, + "max_num_batched_tokens=%d", model_config.max_model_len, + scheduler_config.max_num_seqs, cache_config.block_size, scheduler_config.max_num_batched_tokens) # set env vars for torch_sendnn to consume @@ -209,28 +198,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str( max(vllm_config.scheduler_config.max_num_seqs, 2)) - # max product of batch size x tkv supported by the Spyre compiler - if ('granite-3.3-8b-instruct' in model_config.model - and parallel_config.world_size == 4): - # hard coded value for tensor parallel size 4 with the below model - # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct - os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(128 * 1024) - logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ - "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", - 128 * 1024) - - # If no HDMA p2psize override was specified, set 256MB - if not os.getenv("FLEX_HDMA_P2PSIZE", None): - os.environ["FLEX_HDMA_P2PSIZE"] = str(1024 * 1024 * 256) - logger.info( - "Model granite-3.3-8b-instruct and tensor parallel size 4 " - "detected. Using FLEX_HDMA_P2PSIZE = %d", - 1024 * 1024 * 256) - else: - # default value for any other model/ tensor parallel size + # Hardcode some things for granite-3.3-8b-instruct + if cls.is_granite_3_8b(vllm_config.model_config): + cls.configure_granite_3_8b(vllm_config) + + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"): + # max product of batch size x tkv supported by the Spyre compiler default_max_batch_tkv_limit = \ vllm_config.model_config.max_model_len * \ vllm_config.scheduler_config.max_num_seqs + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str( default_max_batch_tkv_limit) logger.info("No model / tensor parallel size specific value for " \ @@ -316,10 +293,6 @@ def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]: def get_block_size(cls) -> int: return cls._block_size - @classmethod - def get_num_spyre_blocks_override(cls) -> int: - return cls._num_spyre_blocks_override - @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: """Returns whether the current platform can support v1 for the supplied @@ -540,3 +513,73 @@ def get_max_output_tokens(self, prompt_len: int) -> int: max_new_tokens = max(max_new_tokens, shape['new_tokens']) return max_new_tokens + + @classmethod + def configure_granite_3_8b(cls, vllm_config: VllmConfig): + """ + Configure hard coded values for the model + https://huggingface.co/ibm-granite/granite-3.3-8b-instruct + """ + parallel_config = vllm_config.parallel_config + + if parallel_config.world_size != 4: + # only override configs for TP=4 + return + + tkv_128k = 128 * 1024 + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"): + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(tkv_128k) + logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ + "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", + tkv_128k) + elif os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT") != str(tkv_128k): + logger.warning( + "VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %s, not " + "overriding to the granite-3.3-8b-instruct default of %d", + os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"), tkv_128k) + + # If no HDMA p2psize override was specified, set 256MB + p2psize_256m = 256 * 1024 * 1024 + if not os.getenv("FLEX_HDMA_P2PSIZE"): + os.environ["FLEX_HDMA_P2PSIZE"] = str(p2psize_256m) + logger.info( + "Model granite-3.3-8b-instruct and tensor parallel size 4 " + "detected. Using FLEX_HDMA_P2PSIZE = %d", p2psize_256m) + elif os.getenv("FLEX_HDMA_P2PSIZE") != str(p2psize_256m): + logger.warning( + "FLEX_HDMA_P2PSIZE was set to %s, not using the " + "granite-3.3-8b-instruct default of %d", + os.getenv("FLEX_HDMA_P2PSIZE"), p2psize_256m) + + # Override the total number of KV cache blocks based on what we know + # will fit. (Unless user already set `--num-gpu-blocks-override`) + # TODO: remove this once we have correct free memory info available + blocks_override = 2080 + if vllm_config.cache_config.num_gpu_blocks_override is None: + vllm_config.cache_config.num_gpu_blocks_override = blocks_override + logger.info( + "Model granite-3.3-8b-instruct and tensor parallel size 4 " + "detected. Overriding available KV Cache blocks to %d", + blocks_override) + elif (vllm_config.cache_config.num_gpu_blocks_override + != blocks_override): + logger.warning( + "--num-gpu-blocks-override was set to %d, not using the " + "granite-3.3-8b-instruct default of %d", + vllm_config.cache_config.num_gpu_blocks_override, + blocks_override) + + @classmethod + def is_granite_3_8b(cls, model_config: ModelConfig): + """Returns true if we have a model that looks like + ibm-granite/granite-3.3-8b-instruct""" + if not isinstance(model_config.hf_config, GraniteConfig): + # Not granite at all + return False + + return (model_config.hf_config.num_hidden_layers == 40 + and model_config.hf_config.max_position_embeddings == 131072 + and model_config.hf_config.hidden_size == 4096 + and model_config.hf_config.vocab_size == 49159 + and model_config.hf_config.num_key_value_heads == 8 + and model_config.hf_config.num_attention_heads == 32) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 005ab915..437fcbfd 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -288,9 +288,7 @@ def load_model(self, prompt_lens: Iterable[int], max_pad_length = max(prompt_lens) max_decode_length = max(num_decode_tokens) self.model = SpyreCausalLM( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, + vllm_config=self.vllm_config, max_prompt_length=max_pad_length, max_decode_length=max_decode_length, rank=self.rank, @@ -829,7 +827,7 @@ def pre_warmup(self) -> None: # Note: Until this feature is supported by the compiler we have to set: # n_blocks_warmup = n_blocks_avail - n_blocks_warmup = self.model.model.get_num_blocks_available() + n_blocks_warmup = self.get_total_spyre_blocks() self._set_blocks(num_blocks=n_blocks_warmup) self.model.model.set_past_key_value_states(num_blocks=n_blocks_warmup) @@ -849,23 +847,55 @@ def complete_warmup(self) -> None: super().complete_warmup() # get the number or pages from the actual Spyre card after the warmup # and set it accordingly in the model runner and for the kv cache size - n_blocks_avail = self.model.model.get_num_blocks_available() + n_blocks_avail = self.get_total_spyre_blocks() self._set_blocks(num_blocks=n_blocks_avail) self.model.model.set_past_key_value_states(num_blocks=n_blocks_avail) def _set_blocks(self, num_blocks: int) -> None: - # overwrite num_blocks for testing scheduler constraints - num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() - if num_blocks_override > 0: - logger.info( - "[WARMUP] Overriding number of KV cache blocks on " - "Spyre/CPU to %d.", num_blocks_override) - num_blocks = num_blocks_override - # set number of available blocks and populate block_pool self.n_blocks = num_blocks self.block_pool = deque([i for i in range(self.n_blocks)]) + def get_total_spyre_blocks(self) -> int: + """Returns the total number of KV cache blocks available for spyre. + This currently returns the number of blocks required for a full-sized + batch, which may be greater than the available memory. + + Until a correct available memory api is available, the number of blocks + must be overridden with a known good value via + cache_config.num_gpu_blocks_override + """ + max_batch_size = self.scheduler_config.max_num_seqs + max_model_len = self.scheduler_config.max_model_len + block_size = SpyrePlatform.get_block_size() + min_req_num_blocks = max_model_len // block_size + + blocks_override = self.cache_config.num_gpu_blocks_override + if blocks_override is not None and blocks_override > 0: + num_blocks = blocks_override + else: + num_blocks = max_batch_size * min_req_num_blocks + + # Total number of blocks needs to be a multiple of the batch size + # (spyre constraint) so round it down + num_blocks = max_batch_size * (num_blocks // max_batch_size) + + if num_blocks < min_req_num_blocks: + raise ValueError( + f"Number of pages available on Spyre {num_blocks} is not " + f"enough to serve the current model (need at least " + f"{min_req_num_blocks} pages).") + + max_concurrency = num_blocks * block_size / max_model_len + backend = "Spyre" if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn' \ + else "CPU" + logger.info("%s KV cache size: %s tokens", backend, + num_blocks * block_size) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + str(max_model_len), max_concurrency) + + return num_blocks + def update_states(self, scheduler_output): super().update_states(scheduler_output) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 20849d0b..65aba7c1 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -181,16 +181,25 @@ def determine_available_memory(self) -> int: The number of device blocks (called "gpu blocks" in most places) can also be overridden by `--num-gpu-blocks-override`, which is set under `vllm_config.cache_config.num_gpu_blocks_override`. + + 🌶️🌶️🌶️ The result from this method _only_ applies to the KV Cache + management in vLLM's core scheduler. This does _not_ apply to the KV + cache management handled directly by the vllm-spyre worker and model + runner. We return a minimal value here to make the vllm scheduler happy. """ - # Currently we override vllm_config.cache_config.num_gpu_blocks_override - # in platform.py, so this value is only used by vllm to check that the - # number of gpu blocks will fit in available memory. - # Since we also return dummy values for the kv cache spec, this check is - # meaningless and we can just return a large value to ensure vllm does - # not raise a validation error. - # TODO: Return the real available device memory when we implement real - # kv-caching. - return 1 << 64 + # The fake kv_cache config specified by the model runner sets 4 bytes + # per token. + accurate_fake_kv_cache_size = (4 * + self.scheduler_config.max_model_len * + self.scheduler_config.max_num_seqs) + + # The vLLM scheduler reserves a null block in its kv-cache, so we need + # at least one more block to allow for proper scheduling. We double + # the cache size here to ensure that the vllm scheduler always has + # blocks available. This causes the log message from vLLM about it's + # KV cache capacity to be double the log message from vllm-spyre. + # This can probably be fixed in a nicer way. + return 2 * accurate_fake_kv_cache_size def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: