Skip to content

Commit 443cb83

Browse files
authored
[CB] hard code number of spyre blocks to 2080 (#362)
### [CB] hard code number of Spyre blocks to 2080 We have received a number from the compiler team for the [model](https://huggingface.co/ibm-granite/granite-3.3-8b-instruct) of interest and hard code it here. --------- Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]>
1 parent e8ca058 commit 443cb83

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -840,15 +840,27 @@ def _get_num_blocks_available(self) -> int:
840840
max_model_len = self.vllm_config.scheduler_config.max_model_len
841841

842842
min_req_num_blocks = max_model_len // self.block_size
843-
# min_req_num_blocks is not enough blocks for the following test:
844-
# tests/e2e/test_spyre_cb.py::test_scheduler_cb_steps_tkv
845-
# [seqs_max_tokens4-prompts_lengths4-steps_add_reqs4-
846-
# checked_steps4-256-False-2-eager-llama-194m]
843+
844+
# TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function
845+
# in torch_sendnn which returns the value set by the Spyre compiler.
846+
847+
if ('granite-3.3-8b-instruct' in self.model_config.model
848+
and self.parallel_config.world_size == 4):
849+
# hard coded value for tensor parallel size 4 with the below model
850+
# https://huggingface.co/ibm-granite/granite-3.3-8b-instruct
851+
NUM_BLOCKS_SPYRE = 2080
852+
logger.info("Model granite-3.3-8b-instruct and tensor parallel " \
853+
"size 4 detected. Using NUM_BLOCKS_SPYRE = %d", 2080)
854+
else:
855+
# default value for any other model/ tensor parallel size
856+
NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks
857+
logger.info("No model / tensor parallel size specific value for" \
858+
"the number of KV cache blocks available on Spyre found. Using " \
859+
"default value (max_batch_size * max_model_len / block_size): %d",
860+
NUM_BLOCKS_SPYRE)
847861

848862
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn':
849-
# TODO: replace num_blocks_spyre by calling a function in
850-
# torch_sendnn which returns the value set by the Spyre compiler
851-
num_blocks_spyre = max_batch_size * min_req_num_blocks
863+
num_blocks_spyre = NUM_BLOCKS_SPYRE
852864
assert num_blocks_spyre >= min_req_num_blocks, (
853865
"Number of pages available on Spyre (%d) is not enough to "
854866
"serve the current model (need at least %d pages)." %
@@ -861,7 +873,8 @@ def _get_num_blocks_available(self) -> int:
861873
str(max_model_len), max_concurrency_spyre)
862874
return num_blocks_spyre
863875
else: # dynamo backend 'eager'
864-
num_blocks_cpu = max_batch_size * min_req_num_blocks
876+
# for debugging purposes we also put the spyre value here for cpu
877+
num_blocks_cpu = NUM_BLOCKS_SPYRE
865878
assert num_blocks_cpu >= min_req_num_blocks, (
866879
"Number of pages available on CPU (%d) is not enough to "
867880
"serve the current model (need at least %d pages)." %

0 commit comments

Comments
 (0)