Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"architectures": [
"GraniteForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"attention_multiplier": 0.0078125,
"bos_token_id": 0,
"embedding_multiplier": 12.0,
"eos_token_id": 0,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12800,
"logits_scaling": 16.0,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "granite",
"num_attention_heads": 32,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"pad_token_id": 0,
"residual_multiplier": 0.22,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000000.0,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.49.0",
"use_cache": true,
"vocab_size": 49159
}
32 changes: 32 additions & 0 deletions tests/fixtures/models/granite-3.3-micro-config-only/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"architectures": [
"GraniteForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"attention_multiplier": 0.0078125,
"bos_token_id": 0,
"dtype": "bfloat16",
"embedding_multiplier": 12.0,
"eos_token_id": 0,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12800,
"logits_scaling": 16.0,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "granite",
"num_attention_heads": 32,
"num_hidden_layers": 4,
"num_key_value_heads": 8,
"pad_token_id": 0,
"residual_multiplier": 0.22,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000000.0,
"tie_word_embeddings": false,
"transformers_version": "4.56.1",
"use_cache": true,
"vocab_size": 49159
}
49 changes: 49 additions & 0 deletions tests/models/test_granite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for model-specific overrides for granite"""
import os
from pathlib import Path
from unittest import mock

import pytest
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig

from vllm_spyre.platform import SpyrePlatform

FIXTURES_PATH = Path(__file__).parent.parent / "fixtures" / "models"

NO_SWAP_CONFIG = CacheConfig(swap_space=0.001)


@pytest.mark.cpu
def test_granite_3_8b_detection():
"""Check that we can detect the model config for granite 3 8b"""

granite_3_8b_config = VllmConfig(model_config=ModelConfig(
model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")),
cache_config=NO_SWAP_CONFIG)

granite_micro_config = VllmConfig(model_config=ModelConfig(
model=str(FIXTURES_PATH / "granite-3.3-micro-config-only")),
cache_config=NO_SWAP_CONFIG)

assert SpyrePlatform.is_granite_3_8b(granite_3_8b_config.model_config)

assert not SpyrePlatform.is_granite_3_8b(granite_micro_config.model_config)


@pytest.mark.cpu
def test_granite_3_8b_overrides():
"""Check that the correct values are overridden for g3.3 8b"""

# Must ensure no env vars have been overridden before testing
with mock.patch.dict(os.environ, clear=True):
tp4_config = ParallelConfig(tensor_parallel_size=4)

granite_3_8b_config = VllmConfig(model_config=ModelConfig(
model=str(FIXTURES_PATH / "granite-3.3-8b-instruct-config-only")),
parallel_config=tp4_config,
cache_config=NO_SWAP_CONFIG)

assert granite_3_8b_config.cache_config.num_gpu_blocks_override == 2080

assert int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT")) == 128 * 1024
assert int(os.getenv("FLEX_HDMA_P2PSIZE")) == 256 * 1024 * 1024
130 changes: 23 additions & 107 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from fms.models import get_model
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -51,17 +51,15 @@ class SpyreCausalLM(nn.Module):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
vllm_config: VllmConfig,
max_prompt_length: int,
max_decode_length: int,
rank: int,
) -> None:
super().__init__()

self.logits_processor = LogitsProcessor(
model_config.hf_config.vocab_size, logits_as_input=True)
vllm_config.model_config.hf_config.vocab_size,
logits_as_input=True)
self.sampler = get_sampler()

# boolean tensor of length batch size with indices:
Expand All @@ -78,14 +76,10 @@ def __init__(

# FMS Model
if envs_spyre.VLLM_SPYRE_USE_CB:
self.model = ContinuousBatchingFmsModel(model_config,
parallel_config,
scheduler_config, rank)
self.model = ContinuousBatchingFmsModel(vllm_config, rank)
else:
self.model = StaticBatchingFmsModel(
model_config,
parallel_config,
scheduler_config,
vllm_config,
max_prompt_length,
max_decode_length,
rank,
Expand Down Expand Up @@ -155,32 +149,34 @@ class FmsModelBase(nn.Module):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
vllm_config: VllmConfig,
max_prompt_length: int,
max_decode_length: int,
rank: int,
sendnn_dynamic: bool,
) -> None:
super().__init__()

self.config: PretrainedConfig = model_config.hf_config
self.config: PretrainedConfig = vllm_config.model_config.hf_config

# Actual FMS model
self.model: nn.Module
self.model_config = model_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.scheduler_config = vllm_config.scheduler_config
self.dtype = self.get_dtype()

# Load the weights from the cached or downloaded files.
self.load_weights(
model_config=model_config,
model_config=self.model_config,
max_prompt_length=max_prompt_length,
max_decode_length=max_decode_length,
distributed_strategy="tp"
if parallel_config.world_size > 1 else None,
if self.parallel_config.world_size > 1 else None,
sendnn_dynamic=sendnn_dynamic,
rank=rank,
world_size=parallel_config.world_size,
world_size=self.parallel_config.world_size,
)

def load_weights(
Expand Down Expand Up @@ -321,37 +317,32 @@ class ContinuousBatchingFmsModel(FmsModelBase):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
vllm_config: VllmConfig,
rank: int,
) -> None:

BLOCK_SIZE = SpyrePlatform.get_block_size()
max_model_len = scheduler_config.max_model_len
max_model_len = vllm_config.scheduler_config.max_model_len

# edge case: prompt fills model length: can produce 1 token with prefill
max_prompt_length = max_model_len
# edge case: prompt will be padded to first block:
# can produce 1 token with prefill plus rest of model length
max_decode_length = max_model_len - BLOCK_SIZE + 1

super().__init__(model_config,
parallel_config,
super().__init__(vllm_config,
max_prompt_length,
max_decode_length,
rank,
sendnn_dynamic=True)

self.scheduler_config = scheduler_config
self.parallel_config = parallel_config
self.prefill_past_key_values = None

# physical KV cache on AIU Spyre: will eventually not live in this class
self.kv_cache_specs = {}
self.kv_cache_specs['block_size'] = BLOCK_SIZE
self.kv_cache_specs['num_kv_heads'] = model_config.get_num_kv_heads(
parallel_config)
self.kv_cache_specs[
'num_kv_heads'] = self.model_config.get_num_kv_heads(
self.parallel_config)

if self.config.model_type in {'llama', 'granite'}:
self.kv_cache_specs['num_layers'] = self.config.num_hidden_layers
Expand All @@ -375,81 +366,9 @@ def __init__(

self.current_scale: Optional[list[tuple]] = None

def get_num_blocks_available(self) -> int:
"""Function returns the number of available blocks/pages.
Will eventually contain a function in torch_sendnn which reads
the actual value provided by the compiler for backend sendnn"""

max_batch_size = self.scheduler_config.max_num_seqs
max_model_len = self.scheduler_config.max_model_len
block_size = self.kv_cache_specs['block_size']

min_req_num_blocks = max_model_len // block_size

# TODO: replace the hard coded NUM_BLOCKS_SPYRE by calling a function
# in torch_sendnn which returns the value set by the Spyre compiler.
if ('granite-3.3-8b-instruct' in self.model_config.model
and self.parallel_config.world_size == 4):
# hard coded value for tensor parallel size 4 with the below model
# https://huggingface.co/ibm-granite/granite-3.3-8b-instruct

# num_blocks_spyre must be multiple of max_batch_size
NUM_BLOCKS_SPYRE = max_batch_size * (2080 // max_batch_size)
logger.info(
"Model %s and tensor parallel "
"size %d detected. Using NUM_BLOCKS_SPYRE = %d",
self.model_config.model,
self.parallel_config.world_size,
NUM_BLOCKS_SPYRE,
)
else:
# default value for any other model/ tensor parallel size
NUM_BLOCKS_SPYRE = max_batch_size * min_req_num_blocks
logger.info("No model / tensor parallel size specific value for " \
"the number of KV cache blocks available on Spyre found. Using " \
"default value (max_batch_size * max_model_len / block_size): %d",
NUM_BLOCKS_SPYRE)

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn':
num_blocks_spyre = NUM_BLOCKS_SPYRE
assert num_blocks_spyre >= min_req_num_blocks, (
"Number of pages available on Spyre (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_spyre, min_req_num_blocks))
max_concurrency_spyre = num_blocks_spyre * block_size \
/ max_model_len
logger.info("Spyre KV cache size: %s tokens",
num_blocks_spyre * block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_spyre)

assert num_blocks_spyre % max_batch_size == 0, \
"num_blocks_spyre must be multiple of max_batch_size"
return num_blocks_spyre
else: # dynamo backend 'eager'
# for debugging purposes we also put the spyre value here for cpu
num_blocks_cpu = NUM_BLOCKS_SPYRE
assert num_blocks_cpu >= min_req_num_blocks, (
"Number of pages available on CPU (%d) is not enough to "
"serve the current model (need at least %d pages)." %
(num_blocks_cpu, min_req_num_blocks))
max_concurrency_cpu = num_blocks_cpu * block_size \
/ max_model_len
logger.info("CPU KV cache size: %s tokens",
num_blocks_cpu * block_size)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
str(max_model_len), max_concurrency_cpu)
return num_blocks_cpu

def set_past_key_value_states(self, num_blocks) -> None:
# overwrite num_blocks for testing scheduler constraints
num_blocks_override = SpyrePlatform.get_num_spyre_blocks_override()
if num_blocks_override > 0:
num_blocks = num_blocks_override

# List[layers] of Tuple[k,v] of
# Tensor[num_blocks, block_size, num_kv_heads, head_dim]

if not self.model_config.quantization:
self.past_key_value_states = [
(torch.zeros(num_blocks,
Expand Down Expand Up @@ -665,15 +584,12 @@ class StaticBatchingFmsModel(FmsModelBase):

def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
_: SchedulerConfig,
vllm_config: VllmConfig,
max_prompt_length: int,
max_decode_length: int,
rank: int,
) -> None:
super().__init__(model_config,
parallel_config,
super().__init__(vllm_config,
max_prompt_length,
max_decode_length,
rank,
Expand Down
Loading