Skip to content

Support disk radix cache #837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache")

parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size")
parser.add_argument("--use_hi_dynamic_prompt_cache", action="store_true", help="enable hierachy prompt cache")
parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
Expand Down Expand Up @@ -311,7 +312,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch"
)
parser.add_argument(
"--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2"
"--visual_gpu_ids", nargs="+", type=int, default=[0, 1, 2, 3, 4, 5, 6, 7], help="List of GPU IDs to use, e.g., 0 1 2"
)
parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT")
parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT")
Expand Down
4 changes: 4 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def normal_or_p_d_start(args):
args.batch_max_tokens >= args.chunked_prefill_size
), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size"

# if use_hi_dynamic_prompt_cache, then use_dynamic_prompt_cache must be True
if args.use_hi_dynamic_prompt_cache:
assert not args.disable_dynamic_prompt_cache, "use_hi_dynamic_prompt_cache must be used with use_dynamic_prompt_cache"

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class StartArgs:
router_max_wait_tokens: int = field(default=6)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
use_hi_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=8192)
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
Expand Down
128 changes: 128 additions & 0 deletions lightllm/server/router/dynamic_prompt/hiradix_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import time
import tempfile
import numpy as np
import torch.distributed as dist
from os.path import join
from .radix_cache import RadixCache, TreeNode, match
from typing import Tuple, Dict, Set, List
from lightllm.common.mem_manager import MemoryManager
from lightllm.utils.log_utils import init_logger
from threading import Lock
from enum import Enum
from .shared_arr import SharedArray
from kvcache.python.jit import PyLocalCacheService

logger = init_logger(__name__)

def wait_until_ready(task, timeout=10.0, check_interval=0.01):
start_time = time.time()
while not task.ready():
time.sleep(check_interval)
if time.time() - start_time > timeout:
logger.error("Current kv cache task not ready in time")
return False
return True

class LocalCacheManager:

def __init__(self, unique_name: str, rank_in_node: int, mem_manager):
tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}")
self.cache_file = join(tmp_dir, "cache_file")
all_buffers = mem_manager.kv_buffer
all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1)

self.py_cache_service = PyLocalCacheService(
file=self.cache_file,
storage_size=128 * (1024 ** 3), # 128GB
num_shard=32,
kvcache_tensor=all_buffers,
num_worker=8
)

def insert(self, tokens, kv_page_indexer, start_pos=0):
t = self.py_cache_service.create(
tokens=tokens,
kv_page_indexer=kv_page_indexer,
mode="w",
start_pos=start_pos)
res = wait_until_ready(t)
if not res:
self.py_cache_service.az5(t)

def read(self, tokens, kv_page_indexer, start_pos=0):
t = self.py_cache_service.create(
tokens=tokens,
kv_page_indexer=kv_page_indexer,
mode="r",
start_pos=start_pos)
res = wait_until_ready(t)
return res

def query(self, tokens):
query_result = self.py_cache_service.query(tokens)
max_len = 0
for result in query_result:
if result:
max_len += 1
else:
break
return max_len * self.block_size

@property
def block_size(self,):
return self.py_cache_service.tokens_per_block

class HiRadixCache(RadixCache):
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager):
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
self.rank_in_node = rank_in_node
self.local_cache_manager = LocalCacheManager(
unique_name,
rank_in_node,
mem_manager,
)
self.is_hi_radix_cache = True
self.disk_cache_match_count = SharedArray(f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64)
self.disk_cache_match_count.arr[0] = 0
self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64)
self.total_match_count.arr[0] = 0
self.disk_cache_match_ratio = SharedArray(f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32)
self.disk_cache_match_ratio.arr[0] = 0.0
logger.info(f"Initializing HiRadixCache {rank_in_node}")

def insert(self, key, value=None):
share_len = super().insert(key, value)
if share_len == 0:
return 0
self.local_cache_manager.insert(key, value)
return share_len

def match_prefix(self, key, update_refs=False):
assert len(key) != 0
self.total_match_count.arr[0] += 1
ans_value_list = []
ans_value = None
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
if tree_node.node_prefix_total_len != 0:
ans_value = torch.concat(ans_value_list)
max_len = 0
if tree_node.node_prefix_total_len < len(key):
max_len = self.local_cache_manager.query(key)
if max_len > tree_node.node_prefix_total_len:
pull_len = max_len - tree_node.node_prefix_total_len
self.disk_cache_match_count.arr[0] += 1
self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0]
self.free_radix_cache_to_get_enough_token(pull_len)
buffers = self.mem_manager.alloc(pull_len)
start_pos = 0
if ans_value is not None:
buffers = torch.concat([ans_value, buffers])
start_pos = (tree_node.node_prefix_total_len - 1) // self.local_cache_manager.block_size * self.local_cache_manager.block_size
logger.debug(f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}, pulled cache len {pull_len} from disk")
res = self.local_cache_manager.read(tokens=key[:max_len], kv_page_indexer=buffers, start_pos=start_pos)
if res:
super().insert(key[:max_len], buffers)
else:
self.mem_manager.free(buffers[tree_node.node_prefix_total_len:])
return super().match_prefix(key, update_refs=update_refs)
2 changes: 2 additions & 0 deletions lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo
)
self.tree_total_tokens_num.arr[0] = 0

self.is_hi_radix_cache = False

def insert(self, key, value=None):
if value is None:
value = key
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def wait_to_model_ready(self):
"return_all_prompt_logprobs": self.args.return_all_prompt_logprobs,
"use_reward_model": self.args.use_reward_model,
"disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache,
"use_hi_dynamic_prompt_cache": self.args.use_hi_dynamic_prompt_cache,
"data_type": self.args.data_type,
"eos_id": self.eos_id,
"diverse_mode": self.args.diverse_mode,
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
req.shared_kv_node = None


def _save_promptcache_kvbuffer(self):
"""
save prompt cache kv buffer
Expand Down
11 changes: 10 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightllm.utils.log_utils import init_logger
from lightllm.models import get_model
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.server.router.dynamic_prompt.hiradix_cache import HiRadixCache
from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams
from lightllm.server.router.token_load import TokenLoad
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
Expand Down Expand Up @@ -51,6 +52,7 @@ def init_model(self, kvargs):
self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = not kvargs.get("disable_dynamic_prompt_cache", False)
self.use_hi_dynamic_prompt_cache = kvargs.get("use_hi_dynamic_prompt_cache", False)
self.eos_id: List[int] = kvargs.get("eos_id", [2])
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)

Expand Down Expand Up @@ -115,7 +117,14 @@ def init_model(self, kvargs):
self.model, self.is_multimodal = get_model(model_cfg, model_kvargs)
set_random_seed(2147483647)
self.radix_cache = (
RadixCache(
HiRadixCache(
get_unique_server_name(),
self.model.mem_manager.size,
self.rank_in_node,
mem_manager=self.model.mem_manager
)
if self.use_dynamic_prompt_cache and self.use_hi_dynamic_prompt_cache
else RadixCache(
get_unique_server_name(),
self.model.mem_manager.size,
self.rank_in_node,
Expand Down
155 changes: 155 additions & 0 deletions test/server/test_hicache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# test_hicache.py
import torch
import time
import random
from threading import Thread, Event
from queue import Queue
from lightllm.server.router.dynamic_prompt.cache_controller import (
HiCacheController,
CacheNode,
BLOCK_SIZE,
HiHostService,
HiHostTask,
)


class MockMemoryManager:
"""模拟内存管理器,仅返回连续的索引值"""

def __init__(self):
self.current_idx = 0
self.kvcache_store = {}

def alloc(self, size):
indices = list(range(self.current_idx, self.current_idx + size))
self.current_idx += size
self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)]))
return indices

def load_index_kv_buffer(self, index, load_tensor_dict):
self.kvcache_store[index] = load_tensor_dict["kv_buffer"]

def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kvcache_store[index]}

def to_kvcache(self, indices):
assert all(
[idx in self.kvcache_store for idx in indices]
), f"Not all of {indices} are not found in kvcache_store"
return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices])

def store(self, indices, value):
print(f"[TEST:MemManager] Storing {value.shape} at {indices}")
for idx, value_dim in zip(indices, range(value.shape[0])):
self.kvcache_store[idx] = value[value_dim]
print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}")
return indices

def free(self, indices):
print(f"[TEST:MemManager] Freeing {indices}")
for idx in indices:
del self.kvcache_store[idx]


def setup():
mem_manager = MockMemoryManager()
service = HiHostService()
hicache = HiCacheController(mem_manager)
hicache.service = service # 注入模拟服务

indices = mem_manager.alloc(5)
print(mem_manager.to_kvcache(indices))

# 预先计算单token大小
dummy_indices = mem_manager.alloc(1)
kvcache = mem_manager.to_kvcache(dummy_indices[:1])
token_size = kvcache.nelement() * kvcache.element_size()
print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}")

return mem_manager, service, hicache, token_size


def test_basic_write_read(mem_manager, hicache, token_size):
# 计算每个块可容纳的token数量
tokens_per_block = BLOCK_SIZE // token_size
print(f"[TEST] Each block can hold {tokens_per_block} tokens")

# 生成测试数据:刚好占满一个块
token_ids = list(range(tokens_per_block))
indices = mem_manager.alloc(len(token_ids))
kvcache = mem_manager.to_kvcache(indices)
print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}")

# 写入缓存
hicache.write(torch.tensor(token_ids), torch.tensor(indices))
time.sleep(2)

# 等待任务完成
hicache.service.wait_till_all_finished()

mem_manager.free(indices)

# 读取验证
result = hicache.read(torch.tensor(token_ids))
result = mem_manager.to_kvcache(result.tolist())
assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}"
print("[TEST] Basic test passed. Retrieved kvcache\n\n")


def test_node_splitting(mem_manager, hicache, token_size):
tokens_per_block = BLOCK_SIZE // token_size
# 生成超过一个块的数据
token_ids = list(range(12, 12 + tokens_per_block * 3 + 1))
indices = mem_manager.alloc(len(token_ids))
kvcache = mem_manager.to_kvcache(indices)

hicache.write(torch.tensor(token_ids), torch.tensor(indices))
time.sleep(2)
hicache.service.wait_till_all_finished()

# 验证根节点应该有子节点
root = hicache.root
assert len(root.children) > 0
print(f"\nRoot node has {len(root.children)} children")

# 读取完整序列
result = hicache.read(torch.tensor(token_ids))
result = mem_manager.to_kvcache(result.tolist())
assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}"
print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n")


def test_partial_read(mem_manager, hicache):
token_ids = [97, 98, 99, 100, 101, 102]
indices = mem_manager.alloc(len(token_ids))
kvcache = mem_manager.to_kvcache(indices)
hicache.write(torch.tensor(token_ids), torch.tensor(indices))
time.sleep(2)
hicache.service.wait_till_all_finished()

# 查询存在的部分前缀
result = hicache.read(torch.tensor([97, 98, 99]))
result = mem_manager.to_kvcache(result.tolist())
assert result.eq(kvcache[:3]).all()
print("[TEST] Partial read passed")

# 查询不存在的前缀
result = hicache.read(torch.tensor([97, 98, 100]))
assert len(result) == 2
result = mem_manager.to_kvcache(result.tolist())
assert result.eq(kvcache[:2]).all()
print(f"[TEST] Non-existent prefix returned: {result.tolist()}")


def main():
mem_manager, service, hicache, token_size = setup()
try:
test_basic_write_read(mem_manager, hicache, token_size)
test_node_splitting(mem_manager, hicache, token_size)
test_partial_read(mem_manager, hicache)
finally:
service.shutdown()


if __name__ == "__main__":
main()
Loading