Skip to content

Commit 263fb6e

Browse files
committed
format3
1 parent df90384 commit 263fb6e

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def decode(self):
3434
prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
3535
)
3636
logits = self.model.forward(**kwargs)
37-
37+
3838
self.store_hicache_after_prefill(run_reqs)
3939

4040
self._overlap_req_init_and_filter(

test/server/test_hicache.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,47 @@
44
import random
55
from threading import Thread, Event
66
from queue import Queue
7-
from lightllm.server.router.dynamic_prompt.cache_controller import HiCacheController, CacheNode, BLOCK_SIZE, HiHostService, HiHostTask
7+
from lightllm.server.router.dynamic_prompt.cache_controller import (
8+
HiCacheController,
9+
CacheNode,
10+
BLOCK_SIZE,
11+
HiHostService,
12+
HiHostTask,
13+
)
14+
815

916
class MockMemoryManager:
1017
"""模拟内存管理器,仅返回连续的索引值"""
18+
1119
def __init__(self):
1220
self.current_idx = 0
1321
self.kvcache_store = {}
1422

1523
def alloc(self, size):
1624
indices = list(range(self.current_idx, self.current_idx + size))
1725
self.current_idx += size
18-
self.store(indices, torch.tensor([[random.randint(0, 0xffff) for __ in range(512)] for _ in range(size)]))
26+
self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)]))
1927
return indices
20-
28+
2129
def load_index_kv_buffer(self, index, load_tensor_dict):
2230
self.kvcache_store[index] = load_tensor_dict["kv_buffer"]
23-
31+
2432
def get_index_kv_buffer(self, index):
2533
return {"kv_buffer": self.kvcache_store[index]}
26-
34+
2735
def to_kvcache(self, indices):
28-
assert all([idx in self.kvcache_store for idx in indices]), f"Not all of {indices} are not found in kvcache_store"
36+
assert all(
37+
[idx in self.kvcache_store for idx in indices]
38+
), f"Not all of {indices} are not found in kvcache_store"
2939
return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices])
30-
40+
3141
def store(self, indices, value):
3242
print(f"[TEST:MemManager] Storing {value.shape} at {indices}")
3343
for idx, value_dim in zip(indices, range(value.shape[0])):
3444
self.kvcache_store[idx] = value[value_dim]
3545
print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}")
3646
return indices
37-
47+
3848
def free(self, indices):
3949
print(f"[TEST:MemManager] Freeing {indices}")
4050
for idx in indices:
@@ -46,87 +56,91 @@ def setup():
4656
service = HiHostService()
4757
hicache = HiCacheController(mem_manager)
4858
hicache.service = service # 注入模拟服务
49-
59+
5060
indices = mem_manager.alloc(5)
5161
print(mem_manager.to_kvcache(indices))
52-
62+
5363
# 预先计算单token大小
5464
dummy_indices = mem_manager.alloc(1)
5565
kvcache = mem_manager.to_kvcache(dummy_indices[:1])
5666
token_size = kvcache.nelement() * kvcache.element_size()
5767
print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}")
58-
68+
5969
return mem_manager, service, hicache, token_size
6070

71+
6172
def test_basic_write_read(mem_manager, hicache, token_size):
6273
# 计算每个块可容纳的token数量
6374
tokens_per_block = BLOCK_SIZE // token_size
6475
print(f"[TEST] Each block can hold {tokens_per_block} tokens")
65-
76+
6677
# 生成测试数据:刚好占满一个块
6778
token_ids = list(range(tokens_per_block))
6879
indices = mem_manager.alloc(len(token_ids))
6980
kvcache = mem_manager.to_kvcache(indices)
7081
print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}")
71-
82+
7283
# 写入缓存
7384
hicache.write(torch.tensor(token_ids), torch.tensor(indices))
7485
time.sleep(2)
75-
86+
7687
# 等待任务完成
7788
hicache.service.wait_till_all_finished()
78-
89+
7990
mem_manager.free(indices)
80-
91+
8192
# 读取验证
8293
result = hicache.read(torch.tensor(token_ids))
8394
result = mem_manager.to_kvcache(result.tolist())
8495
assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}"
85-
print(f"[TEST] Basic test passed. Retrieved kvcache\n\n")
96+
print("[TEST] Basic test passed. Retrieved kvcache\n\n")
97+
8698

8799
def test_node_splitting(mem_manager, hicache, token_size):
88100
tokens_per_block = BLOCK_SIZE // token_size
89101
# 生成超过一个块的数据
90102
token_ids = list(range(12, 12 + tokens_per_block * 3 + 1))
91103
indices = mem_manager.alloc(len(token_ids))
92104
kvcache = mem_manager.to_kvcache(indices)
93-
105+
94106
hicache.write(torch.tensor(token_ids), torch.tensor(indices))
95107
time.sleep(2)
96108
hicache.service.wait_till_all_finished()
97-
109+
98110
# 验证根节点应该有子节点
99111
root = hicache.root
100112
assert len(root.children) > 0
101113
print(f"\nRoot node has {len(root.children)} children")
102-
114+
103115
# 读取完整序列
104116
result = hicache.read(torch.tensor(token_ids))
105117
result = mem_manager.to_kvcache(result.tolist())
106118
assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}"
107119
print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n")
108120

121+
109122
def test_partial_read(mem_manager, hicache):
110123
token_ids = [97, 98, 99, 100, 101, 102]
111124
indices = mem_manager.alloc(len(token_ids))
112125
kvcache = mem_manager.to_kvcache(indices)
113126
hicache.write(torch.tensor(token_ids), torch.tensor(indices))
114127
time.sleep(2)
115128
hicache.service.wait_till_all_finished()
116-
129+
117130
# 查询存在的部分前缀
118131
result = hicache.read(torch.tensor([97, 98, 99]))
119132
result = mem_manager.to_kvcache(result.tolist())
120133
assert result.eq(kvcache[:3]).all()
121-
print(f"[TEST] Partial read passed")
122-
134+
print("[TEST] Partial read passed")
135+
123136
# 查询不存在的前缀
124137
result = hicache.read(torch.tensor([97, 98, 100]))
125138
assert len(result) == 2
126139
result = mem_manager.to_kvcache(result.tolist())
127140
assert result.eq(kvcache[:2]).all()
128141
print(f"[TEST] Non-existent prefix returned: {result.tolist()}")
129142

143+
130144
def main():
131145
mem_manager, service, hicache, token_size = setup()
132146
try:
@@ -136,5 +150,6 @@ def main():
136150
finally:
137151
service.shutdown()
138152

153+
139154
if __name__ == "__main__":
140-
main()
155+
main()

0 commit comments

Comments
 (0)