Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm_gaudi/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def is_pin_memory_available(cls):
def get_punica_wrapper(cls) -> str:
return "vllm_gaudi.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"

@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_gaudi.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
Expand Down
164 changes: 161 additions & 3 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4641,16 +4641,146 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None:
self._prepare_dummy_scenario(prompt_cfg, decode_cfg)
return

def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
assert len(self.attn_groups) == 0, "Attention backends are already initialized"

class AttentionGroupKey(NamedTuple):
attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec

def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(
self.vllm_config, layer_type, kv_cache_group_spec.layer_names
)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in kv_cache_group_spec.layer_names:
attn_backend = layers[layer_name].get_attn_backend()

if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend(
"FastPrefill",
attn_backend, # type: ignore[arg-type]
)

full_cls_name = attn_backend.full_cls_name()
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (full_cls_name, layer_kv_cache_spec)
attn_backends[key] = AttentionGroupKey(
attn_backend, layer_kv_cache_spec
)
attn_backend_layers[key].append(layer_name)
return (
{attn_backends[k]: v for k, v in attn_backend_layers.items()},
set(group_key.attn_backend for group_key in attn_backends.values()),
)

def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup(
attn_backend,
layer_names,
kv_cache_spec,
kv_cache_group_id,
)

attn_groups.append(attn_group)
return attn_groups

attention_backend_maps = []
attention_backend_list = []
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
attention_backend_maps.append(attn_backends[0])
attention_backend_list.append(attn_backends[1])

# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(
attention_backend_list, kv_cache_config.kv_cache_groups
)

for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))

def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]
) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
(num_blocks, 2, ...).

Args:
kv_caches: The KV cache buffer of each layer.
"""

for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
assert kv_cache.shape[1] != 2, (
"Fail to determine whether the layout is "
"(2, num_blocks, ...) or (num_blocks, 2, ...) for "
f"a tensor of shape {kv_cache.shape}"
)
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(
size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
)

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError("Hybrid models with more than one KV cache type are not "
"supported yet.")
# if len(kv_cache_config.kv_cache_groups) > 1:
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
if block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
)

self.initialize_attn_backend(kv_cache_config)


# build a map from layer_name -> KVCacheTensor
tensor_map: dict[str, KVCacheTensor] = {}
Expand All @@ -4664,6 +4794,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
assert len(kv_cache_tensor.shared_by) == 1
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size

has_mamba = False
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
Expand Down Expand Up @@ -4692,6 +4823,30 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
value_cache = None
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = (key_cache, value_cache)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
state_tensors = []
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset_bytes += stride[0] * dtype_size

kv_caches[layer_name] = state_tensors
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
Expand Down Expand Up @@ -4722,6 +4877,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
logger.info("Allocating unified persistent batch took %.4f GB of host memory",
m.consumed_host_memory / float(2**30))

if has_mamba:
self._update_hybrid_attention_mamba_layout(kv_caches)

htorch.hpu.synchronize()

def get_kv_caches_4D(self, kv_caches) -> dict[str, torch.Tensor]:
Expand Down
Loading