Skip to content
7 changes: 4 additions & 3 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ def __init__(
'num_kv_heads'] = self.model_config.get_num_kv_heads(
self.parallel_config)

if self.config.model_type in {'llama', 'granite'}:
if self.config.model_type in {'llama', 'granite', 'granitemoehybrid'}:
self.kv_cache_specs['num_layers'] = self.config.num_hidden_layers
self.kv_cache_specs['head_dim'] = self.config.hidden_size // \
self.config.num_attention_heads
self.kv_cache_specs['head_dim'] = getattr(
self.model.config, "head_dim",
self.config.hidden_size // self.config.num_attention_heads)
elif self.config.model_type == 'gpt_bigcode':
self.kv_cache_specs['num_layers'] = self.config.n_layer
self.kv_cache_specs[
Expand Down