From bf8c9c8d56898f510799ad49fbeef63c23299042 Mon Sep 17 00:00:00 2001 From: ZeldaHuang Date: Wed, 19 Nov 2025 20:32:31 +0800 Subject: [PATCH 1/3] fix --- mooncake-wheel/mooncake/mooncake_connector_v1.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mooncake-wheel/mooncake/mooncake_connector_v1.py b/mooncake-wheel/mooncake/mooncake_connector_v1.py index fd607460a..852122be8 100644 --- a/mooncake-wheel/mooncake/mooncake_connector_v1.py +++ b/mooncake-wheel/mooncake/mooncake_connector_v1.py @@ -22,7 +22,7 @@ import torch import zmq -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.selector import AttentionBackendEnum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -30,8 +30,7 @@ get_tp_group) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend -from vllm.utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -114,8 +113,9 @@ class MooncakeConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None + super().__init__(vllm_config, role) self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id - + if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ MooncakeConnectorScheduler(vllm_config, self.engine_id) @@ -425,12 +425,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, - self.model_config.is_attention_free, use_mla=self.use_mla) self.backend_name = backend.get_name() - attn_backend = backend_name_to_enum(self.backend_name) - self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 - self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend in [AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER_MLA] + self._use_pallas_v1 = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) From 24e164d9086bb76ec7d5fa749f6a9c7826b56f5e Mon Sep 17 00:00:00 2001 From: ZeldaHuang Date: Wed, 19 Nov 2025 20:38:41 +0800 Subject: [PATCH 2/3] lint --- mooncake-wheel/mooncake/mooncake_connector_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mooncake-wheel/mooncake/mooncake_connector_v1.py b/mooncake-wheel/mooncake/mooncake_connector_v1.py index 852122be8..79e8dc25b 100644 --- a/mooncake-wheel/mooncake/mooncake_connector_v1.py +++ b/mooncake-wheel/mooncake/mooncake_connector_v1.py @@ -115,7 +115,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config.engine_id is not None super().__init__(vllm_config, role) self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id - + if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ MooncakeConnectorScheduler(vllm_config, self.engine_id) From 80baa5543eae763908dbcab36df3ab723a85bd54 Mon Sep 17 00:00:00 2001 From: ZeldaHuang Date: Mon, 24 Nov 2025 19:03:30 +0800 Subject: [PATCH 3/3] adaption for old version --- .../mooncake/mooncake_connector_v1.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/mooncake-wheel/mooncake/mooncake_connector_v1.py b/mooncake-wheel/mooncake/mooncake_connector_v1.py index 79e8dc25b..8dc29b89e 100644 --- a/mooncake-wheel/mooncake/mooncake_connector_v1.py +++ b/mooncake-wheel/mooncake/mooncake_connector_v1.py @@ -9,6 +9,7 @@ import contextlib import threading import time +import importlib.metadata from collections import defaultdict from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor @@ -22,7 +23,7 @@ import torch import zmq -from vllm.attention.selector import AttentionBackendEnum, get_attn_backend +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -30,7 +31,10 @@ get_tp_group) from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +try: + from vllm.utils import get_ip, make_zmq_path, make_zmq_socket +except ImportError: + from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -115,7 +119,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config.engine_id is not None super().__init__(vllm_config, role) self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id - + if role == KVConnectorRole.SCHEDULER: self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ MooncakeConnectorScheduler(vllm_config, self.engine_id) @@ -427,9 +431,20 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.block_size, use_mla=self.use_mla) self.backend_name = backend.get_name() - attn_backend = AttentionBackendEnum[self.backend_name] - self._use_flashinfer = attn_backend in [AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER_MLA] - self._use_pallas_v1 = attn_backend == AttentionBackendEnum.PALLAS + vllm_version = importlib.metadata.version("vllm") + if vllm_version.startswith("0.11.0"): + from vllm.attention.selector import backend_name_to_enum + from vllm.platforms import _Backend + attn_backend = backend_name_to_enum(self.backend_name) + self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 + elif vllm_version.startswith("0.11.1") or vllm_version.startswith("0.11.2"): + from vllm.attention.selector import AttentionBackendEnum + attn_backend = AttentionBackendEnum[self.backend_name] + self._use_flashinfer = attn_backend in [AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER_MLA] + self._use_pallas_v1 = attn_backend == AttentionBackendEnum.PALLAS + else: + raise Exception("Unsupported vllm version %s", vllm_version) self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -758,4 +773,4 @@ def group_concurrent_contiguous( src_groups = [g.tolist() for g in src_groups] dst_groups = [g.tolist() for g in dst_groups] - return src_groups, dst_groups + return src_groups, dst_groups \ No newline at end of file