diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index b4764d644..5ed3b0c6e 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -145,6 +145,10 @@ def is_pin_memory_available(cls): def get_punica_wrapper(cls) -> str: return "vllm_gaudi.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + @classmethod def get_device_communicator_cls(cls) -> str: return "vllm_gaudi.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 240d7b02b..3e239a152 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4641,6 +4641,111 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None: self._prepare_dummy_scenario(prompt_cfg, decode_cfg) return + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, "Attention backends are already initialized" + + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config( + self.vllm_config, layer_type, kv_cache_group_spec.layer_names + ) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in kv_cache_group_spec.layer_names: + attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, # type: ignore[arg-type] + ) + + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) + attn_backend_layers[key].append(layer_name) + return ( + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), + ) + + def create_attn_groups( + attn_backends_map: dict[AttentionGroupKey, list[str]], + kv_cache_group_id: int, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): + attn_group = AttentionGroup( + attn_backend, + layer_names, + kv_cache_spec, + kv_cache_group_id, + ) + + attn_groups.append(attn_group) + return attn_groups + + attention_backend_maps = [] + attention_backend_list = [] + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + attn_backends = get_attn_backends_for_group(kv_cache_group_spec) + attention_backend_maps.append(attn_backends[0]) + attention_backend_list.append(attn_backends[1]) + + # Resolve cudagraph_mode before actually initialize metadata_builders + self._check_and_update_cudagraph_mode( + attention_backend_list, kv_cache_config.kv_cache_groups + ) + + for i, attn_backend_map in enumerate(attention_backend_maps): + self.attn_groups.append(create_attn_groups(attn_backend_map, i)) + + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor] + ) -> None: + """ + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). + + Args: + kv_caches: The KV cache buffer of each layer. + """ + + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + for layer_name in group.layer_names: + kv_cache = kv_caches[layer_name] + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " + f"a tensor of shape {kv_cache.shape}" + ) + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -4648,9 +4753,34 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError("Hybrid models with more than one KV cache type are not " - "supported yet.") + # if len(kv_cache_config.kv_cache_groups) > 1: + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + if block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=max(self.max_model_len, self.max_encoder_len), + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + ) + + self.initialize_attn_backend(kv_cache_config) + # build a map from layer_name -> KVCacheTensor tensor_map: dict[str, KVCacheTensor] = {} @@ -4664,6 +4794,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert len(kv_cache_tensor.shared_by) == 1 kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + has_mamba = False for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for kv_cache_tensor in kv_cache_config.kv_cache_tensors: @@ -4692,6 +4823,30 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: value_cache = None for layer_name in kv_cache_tensor.shared_by: kv_caches[layer_name] = (key_cache, value_cache) + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + raw_tensor = kv_cache_raw_tensors[layer_name] + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + + kv_caches[layer_name] = state_tensors else: # TODO: add new branches when introducing more types of # KV cache specs. @@ -4722,6 +4877,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: logger.info("Allocating unified persistent batch took %.4f GB of host memory", m.consumed_host_memory / float(2**30)) + if has_mamba: + self._update_hybrid_attention_mamba_layout(kv_caches) + htorch.hpu.synchronize() def get_kv_caches_4D(self, kv_caches) -> dict[str, torch.Tensor]: