From ab3cda7787c7d3f9c4f474f608ac5af0116f30f1 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 2 Oct 2025 11:16:24 -0600 Subject: [PATCH 01/13] :bug: implement better checking for granite Signed-off-by: Joe Runde --- .../model_executor/model_loader/spyre.py | 2 +- vllm_spyre/platform.py | 77 +++++++++++++------ 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index a94db1b6..e734bdac 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -388,7 +388,7 @@ def get_num_blocks_available(self) -> int: # TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function # in torch_sendnn which returns the value set by the Spyre compiler. - if ('granite-3.3-8b-instruct' in self.model_config.model + if (SpyrePlatform.is_granite_33_8b(self.model_config) and self.parallel_config.world_size == 4): # hard coded value for tensor parallel size 4 with the below model # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 59a2f9f3..43e7a96d 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union import torch +from transformers.models.granite import GraniteConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.pooling_params import PoolingParams @@ -209,28 +210,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str( max(vllm_config.scheduler_config.max_num_seqs, 2)) - # max product of batch size x tkv supported by the Spyre compiler - if ('granite-3.3-8b-instruct' in model_config.model - and parallel_config.world_size == 4): - # hard coded value for tensor parallel size 4 with the below model - # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct - os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(128 * 1024) - logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ - "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", - 128 * 1024) + # Hardcode some things for granite-3.3-8b-instruct + cls.configure_granite_33_8b(vllm_config.model_config) - # If no HDMA p2psize override was specified, set 256MB - if not os.getenv("FLEX_HDMA_P2PSIZE", None): - os.environ["FLEX_HDMA_P2PSIZE"] = str(1024 * 1024 * 256) - logger.info( - "Model granite-3.3-8b-instruct and tensor parallel size 4 " - "detected. Using FLEX_HDMA_P2PSIZE = %d", - 1024 * 1024 * 256) - else: - # default value for any other model/ tensor parallel size - default_max_batch_tkv_limit = \ - vllm_config.model_config.max_model_len * \ - vllm_config.scheduler_config.max_num_seqs + # max product of batch size x tkv supported by the Spyre compiler + default_max_batch_tkv_limit = \ + vllm_config.model_config.max_model_len * \ + vllm_config.scheduler_config.max_num_seqs + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str( default_max_batch_tkv_limit) logger.info("No model / tensor parallel size specific value for " \ @@ -518,4 +505,50 @@ def get_max_output_tokens(self, prompt_len: int) -> int: if prompt_len <= shape['prompt_length']: max_new_tokens = max(max_new_tokens, shape['new_tokens']) - return max_new_tokens \ No newline at end of file + return max_new_tokens + + @classmethod + def configure_granite_33_8b(cls, vllm_config: VllmConfig): + """ + Configure hard coded values for the model + https://huggingface.co/ibm-granite/granite-3.3-8b-instruct + """ + + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if not cls.is_granite_33_8b(model_config): + # Not granite + return + + if parallel_config.world_size != 4: + # only override configs for TP=4 + return + + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(128 * 1024) + logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ + "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", + 128 * 1024) + + # If no HDMA p2psize override was specified, set 256MB + if not os.getenv("FLEX_HDMA_P2PSIZE", None): + os.environ["FLEX_HDMA_P2PSIZE"] = str(1024 * 1024 * 256) + logger.info( + "Model granite-3.3-8b-instruct and tensor parallel size 4 " + "detected. Using FLEX_HDMA_P2PSIZE = %d", 1024 * 1024 * 256) + + @classmethod + def is_granite_33_8b(cls, model_config: ModelConfig): + """Returns true if we have a model that looks like + ibm-granite/granite-3.3-8b-instruct""" + if not isinstance(model_config.hf_config, GraniteConfig): + # Not granite at all + return False + + return (model_config.hf_config.num_hidden_layers == 40 + and model_config.hf_config.max_position_embeddings == 131072 + and model_config.hf_config.hidden_size == 4096 + and model_config.hf_config.vocab_size == 49159 + and model_config.hf_config.num_key_value_heads == 8 + and model_config.hf_config.num_attention_heads == 32) From 0c3c937f8f828561199788656e6ac1749a1eca0a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 2 Oct 2025 11:20:24 -0600 Subject: [PATCH 02/13] :recycle: cleanup constants Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 43e7a96d..f7740b03 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -526,17 +526,19 @@ def configure_granite_33_8b(cls, vllm_config: VllmConfig): return if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): - os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(128 * 1024) + tkv_128k = 128 * 1024 + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(tkv_128k) logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", - 128 * 1024) + tkv_128k) # If no HDMA p2psize override was specified, set 256MB if not os.getenv("FLEX_HDMA_P2PSIZE", None): - os.environ["FLEX_HDMA_P2PSIZE"] = str(1024 * 1024 * 256) + p2psize_256m = 256 * 1024 * 1024 + os.environ["FLEX_HDMA_P2PSIZE"] = str(p2psize_256m) logger.info( "Model granite-3.3-8b-instruct and tensor parallel size 4 " - "detected. Using FLEX_HDMA_P2PSIZE = %d", 1024 * 1024 * 256) + "detected. Using FLEX_HDMA_P2PSIZE = %d", p2psize_256m) @classmethod def is_granite_33_8b(cls, model_config: ModelConfig): From a05befee93657a0c783da2aa27f86d5638d401ae Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 2 Oct 2025 11:25:03 -0600 Subject: [PATCH 03/13] :bug: fixup call Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index f7740b03..086717e7 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -211,7 +211,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: max(vllm_config.scheduler_config.max_num_seqs, 2)) # Hardcode some things for granite-3.3-8b-instruct - cls.configure_granite_33_8b(vllm_config.model_config) + cls.configure_granite_33_8b(vllm_config) # max product of batch size x tkv supported by the Spyre compiler default_max_batch_tkv_limit = \ From e83893cbf83cee708a0afadbfb3eff7d0812f6f2 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 2 Oct 2025 15:33:48 -0600 Subject: [PATCH 04/13] :recycle: consolidate block overrides for granite Signed-off-by: Joe Runde --- .../model_executor/model_loader/spyre.py | 131 ++++-------------- vllm_spyre/platform.py | 73 +++++----- vllm_spyre/v1/worker/spyre_model_runner.py | 56 ++++++-- vllm_spyre/v1/worker/spyre_worker.py | 18 +-- 4 files changed, 117 insertions(+), 161 deletions(-) diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index e734bdac..4e6f1bdb 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -9,7 +9,7 @@ import torch.nn as nn from fms.models import get_model from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -51,17 +51,17 @@ class SpyreCausalLM(nn.Module): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, ) -> None: super().__init__() + self.vllm_config = vllm_config self.logits_processor = LogitsProcessor( - model_config.hf_config.vocab_size, logits_as_input=True) + vllm_config.model_config.hf_config.vocab_size, + logits_as_input=True) self.sampler = get_sampler() # boolean tensor of length batch size with indices: @@ -78,14 +78,10 @@ def __init__( # FMS Model if envs_spyre.VLLM_SPYRE_USE_CB: - self.model = ContinuousBatchingFmsModel(model_config, - parallel_config, - scheduler_config, rank) + self.model = ContinuousBatchingFmsModel(vllm_config, rank) else: self.model = StaticBatchingFmsModel( - model_config, - parallel_config, - scheduler_config, + vllm_config, max_prompt_length, max_decode_length, rank, @@ -155,8 +151,7 @@ class FmsModelBase(nn.Module): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, @@ -164,23 +159,27 @@ def __init__( ) -> None: super().__init__() - self.config: PretrainedConfig = model_config.hf_config + self.config: PretrainedConfig = vllm_config.model_config.hf_config # Actual FMS model self.model: nn.Module - self.model_config = model_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.scheduler_config = vllm_config.scheduler_config self.dtype = self.get_dtype() # Load the weights from the cached or downloaded files. self.load_weights( - model_config=model_config, + model_config=self.model_config, max_prompt_length=max_prompt_length, max_decode_length=max_decode_length, distributed_strategy="tp" - if parallel_config.world_size > 1 else None, + if self.parallel_config.world_size > 1 else None, sendnn_dynamic=sendnn_dynamic, rank=rank, - world_size=parallel_config.world_size, + world_size=self.parallel_config.world_size, ) def load_weights( @@ -321,14 +320,11 @@ class ContinuousBatchingFmsModel(FmsModelBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, + vllm_config: VllmConfig, rank: int, ) -> None: - BLOCK_SIZE = SpyrePlatform.get_block_size() - max_model_len = scheduler_config.max_model_len + max_model_len = vllm_config.scheduler_config.max_model_len # edge case: prompt fills model length: can produce 1 token with prefill max_prompt_length = max_model_len @@ -336,22 +332,20 @@ def __init__( # can produce 1 token with prefill plus rest of model length max_decode_length = max_model_len - BLOCK_SIZE + 1 - super().__init__(model_config, - parallel_config, + super().__init__(vllm_config, max_prompt_length, max_decode_length, rank, sendnn_dynamic=True) - self.scheduler_config = scheduler_config - self.parallel_config = parallel_config self.prefill_past_key_values = None # physical KV cache on AIU Spyre: will eventually not live in this class self.kv_cache_specs = {} self.kv_cache_specs['block_size'] = BLOCK_SIZE - self.kv_cache_specs['num_kv_heads'] = model_config.get_num_kv_heads( - parallel_config) + self.kv_cache_specs[ + 'num_kv_heads'] = self.model_config.get_num_kv_heads( + self.parallel_config) if self.config.model_type in {'llama', 'granite'}: self.kv_cache_specs['num_layers'] = self.config.num_hidden_layers @@ -375,81 +369,9 @@ def __init__( self.current_scale: Optional[list[tuple]] = None - def get_num_blocks_available(self) -> int: - """Function returns the number of available blocks/pages. - Will eventually contain a function in torch_sendnn which reads - the actual value provided by the compiler for backend sendnn""" - - max_batch_size = self.scheduler_config.max_num_seqs - max_model_len = self.scheduler_config.max_model_len - block_size = self.kv_cache_specs['block_size'] - - min_req_num_blocks = max_model_len // block_size - - # TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function - # in torch_sendnn which returns the value set by the Spyre compiler. - if (SpyrePlatform.is_granite_33_8b(self.model_config) - and self.parallel_config.world_size == 4): - # hard coded value for tensor parallel size 4 with the below model - # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct - - # num_blocks_spyre must be multiple of max_batch_size - NUM_BLOCKS_SPYRE = max_batch_size * (2080 // max_batch_size) - logger.info( - "Model %s and tensor parallel " - "size %d detected. Using NUM_BLOCKS_SPYRE = %d", - self.model_config.model, - self.parallel_config.world_size, - NUM_BLOCKS_SPYRE, - ) - else: - # default value for any other model/ tensor parallel size - NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks - logger.info("No model / tensor parallel size specific value for " \ - "the number of KV cache blocks available on Spyre found. Using " \ - "default value (max_batch_size * max_model_len / block_size): %d", - NUM_BLOCKS_SPYRE) - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn': - num_blocks_spyre = NUM_BLOCKS_SPYRE - assert num_blocks_spyre >= min_req_num_blocks, ( - "Number of pages available on Spyre (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_spyre, min_req_num_blocks)) - max_concurrency_spyre = num_blocks_spyre * block_size \ - / max_model_len - logger.info("Spyre KV cache size: %s tokens", - num_blocks_spyre * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_spyre) - - assert num_blocks_spyre % max_batch_size == 0, \ - "num_blocks_spyre must be multiple of max_batch_size" - return num_blocks_spyre - else: # dynamo backend 'eager' - # for debugging purposes we also put the spyre value here for cpu - num_blocks_cpu = NUM_BLOCKS_SPYRE - assert num_blocks_cpu >= min_req_num_blocks, ( - "Number of pages available on CPU (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_cpu, min_req_num_blocks)) - max_concurrency_cpu = num_blocks_cpu * block_size \ - / max_model_len - logger.info("CPU KV cache size: %s tokens", - num_blocks_cpu * block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_cpu) - return num_blocks_cpu - def set_past_key_value_states(self, num_blocks) -> None: - # overwrite num_blocks for testing scheduler constraints - num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() - if num_blocks_override > 0: - num_blocks = num_blocks_override - # List[layers] of Tuple[k,v] of # Tensor[num_blocks, block_size, num_kv_heads, head_dim] - if not self.model_config.quantization: self.past_key_value_states = [ (torch.zeros(num_blocks, @@ -665,15 +587,12 @@ class StaticBatchingFmsModel(FmsModelBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - _: SchedulerConfig, + vllm_config: VllmConfig, max_prompt_length: int, max_decode_length: int, rank: int, ) -> None: - super().__init__(model_config, - parallel_config, + super().__init__(vllm_config, max_prompt_length, max_decode_length, rank, diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 086717e7..f9b8143d 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -63,7 +63,6 @@ class SpyrePlatform(Platform): supported_quantization: list[str] = ["gptq", "compressed-tensors"] _warmup_shapes: tuple[dict[str, int], ...] | None = None _block_size: int = 64 # hardcoded Spyre constraint for now - _num_spyre_blocks_override: int = -1 # override num of KV cache blocks _config: VllmConfig = None # Backend for dynamic compilation ops @@ -168,20 +167,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # - Set the block size (in tokens) to the maximum sequence length # so that the scheduler thinks an entire sequence will fit in # one single block. - # - Set the number of blocks to the maximum number of sequences, so - # the scheduler always thinks there's a block available # - Set `max_num_batched_tokens` to the size of a full batch of full # length requests, so that the scheduler will always have token # budget available to schedule a full batch if cache_config is not None: - # overriding number of available Spyre blocks if not None - if cache_config.num_gpu_blocks_override: - cls._num_spyre_blocks_override = \ - cache_config.num_gpu_blocks_override - # The V1 scheduler actually needs 2 blocks for each sequence... - cache_config.num_gpu_blocks_override = \ - scheduler_config.max_num_seqs * 2 - cache_config.block_size = model_config.max_model_len scheduler_config.max_num_batched_tokens = ( model_config.max_model_len * scheduler_config.max_num_seqs) @@ -189,9 +178,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.info( "Overriding configurations based on warmup shapes. " "max_model_len=%d, max_num_seqs=%d, block_size=%d, " - "num_gpu_blocks_override=%d, max_num_batched_tokens=%d", - model_config.max_model_len, scheduler_config.max_num_seqs, - cache_config.block_size, cache_config.num_gpu_blocks_override, + "max_num_batched_tokens=%d", model_config.max_model_len, + scheduler_config.max_num_seqs, cache_config.block_size, scheduler_config.max_num_batched_tokens) # set env vars for torch_sendnn to consume @@ -211,13 +199,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: max(vllm_config.scheduler_config.max_num_seqs, 2)) # Hardcode some things for granite-3.3-8b-instruct - cls.configure_granite_33_8b(vllm_config) + if cls.is_granite_3_8b(vllm_config.model_config): + cls.configure_granite_3_8b(vllm_config) - # max product of batch size x tkv supported by the Spyre compiler - default_max_batch_tkv_limit = \ - vllm_config.model_config.max_model_len * \ - vllm_config.scheduler_config.max_num_seqs if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): + # max product of batch size x tkv supported by the Spyre compiler + default_max_batch_tkv_limit = \ + vllm_config.model_config.max_model_len * \ + vllm_config.scheduler_config.max_num_seqs + os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str( default_max_batch_tkv_limit) logger.info("No model / tensor parallel size specific value for " \ @@ -303,10 +293,6 @@ def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]: def get_block_size(cls) -> int: return cls._block_size - @classmethod - def get_num_spyre_blocks_override(cls) -> int: - return cls._num_spyre_blocks_override - @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: """Returns whether the current platform can support v1 for the supplied @@ -508,40 +494,61 @@ def get_max_output_tokens(self, prompt_len: int) -> int: return max_new_tokens @classmethod - def configure_granite_33_8b(cls, vllm_config: VllmConfig): + def configure_granite_3_8b(cls, vllm_config: VllmConfig): """ Configure hard coded values for the model https://huggingface.co/ibm-granite/granite-3.3-8b-instruct """ - - model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config - if not cls.is_granite_33_8b(model_config): - # Not granite - return - if parallel_config.world_size != 4: # only override configs for TP=4 return + tkv_128k = 128 * 1024 if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): - tkv_128k = 128 * 1024 os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(tkv_128k) logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", tkv_128k) + else: + logger.warning( + "VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %d, not " + "overriding to the granite-3.3-8b-instruct default of %d", + os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"), tkv_128k) # If no HDMA p2psize override was specified, set 256MB + p2psize_256m = 256 * 1024 * 1024 if not os.getenv("FLEX_HDMA_P2PSIZE", None): - p2psize_256m = 256 * 1024 * 1024 os.environ["FLEX_HDMA_P2PSIZE"] = str(p2psize_256m) logger.info( "Model granite-3.3-8b-instruct and tensor parallel size 4 " "detected. Using FLEX_HDMA_P2PSIZE = %d", p2psize_256m) + else: + logger.warning( + "FLEX_HDMA_P2PSIZE was set to %d, not using the " + "granite-3.3-8b-instruct default of %d", + os.getenv("FLEX_HDMA_P2PSIZE", None), p2psize_256m) + + # Override the total number of KV cache blocks based on what we know + # will fit. (Unless user already set `--num-gpu-blocks-override`) + # TODO: remove this once we have correct free memory info available + blocks_override = 2080 + if vllm_config.cache_config.num_gpu_blocks_override is not None: + vllm_config.cache_config.num_gpu_blocks_override = blocks_override + logger.info( + "Model granite-3.3-8b-instruct and tensor parallel size 4 " + "detected. Overriding available KV Cache blocks t0 %d", + blocks_override) + else: + logger.warning( + "--num-gpu-blocks-override was set to %d, not using the " + "granite-3.3-8b-instruct default of %d", + vllm_config.cache_config.num_gpu_blocks_override, + blocks_override) @classmethod - def is_granite_33_8b(cls, model_config: ModelConfig): + def is_granite_3_8b(cls, model_config: ModelConfig): """Returns true if we have a model that looks like ibm-granite/granite-3.3-8b-instruct""" if not isinstance(model_config.hf_config, GraniteConfig): diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 005ab915..437fcbfd 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -288,9 +288,7 @@ def load_model(self, prompt_lens: Iterable[int], max_pad_length = max(prompt_lens) max_decode_length = max(num_decode_tokens) self.model = SpyreCausalLM( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, + vllm_config=self.vllm_config, max_prompt_length=max_pad_length, max_decode_length=max_decode_length, rank=self.rank, @@ -829,7 +827,7 @@ def pre_warmup(self) -> None: # Note: Until this feature is supported by the compiler we have to set: # n_blocks_warmup = n_blocks_avail - n_blocks_warmup = self.model.model.get_num_blocks_available() + n_blocks_warmup = self.get_total_spyre_blocks() self._set_blocks(num_blocks=n_blocks_warmup) self.model.model.set_past_key_value_states(num_blocks=n_blocks_warmup) @@ -849,23 +847,55 @@ def complete_warmup(self) -> None: super().complete_warmup() # get the number or pages from the actual Spyre card after the warmup # and set it accordingly in the model runner and for the kv cache size - n_blocks_avail = self.model.model.get_num_blocks_available() + n_blocks_avail = self.get_total_spyre_blocks() self._set_blocks(num_blocks=n_blocks_avail) self.model.model.set_past_key_value_states(num_blocks=n_blocks_avail) def _set_blocks(self, num_blocks: int) -> None: - # overwrite num_blocks for testing scheduler constraints - num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override() - if num_blocks_override > 0: - logger.info( - "[WARMUP] Overriding number of KV cache blocks on " - "Spyre/CPU to %d.", num_blocks_override) - num_blocks = num_blocks_override - # set number of available blocks and populate block_pool self.n_blocks = num_blocks self.block_pool = deque([i for i in range(self.n_blocks)]) + def get_total_spyre_blocks(self) -> int: + """Returns the total number of KV cache blocks available for spyre. + This currently returns the number of blocks required for a full-sized + batch, which may be greater than the available memory. + + Until a correct available memory api is available, the number of blocks + must be overridden with a known good value via + cache_config.num_gpu_blocks_override + """ + max_batch_size = self.scheduler_config.max_num_seqs + max_model_len = self.scheduler_config.max_model_len + block_size = SpyrePlatform.get_block_size() + min_req_num_blocks = max_model_len // block_size + + blocks_override = self.cache_config.num_gpu_blocks_override + if blocks_override is not None and blocks_override > 0: + num_blocks = blocks_override + else: + num_blocks = max_batch_size * min_req_num_blocks + + # Total number of blocks needs to be a multiple of the batch size + # (spyre constraint) so round it down + num_blocks = max_batch_size * (num_blocks // max_batch_size) + + if num_blocks < min_req_num_blocks: + raise ValueError( + f"Number of pages available on Spyre {num_blocks} is not " + f"enough to serve the current model (need at least " + f"{min_req_num_blocks} pages).") + + max_concurrency = num_blocks * block_size / max_model_len + backend = "Spyre" if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn' \ + else "CPU" + logger.info("%s KV cache size: %s tokens", backend, + num_blocks * block_size) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + str(max_model_len), max_concurrency) + + return num_blocks + def update_states(self, scheduler_output): super().update_states(scheduler_output) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 20849d0b..c314d3c7 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -181,16 +181,16 @@ def determine_available_memory(self) -> int: The number of device blocks (called "gpu blocks" in most places) can also be overridden by `--num-gpu-blocks-override`, which is set under `vllm_config.cache_config.num_gpu_blocks_override`. + + 🌶️🌶️🌶️ The result from this method _only_ applies to the KV Cache + management in vLLM's core scheduler. This does _not_ apply to the KV + cache management handled directly by the vllm-spyre worker and model + runner. We return a minimal value here to make the vllm scheduler happy. """ - # Currently we override vllm_config.cache_config.num_gpu_blocks_override - # in platform.py, so this value is only used by vllm to check that the - # number of gpu blocks will fit in available memory. - # Since we also return dummy values for the kv cache spec, this check is - # meaningless and we can just return a large value to ensure vllm does - # not raise a validation error. - # TODO: Return the real available device memory when we implement real - # kv-caching. - return 1 << 64 + # The fake kv_cache config specified by the model runner sets 4 bytes + # per token. + return (4 * self.scheduler_config.max_model_len * + self.scheduler_config.max_num_seqs) def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: From 298f5bcd068b74bf3198196c8d658e153e48d153 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 08:45:29 -0600 Subject: [PATCH 05/13] :bug: fixup vllm kv cache Signed-off-by: Joe Runde --- vllm_spyre/v1/worker/spyre_worker.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index c314d3c7..65aba7c1 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -189,8 +189,17 @@ def determine_available_memory(self) -> int: """ # The fake kv_cache config specified by the model runner sets 4 bytes # per token. - return (4 * self.scheduler_config.max_model_len * - self.scheduler_config.max_num_seqs) + accurate_fake_kv_cache_size = (4 * + self.scheduler_config.max_model_len * + self.scheduler_config.max_num_seqs) + + # The vLLM scheduler reserves a null block in its kv-cache, so we need + # at least one more block to allow for proper scheduling. We double + # the cache size here to ensure that the vllm scheduler always has + # blocks available. This causes the log message from vLLM about it's + # KV cache capacity to be double the log message from vllm-spyre. + # This can probably be fixed in a nicer way. + return 2 * accurate_fake_kv_cache_size def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: From 1ef6dcdb0d4711a0066b8dbe92fa0564ca5aae51 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 16:10:37 -0600 Subject: [PATCH 06/13] :bug: add config tests and fix 8b bug Signed-off-by: Joe Runde --- .../config.json | 32 ++++++++++++++++ .../granite-3.3-micro-config-only/config.json | 32 ++++++++++++++++ tests/models/test_granite.py | 37 +++++++++++++++++++ vllm_spyre/platform.py | 2 +- 4 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json create mode 100644 tests/fixtures/models/granite-3.3-micro-config-only/config.json create mode 100644 tests/models/test_granite.py diff --git a/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json b/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json new file mode 100644 index 00000000..f731a4bb --- /dev/null +++ b/tests/fixtures/models/granite-3.3-8b-instruct-config-only/config.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "GraniteForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "attention_multiplier": 0.0078125, + "bos_token_id": 0, + "embedding_multiplier": 12.0, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12800, + "logits_scaling": 16.0, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "granite", + "num_attention_heads": 32, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "pad_token_id": 0, + "residual_multiplier": 0.22, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": true, + "vocab_size": 49159 +} diff --git a/tests/fixtures/models/granite-3.3-micro-config-only/config.json b/tests/fixtures/models/granite-3.3-micro-config-only/config.json new file mode 100644 index 00000000..67065dd6 --- /dev/null +++ b/tests/fixtures/models/granite-3.3-micro-config-only/config.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "GraniteForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "attention_multiplier": 0.0078125, + "bos_token_id": 0, + "dtype": "bfloat16", + "embedding_multiplier": 12.0, + "eos_token_id": 0, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12800, + "logits_scaling": 16.0, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "granite", + "num_attention_heads": 32, + "num_hidden_layers": 4, + "num_key_value_heads": 8, + "pad_token_id": 0, + "residual_multiplier": 0.22, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 10000000.0, + "tie_word_embeddings": false, + "transformers_version": "4.56.1", + "use_cache": true, + "vocab_size": 49159 +} diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py new file mode 100644 index 00000000..5a573b75 --- /dev/null +++ b/tests/models/test_granite.py @@ -0,0 +1,37 @@ +"""Tests for model-specific overrides for granite""" +from pathlib import Path + +import pytest +from vllm.config import ModelConfig, ParallelConfig, VllmConfig + +from vllm_spyre.platform import SpyrePlatform + +FIXTURES_PATH = Path(__file__).parent.parent / "fixtures" / "models" + + +@pytest.mark.cpu +def test_granite_3_8b_detection(): + """Check that we can detect the model config for granite 3 8b""" + + granite_3_8b_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only"))) + + granite_micro_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-micro-config-only"))) + + assert SpyrePlatform.is_granite_3_8b(granite_3_8b_config.model_config) + + assert not SpyrePlatform.is_granite_3_8b(granite_micro_config.model_config) + + +@pytest.mark.cpu +def test_granite_3_8b_overrides(): + """Check that the correct values are overridden for g3.3 8b""" + + tp4_config = ParallelConfig(tensor_parallel_size=4) + + granite_3_8b_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), + parallel_config=tp4_config) + + assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080 diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 178d7b02..61ae0ae0 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -555,7 +555,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): # will fit. (Unless user already set `--num-gpu-blocks-override`) # TODO: remove this once we have correct free memory info available blocks_override = 2080 - if vllm_config.cache_config.num_gpu_blocks_override is not None: + if vllm_config.cache_config.num_gpu_blocks_override is None: vllm_config.cache_config.num_gpu_blocks_override = blocks_override logger.info( "Model granite-3.3-8b-instruct and tensor parallel size 4 " From c8dbfe7b4c3c7151cef7ed00b67b3bc6f167d068 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 16:45:53 -0600 Subject: [PATCH 07/13] :test_tube: check for env var overrides too Signed-off-by: Joe Runde --- tests/models/test_granite.py | 17 ++++++++++++----- vllm_spyre/platform.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py index 5a573b75..b4102c85 100644 --- a/tests/models/test_granite.py +++ b/tests/models/test_granite.py @@ -1,5 +1,7 @@ """Tests for model-specific overrides for granite""" +import os from pathlib import Path +from unittest import mock import pytest from vllm.config import ModelConfig, ParallelConfig, VllmConfig @@ -28,10 +30,15 @@ def test_granite_3_8b_detection(): def test_granite_3_8b_overrides(): """Check that the correct values are overridden for g3.3 8b""" - tp4_config = ParallelConfig(tensor_parallel_size=4) + # Must ensure no env vars have been overridden before testing + with mock.patch.dict(os.environ, clear=True): + tp4_config = ParallelConfig(tensor_parallel_size=4) - granite_3_8b_config = VllmConfig(model_config=ModelConfig( - model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), - parallel_config=tp4_config) + granite_3_8b_config = VllmConfig(model_config=ModelConfig( + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), + parallel_config=tp4_config) + + assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080 - assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080 + assert int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT")) == 128 * 1024 + assert int(os.getenv("FLEX_HDMA_P2PSIZE")) == 256 * 1024 * 1024 diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 61ae0ae0..09993c8b 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -534,7 +534,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): tkv_128k) else: logger.warning( - "VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %d, not " + "VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %s, not " "overriding to the granite-3.3-8b-instruct default of %d", os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"), tkv_128k) From 9dae41ba992870cb1763c005210052f918652f35 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 16:52:40 -0600 Subject: [PATCH 08/13] :bug: set less swap space Signed-off-by: Joe Runde --- tests/models/test_granite.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py index b4102c85..8d9cba0a 100644 --- a/tests/models/test_granite.py +++ b/tests/models/test_granite.py @@ -4,22 +4,26 @@ from unittest import mock import pytest -from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm_spyre.platform import SpyrePlatform FIXTURES_PATH = Path(__file__).parent.parent / "fixtures" / "models" +NO_SWAP_CONFIG = CacheConfig(swap_space=0.001) + @pytest.mark.cpu def test_granite_3_8b_detection(): """Check that we can detect the model config for granite 3 8b""" granite_3_8b_config = VllmConfig(model_config=ModelConfig( - model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only"))) + model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), + cache_config=NO_SWAP_CONFIG) granite_micro_config = VllmConfig(model_config=ModelConfig( - model=str(FIXTURES_PATH / "granite-3.3-micro-config-only"))) + model=str(FIXTURES_PATH / "granite-3.3-micro-config-only")), + cache_config=NO_SWAP_CONFIG) assert SpyrePlatform.is_granite_3_8b(granite_3_8b_config.model_config) @@ -36,7 +40,8 @@ def test_granite_3_8b_overrides(): granite_3_8b_config = VllmConfig(model_config=ModelConfig( model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")), - parallel_config=tp4_config) + parallel_config=tp4_config, + cache_config=NO_SWAP_CONFIG) assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080 From 51c3103d76b37b1bac8853ae7b666f54f1a04cbb Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 17:14:51 -0600 Subject: [PATCH 09/13] Update vllm_spyre/platform.py Co-authored-by: Travis Johnson Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 09993c8b..0feb0ff0 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -547,7 +547,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): "detected. Using FLEX_HDMA_P2PSIZE = %d", p2psize_256m) else: logger.warning( - "FLEX_HDMA_P2PSIZE was set to %d, not using the " + "FLEX_HDMA_P2PSIZE was set to %s, not using the " "granite-3.3-8b-instruct default of %d", os.getenv("FLEX_HDMA_P2PSIZE", None), p2psize_256m) From 09735e84e18d39232f28f33a10d7d3f4b24eab6e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 17:15:00 -0600 Subject: [PATCH 10/13] Update vllm_spyre/platform.py Co-authored-by: Travis Johnson Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 0feb0ff0..779322ea 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -202,7 +202,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cls.is_granite_3_8b(vllm_config.model_config): cls.configure_granite_3_8b(vllm_config) - if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"): # max product of batch size x tkv supported by the Spyre compiler default_max_batch_tkv_limit = \ vllm_config.model_config.max_model_len * \ From 87b875da106ae9ead3bb390d6f16ba37ac1d034a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 17:17:28 -0600 Subject: [PATCH 11/13] :recycle: remove default None from getenv Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 779322ea..f41b9f16 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -527,7 +527,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): return tkv_128k = 128 * 1024 - if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", None): + if not os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT"): os.environ["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = str(tkv_128k) logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", @@ -540,7 +540,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): # If no HDMA p2psize override was specified, set 256MB p2psize_256m = 256 * 1024 * 1024 - if not os.getenv("FLEX_HDMA_P2PSIZE", None): + if not os.getenv("FLEX_HDMA_P2PSIZE"): os.environ["FLEX_HDMA_P2PSIZE"] = str(p2psize_256m) logger.info( "Model granite-3.3-8b-instruct and tensor parallel size 4 " @@ -549,7 +549,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): logger.warning( "FLEX_HDMA_P2PSIZE was set to %s, not using the " "granite-3.3-8b-instruct default of %d", - os.getenv("FLEX_HDMA_P2PSIZE", None), p2psize_256m) + os.getenv("FLEX_HDMA_P2PSIZE"), p2psize_256m) # Override the total number of KV cache blocks based on what we know # will fit. (Unless user already set `--num-gpu-blocks-override`) From 8991f8d8210a24bde4b477a6f20328a02d1408a1 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 17:19:10 -0600 Subject: [PATCH 12/13] :recycle: add elif checks for warnings Signed-off-by: Joe Runde --- vllm_spyre/platform.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index f41b9f16..685dd844 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -532,7 +532,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ "size 4 detected. Using VLLM_DT_MAX_BATCH_TKV_LIMIT = %d", tkv_128k) - else: + elif os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT") != str(tkv_128k): logger.warning( "VLLM_DT_MAX_BATCH_TKV_LIMIT was set to %s, not " "overriding to the granite-3.3-8b-instruct default of %d", @@ -545,7 +545,7 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): logger.info( "Model granite-3.3-8b-instruct and tensor parallel size 4 " "detected. Using FLEX_HDMA_P2PSIZE = %d", p2psize_256m) - else: + elif os.getenv("FLEX_HDMA_P2PSIZE") != str(p2psize_256m): logger.warning( "FLEX_HDMA_P2PSIZE was set to %s, not using the " "granite-3.3-8b-instruct default of %d", @@ -561,7 +561,8 @@ def configure_granite_3_8b(cls, vllm_config: VllmConfig): "Model granite-3.3-8b-instruct and tensor parallel size 4 " "detected. Overriding available KV Cache blocks to %d", blocks_override) - else: + elif (vllm_config.cache_config.num_gpu_blocks_override + != blocks_override): logger.warning( "--num-gpu-blocks-override was set to %d, not using the " "granite-3.3-8b-instruct default of %d", From 9d9c87451f72f22871d14da3a85aa55a972ca97c Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 3 Oct 2025 17:20:37 -0600 Subject: [PATCH 13/13] :recycle: remove unused model.vllm_config Signed-off-by: Joe Runde --- vllm_spyre/model_executor/model_loader/spyre.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 4e6f1bdb..f2b113f1 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -57,8 +57,6 @@ def __init__( rank: int, ) -> None: super().__init__() - self.vllm_config = vllm_config - self.logits_processor = LogitsProcessor( vllm_config.model_config.hf_config.vocab_size, logits_as_input=True) @@ -163,7 +161,6 @@ def __init__( # Actual FMS model self.model: nn.Module - self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config