diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 29e7fd149..a6ed819b8 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -300,6 +300,9 @@ def __init__( max_decode_length, sendnn_dynamic=True) + self.scheduler_config = scheduler_config + self.parallel_config = parallel_config + # physical KV cache on AIU Spyre: will eventually not live in this class self.kv_cache_specs = {} self.kv_cache_specs['block_size'] = BLOCK_SIZE @@ -324,18 +327,61 @@ def __init__( else: self.attention_name = "spyre_paged_attn" - # set num_blocks to the minimal value of 4 required for warmup - # is reset to the value returned by the Spyre compiler after warmup - # self._set_past_key_value_states(num_blocks=4) - num_blocks = scheduler_config.max_num_seqs * max_model_len // BLOCK_SIZE - self._set_past_key_value_states(num_blocks=num_blocks) - - # mark the num_blocks dimension dynamic for Spyre compiler for warmup - # only, compiler will return the number of blocks it can accommodate. - # (This is not yet supported by the compiler) - # for layer in self.past_key_value_states: - # for tensor in layer: - # torch._dynamo.mark_dynamic(tensor, 0) + def get_num_blocks_available(self) -> int: + """Function returns the number of available blocks/pages. + Will eventually contain a function in torch_sendnn which reads + the actual value provided by the compiler for backend sendnn""" + + max_batch_size = self.scheduler_config.max_num_seqs + max_model_len = self.scheduler_config.max_model_len + block_size = self.kv_cache_specs['block_size'] + + min_req_num_blocks = max_model_len // block_size + + # TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function + # in torch_sendnn which returns the value set by the Spyre compiler. + if ('granite-3.3-8b-instruct' in self.model_config.model + and self.parallel_config.world_size == 4): + # hard coded value for tensor parallel size 4 with the below model + # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct + NUM_BLOCKS_SPYRE = 2080 + logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ + "size 4 detected. Using NUM_BLOCKS_SPYRE = %d", 2080) + else: + # default value for any other model/ tensor parallel size + NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks + logger.info("No model / tensor parallel size specific value for" \ + "the number of KV cache blocks available on Spyre found. Using " \ + "default value (max_batch_size * max_model_len / block_size): %d", + NUM_BLOCKS_SPYRE) + + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn': + num_blocks_spyre = NUM_BLOCKS_SPYRE + assert num_blocks_spyre >= min_req_num_blocks, ( + "Number of pages available on Spyre (%d) is not enough to " + "serve the current model (need at least %d pages)." % + (num_blocks_spyre, min_req_num_blocks)) + max_concurrency_spyre = num_blocks_spyre * block_size \ + / max_model_len + logger.info("Spyre KV cache size: %s tokens", + num_blocks_spyre * block_size) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + str(max_model_len), max_concurrency_spyre) + return num_blocks_spyre + else: # dynamo backend 'eager' + # for debugging purposes we also put the spyre value here for cpu + num_blocks_cpu = NUM_BLOCKS_SPYRE + assert num_blocks_cpu >= min_req_num_blocks, ( + "Number of pages available on CPU (%d) is not enough to " + "serve the current model (need at least %d pages)." % + (num_blocks_cpu, min_req_num_blocks)) + max_concurrency_cpu = num_blocks_cpu * block_size \ + / max_model_len + logger.info("CPU KV cache size: %s tokens", + num_blocks_cpu * block_size) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + str(max_model_len), max_concurrency_cpu) + return num_blocks_cpu def _set_past_key_value_states(self, num_blocks) -> None: # overwrite num_blocks for testing scheduler constraints diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 637f4adb5..a3b26a3f8 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -782,15 +782,6 @@ def __init__( self.req_ids2reserved_blocks: dict[str, int] = {} self.tkv: int = 0 - # set self.block_pool to the minimal value of 4 required for warmup - # is reset to the value returned by the Spyre compiler after warmup - # self._set_blocks(num_blocks=4) - # for the time being we set this to num_blocks consistent with the - # cache dimension of ContinuousBatchingFmsModel.past_key_value_states - num_blocks = (vllm_config.scheduler_config.max_num_seqs * - vllm_config.model_config.max_model_len // - self.block_size) - self._set_blocks(num_blocks=num_blocks) # TODO: Remove this once we can prefill and decode in the same step self.prefill_batch = SamplingInputBatch( @@ -803,11 +794,34 @@ def __init__( vocab_size=vllm_config.model_config.get_vocab_size(), ) + def pre_warmup(self) -> None: + # Set the number of kv cache blocks to the minimal value of 2 which is + # required for warmup. After the warmup, the number of blocks will be + # set to the value returned by the Spyre compiler (see complete_warmup) + # Note: Until this feature is supported by the compiler we have to set: + # n_blocks_warmup = n_blocks_avail + + n_blocks_warmup = self.model.model.get_num_blocks_available() + self._set_blocks(num_blocks=n_blocks_warmup) + self.model.model._set_past_key_value_states(num_blocks=n_blocks_warmup) + + # Future code: + + # self._set_blocks(num_blocks=2) + # self.model.model._set_past_key_value_states(num_blocks=2) + + # mark the num_blocks dimension dynamic for Spyre compiler for warmup + # only, compiler will return the number of blocks it can accommodate. + # (This is not yet supported by the compiler) + # for layer in self.model.model.past_key_value_states: + # for tensor in layer: + # torch._dynamo.mark_dynamic(tensor, 0) + def complete_warmup(self) -> None: super().complete_warmup() # get the number or pages from the actual Spyre card after the warmup - # and set it accordingly in the model runner and the kv cache size - n_blocks_avail = self._get_num_blocks_available() + # and set it accordingly in the model runner and for the kv cache size + n_blocks_avail = self.model.model.get_num_blocks_available() self._set_blocks(num_blocks=n_blocks_avail) self.model.model._set_past_key_value_states(num_blocks=n_blocks_avail) @@ -824,62 +838,6 @@ def _set_blocks(self, num_blocks: int) -> None: self.n_blocks = num_blocks self.block_pool = deque([i for i in range(self.n_blocks)]) - def _get_num_blocks_available(self) -> int: - """Function returns the number of available blocks/pages. - Will eventually contain a function in torch_sendnn which reads - the actual value provided by the compiler for backend sendnn""" - - max_batch_size = self.vllm_config.scheduler_config.max_num_seqs - max_model_len = self.vllm_config.scheduler_config.max_model_len - - min_req_num_blocks = max_model_len // self.block_size - - # TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function - # in torch_sendnn which returns the value set by the Spyre compiler. - - if ('granite-3.3-8b-instruct' in self.model_config.model - and self.parallel_config.world_size == 4): - # hard coded value for tensor parallel size 4 with the below model - # https://huggingface.co/ibm-granite/granite-3.3-8b-instruct - NUM_BLOCKS_SPYRE = 2080 - logger.info("Model granite-3.3-8b-instruct and tensor parallel " \ - "size 4 detected. Using NUM_BLOCKS_SPYRE = %d", 2080) - else: - # default value for any other model/ tensor parallel size - NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks - logger.info("No model / tensor parallel size specific value for" \ - "the number of KV cache blocks available on Spyre found. Using " \ - "default value (max_batch_size * max_model_len / block_size): %d", - NUM_BLOCKS_SPYRE) - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn': - num_blocks_spyre = NUM_BLOCKS_SPYRE - assert num_blocks_spyre >= min_req_num_blocks, ( - "Number of pages available on Spyre (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_spyre, min_req_num_blocks)) - max_concurrency_spyre = num_blocks_spyre * self.block_size \ - / max_model_len - logger.info("Spyre KV cache size: %s tokens", - num_blocks_spyre * self.block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_spyre) - return num_blocks_spyre - else: # dynamo backend 'eager' - # for debugging purposes we also put the spyre value here for cpu - num_blocks_cpu = NUM_BLOCKS_SPYRE - assert num_blocks_cpu >= min_req_num_blocks, ( - "Number of pages available on CPU (%d) is not enough to " - "serve the current model (need at least %d pages)." % - (num_blocks_cpu, min_req_num_blocks)) - max_concurrency_cpu = num_blocks_cpu * self.block_size \ - / max_model_len - logger.info("CPU KV cache size: %s tokens", - num_blocks_cpu * self.block_size) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - str(max_model_len), max_concurrency_cpu) - return num_blocks_cpu - def update_states(self, scheduler_output): super().update_states(scheduler_output) diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 4f466b268..f5be6b654 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -389,6 +389,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): pooling_params=None, ) for i in range(2)) + model_runner.pre_warmup() + with _maybe_warmup_context(): self._dynamic_warmup(request=warmup_req, prompt_len=prompt_len,