Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 58 additions & 12 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def __init__(
max_decode_length,
sendnn_dynamic=True)

self.scheduler_config = scheduler_config
self.parallel_config = parallel_config

# physical KV cache on AIU Spyre: will eventually not live in this class
self.kv_cache_specs = {}
self.kv_cache_specs['block_size'] = BLOCK_SIZE
Expand All @@ -324,18 +327,61 @@ def __init__(
else:
self.attention_name = "spyre_paged_attn"

# set num_blocks to the minimal value of 4 required for warmup
# is reset to the value returned by the Spyre compiler after warmup
# self._set_past_key_value_states(num_blocks=4)
num_blocks = scheduler_config.max_num_seqs * max_model_len // BLOCK_SIZE
self._set_past_key_value_states(num_blocks=num_blocks)

# mark the num_blocks dimension dynamic for Spyre compiler for warmup
# only, compiler will return the number of blocks it can accommodate.
# (This is not yet supported by the compiler)
# for layer in self.past_key_value_states:
# for tensor in layer:
# torch._dynamo.mark_dynamic(tensor, 0)
def get_num_blocks_available(self) -> int:
"""Function returns the number of available blocks/pages.
Will eventually contain a function in torch_sendnn which reads
the actual value provided by the compiler for backend sendnn"""

max_batch_size = self.scheduler_config.max_num_seqs
max_model_len = self.scheduler_config.max_model_len
block_size = self.kv_cache_specs['block_size']

min_req_num_blocks = max_model_len // block_size

# TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function
# in torch_sendnn which returns the value set by the Spyre compiler.
if ('granite-3.3-8b-instruct' in self.model_config.model
and self.parallel_config.world_size == 4):
# hard coded value for tensor parallel size 4 with the below model
# https://huggingface.co/ibm-granite/granite-3.3-8b-instruct
NUM_BLOCKS_SPYRE = 2080
logger.info("Model granite-3.3-8b-instruct and tensor parallel " \
"size 4 detected. Using NUM_BLOCKS_SPYRE = %d", 2080)
else:
# default value for any other model/ tensor parallel size
NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks
logger.info("No model / tensor parallel size specific value for" \
"the number of KV cache blocks available on Spyre found. Using " \
"default value (max_batch_size * max_model_len / block_size): %d",
NUM_BLOCKS_SPYRE)

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn':
num_blocks_spyre = NUM_BLOCKS_SPYRE
assert num_blocks_spyre >= min_req_num_blocks, (
"Number of pages available on Spyre (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_spyre, min_req_num_blocks))
max_concurrency_spyre = num_blocks_spyre * block_size \
/ max_model_len
logger.info("Spyre KV cache size: %s tokens",
num_blocks_spyre * block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_spyre)
return num_blocks_spyre
else: # dynamo backend 'eager'
# for debugging purposes we also put the spyre value here for cpu
num_blocks_cpu = NUM_BLOCKS_SPYRE
assert num_blocks_cpu >= min_req_num_blocks, (
"Number of pages available on CPU (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_cpu, min_req_num_blocks))
max_concurrency_cpu = num_blocks_cpu * block_size \
/ max_model_len
logger.info("CPU KV cache size: %s tokens",
num_blocks_cpu * block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_cpu)
return num_blocks_cpu

def _set_past_key_value_states(self, num_blocks) -> None:
# overwrite num_blocks for testing scheduler constraints
Expand Down
92 changes: 25 additions & 67 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,15 +782,6 @@ def __init__(
self.req_ids2reserved_blocks: dict[str, int] = {}

self.tkv: int = 0
# set self.block_pool to the minimal value of 4 required for warmup
# is reset to the value returned by the Spyre compiler after warmup
# self._set_blocks(num_blocks=4)
# for the time being we set this to num_blocks consistent with the
# cache dimension of ContinuousBatchingFmsModel.past_key_value_states
num_blocks = (vllm_config.scheduler_config.max_num_seqs *
vllm_config.model_config.max_model_len //
self.block_size)
self._set_blocks(num_blocks=num_blocks)

# TODO: Remove this once we can prefill and decode in the same step
self.prefill_batch = SamplingInputBatch(
Expand All @@ -803,11 +794,34 @@ def __init__(
vocab_size=vllm_config.model_config.get_vocab_size(),
)

def pre_warmup(self) -> None:
# Set the number of kv cache blocks to the minimal value of 2 which is
# required for warmup. After the warmup, the number of blocks will be
# set to the value returned by the Spyre compiler (see complete_warmup)
# Note: Until this feature is supported by the compiler we have to set:
# n_blocks_warmup = n_blocks_avail

n_blocks_warmup = self.model.model.get_num_blocks_available()
self._set_blocks(num_blocks=n_blocks_warmup)
self.model.model._set_past_key_value_states(num_blocks=n_blocks_warmup)

# Future code:

# self._set_blocks(num_blocks=2)
# self.model.model._set_past_key_value_states(num_blocks=2)

# mark the num_blocks dimension dynamic for Spyre compiler for warmup
# only, compiler will return the number of blocks it can accommodate.
# (This is not yet supported by the compiler)
# for layer in self.model.model.past_key_value_states:
# for tensor in layer:
# torch._dynamo.mark_dynamic(tensor, 0)

def complete_warmup(self) -> None:
super().complete_warmup()
# get the number or pages from the actual Spyre card after the warmup
# and set it accordingly in the model runner and the kv cache size
n_blocks_avail = self._get_num_blocks_available()
# and set it accordingly in the model runner and for the kv cache size
n_blocks_avail = self.model.model.get_num_blocks_available()
self._set_blocks(num_blocks=n_blocks_avail)
self.model.model._set_past_key_value_states(num_blocks=n_blocks_avail)

Expand All @@ -824,62 +838,6 @@ def _set_blocks(self, num_blocks: int) -> None:
self.n_blocks = num_blocks
self.block_pool = deque([i for i in range(self.n_blocks)])

def _get_num_blocks_available(self) -> int:
"""Function returns the number of available blocks/pages.
Will eventually contain a function in torch_sendnn which reads
the actual value provided by the compiler for backend sendnn"""

max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
max_model_len = self.vllm_config.scheduler_config.max_model_len

min_req_num_blocks = max_model_len // self.block_size

# TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function
# in torch_sendnn which returns the value set by the Spyre compiler.

if ('granite-3.3-8b-instruct' in self.model_config.model
and self.parallel_config.world_size == 4):
# hard coded value for tensor parallel size 4 with the below model
# https://huggingface.co/ibm-granite/granite-3.3-8b-instruct
NUM_BLOCKS_SPYRE = 2080
logger.info("Model granite-3.3-8b-instruct and tensor parallel " \
"size 4 detected. Using NUM_BLOCKS_SPYRE = %d", 2080)
else:
# default value for any other model/ tensor parallel size
NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks
logger.info("No model / tensor parallel size specific value for" \
"the number of KV cache blocks available on Spyre found. Using " \
"default value (max_batch_size * max_model_len / block_size): %d",
NUM_BLOCKS_SPYRE)

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn':
num_blocks_spyre = NUM_BLOCKS_SPYRE
assert num_blocks_spyre >= min_req_num_blocks, (
"Number of pages available on Spyre (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_spyre, min_req_num_blocks))
max_concurrency_spyre = num_blocks_spyre * self.block_size \
/ max_model_len
logger.info("Spyre KV cache size: %s tokens",
num_blocks_spyre * self.block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_spyre)
return num_blocks_spyre
else: # dynamo backend 'eager'
# for debugging purposes we also put the spyre value here for cpu
num_blocks_cpu = NUM_BLOCKS_SPYRE
assert num_blocks_cpu >= min_req_num_blocks, (
"Number of pages available on CPU (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_cpu, min_req_num_blocks))
max_concurrency_cpu = num_blocks_cpu * self.block_size \
/ max_model_len
logger.info("CPU KV cache size: %s tokens",
num_blocks_cpu * self.block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_cpu)
return num_blocks_cpu

def update_states(self, scheduler_output):

super().update_states(scheduler_output)
Expand Down
2 changes: 2 additions & 0 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
pooling_params=None,
) for i in range(2))

model_runner.pre_warmup()

with _maybe_warmup_context():
self._dynamic_warmup(request=warmup_req,
prompt_len=prompt_len,
Expand Down