4
4
import random
5
5
from threading import Thread , Event
6
6
from queue import Queue
7
- from lightllm .server .router .dynamic_prompt .cache_controller import HiCacheController , CacheNode , BLOCK_SIZE , HiHostService , HiHostTask
7
+ from lightllm .server .router .dynamic_prompt .cache_controller import (
8
+ HiCacheController ,
9
+ CacheNode ,
10
+ BLOCK_SIZE ,
11
+ HiHostService ,
12
+ HiHostTask ,
13
+ )
14
+
8
15
9
16
class MockMemoryManager :
10
17
"""模拟内存管理器,仅返回连续的索引值"""
18
+
11
19
def __init__ (self ):
12
20
self .current_idx = 0
13
21
self .kvcache_store = {}
14
22
15
23
def alloc (self , size ):
16
24
indices = list (range (self .current_idx , self .current_idx + size ))
17
25
self .current_idx += size
18
- self .store (indices , torch .tensor ([[random .randint (0 , 0xffff ) for __ in range (512 )] for _ in range (size )]))
26
+ self .store (indices , torch .tensor ([[random .randint (0 , 0xFFFF ) for __ in range (512 )] for _ in range (size )]))
19
27
return indices
20
-
28
+
21
29
def load_index_kv_buffer (self , index , load_tensor_dict ):
22
30
self .kvcache_store [index ] = load_tensor_dict ["kv_buffer" ]
23
-
31
+
24
32
def get_index_kv_buffer (self , index ):
25
33
return {"kv_buffer" : self .kvcache_store [index ]}
26
-
34
+
27
35
def to_kvcache (self , indices ):
28
- assert all ([idx in self .kvcache_store for idx in indices ]), f"Not all of { indices } are not found in kvcache_store"
36
+ assert all (
37
+ [idx in self .kvcache_store for idx in indices ]
38
+ ), f"Not all of { indices } are not found in kvcache_store"
29
39
return torch .tensor ([self .kvcache_store [idx ].tolist () for idx in indices ])
30
-
40
+
31
41
def store (self , indices , value ):
32
42
print (f"[TEST:MemManager] Storing { value .shape } at { indices } " )
33
43
for idx , value_dim in zip (indices , range (value .shape [0 ])):
34
44
self .kvcache_store [idx ] = value [value_dim ]
35
45
print (f"[TEST:MemManager] Stored { value [value_dim ].shape } at { idx } " )
36
46
return indices
37
-
47
+
38
48
def free (self , indices ):
39
49
print (f"[TEST:MemManager] Freeing { indices } " )
40
50
for idx in indices :
@@ -46,87 +56,91 @@ def setup():
46
56
service = HiHostService ()
47
57
hicache = HiCacheController (mem_manager )
48
58
hicache .service = service # 注入模拟服务
49
-
59
+
50
60
indices = mem_manager .alloc (5 )
51
61
print (mem_manager .to_kvcache (indices ))
52
-
62
+
53
63
# 预先计算单token大小
54
64
dummy_indices = mem_manager .alloc (1 )
55
65
kvcache = mem_manager .to_kvcache (dummy_indices [:1 ])
56
66
token_size = kvcache .nelement () * kvcache .element_size ()
57
67
print (f"[TEST] Single token KV cache size: { token_size } bytes, Block size: { BLOCK_SIZE } " )
58
-
68
+
59
69
return mem_manager , service , hicache , token_size
60
70
71
+
61
72
def test_basic_write_read (mem_manager , hicache , token_size ):
62
73
# 计算每个块可容纳的token数量
63
74
tokens_per_block = BLOCK_SIZE // token_size
64
75
print (f"[TEST] Each block can hold { tokens_per_block } tokens" )
65
-
76
+
66
77
# 生成测试数据:刚好占满一个块
67
78
token_ids = list (range (tokens_per_block ))
68
79
indices = mem_manager .alloc (len (token_ids ))
69
80
kvcache = mem_manager .to_kvcache (indices )
70
81
print (f"[TEST] Generated KV cache with shape: { kvcache .shape } , type: { kvcache .dtype } " )
71
-
82
+
72
83
# 写入缓存
73
84
hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
74
85
time .sleep (2 )
75
-
86
+
76
87
# 等待任务完成
77
88
hicache .service .wait_till_all_finished ()
78
-
89
+
79
90
mem_manager .free (indices )
80
-
91
+
81
92
# 读取验证
82
93
result = hicache .read (torch .tensor (token_ids ))
83
94
result = mem_manager .to_kvcache (result .tolist ())
84
95
assert result .eq (kvcache ).all (), f"Retrieved kvcache: { result } , Expected kvcache: { kvcache } "
85
- print (f"[TEST] Basic test passed. Retrieved kvcache\n \n " )
96
+ print ("[TEST] Basic test passed. Retrieved kvcache\n \n " )
97
+
86
98
87
99
def test_node_splitting (mem_manager , hicache , token_size ):
88
100
tokens_per_block = BLOCK_SIZE // token_size
89
101
# 生成超过一个块的数据
90
102
token_ids = list (range (12 , 12 + tokens_per_block * 3 + 1 ))
91
103
indices = mem_manager .alloc (len (token_ids ))
92
104
kvcache = mem_manager .to_kvcache (indices )
93
-
105
+
94
106
hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
95
107
time .sleep (2 )
96
108
hicache .service .wait_till_all_finished ()
97
-
109
+
98
110
# 验证根节点应该有子节点
99
111
root = hicache .root
100
112
assert len (root .children ) > 0
101
113
print (f"\n Root node has { len (root .children )} children" )
102
-
114
+
103
115
# 读取完整序列
104
116
result = hicache .read (torch .tensor (token_ids ))
105
117
result = mem_manager .to_kvcache (result .tolist ())
106
118
assert result .eq (kvcache ).all (), f"Retrieved kvcache: { result } , Expected kvcache: { kvcache } "
107
119
print (f"[TEST] Node splitting test passed. Retrieved kvcache: { result .shape } \n \n " )
108
120
121
+
109
122
def test_partial_read (mem_manager , hicache ):
110
123
token_ids = [97 , 98 , 99 , 100 , 101 , 102 ]
111
124
indices = mem_manager .alloc (len (token_ids ))
112
125
kvcache = mem_manager .to_kvcache (indices )
113
126
hicache .write (torch .tensor (token_ids ), torch .tensor (indices ))
114
127
time .sleep (2 )
115
128
hicache .service .wait_till_all_finished ()
116
-
129
+
117
130
# 查询存在的部分前缀
118
131
result = hicache .read (torch .tensor ([97 , 98 , 99 ]))
119
132
result = mem_manager .to_kvcache (result .tolist ())
120
133
assert result .eq (kvcache [:3 ]).all ()
121
- print (f "[TEST] Partial read passed" )
122
-
134
+ print ("[TEST] Partial read passed" )
135
+
123
136
# 查询不存在的前缀
124
137
result = hicache .read (torch .tensor ([97 , 98 , 100 ]))
125
138
assert len (result ) == 2
126
139
result = mem_manager .to_kvcache (result .tolist ())
127
140
assert result .eq (kvcache [:2 ]).all ()
128
141
print (f"[TEST] Non-existent prefix returned: { result .tolist ()} " )
129
142
143
+
130
144
def main ():
131
145
mem_manager , service , hicache , token_size = setup ()
132
146
try :
@@ -136,5 +150,6 @@ def main():
136
150
finally :
137
151
service .shutdown ()
138
152
153
+
139
154
if __name__ == "__main__" :
140
- main ()
155
+ main ()
0 commit comments