1
1
import torch
2
+ import torch .distributed as dist
2
3
from .radix_cache import RadixCache , TreeNode , match
3
4
from typing import Tuple , Dict , Set , List
4
5
from lightllm .common .mem_manager import MemoryManager
@@ -23,12 +24,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
23
24
self .is_hi_radix_cache = True
24
25
all_buffers = self .mem_manager .kv_buffer
25
26
all_buffers = all_buffers .view (all_buffers .shape [0 ], all_buffers .shape [1 ], - 1 )
26
- self .py_cache_service = PyLocalCacheService (
27
- file = "cache/cache_file" ,
28
- storage_size = 128 * (1024 ** 3 ),
29
- num_shard = 32 ,
30
- kvcache_tensor = all_buffers ,
31
- num_worker = 32 ,
27
+ self .py_cache_service = (
28
+ PyLocalCacheService (
29
+ file = "cache/cache_file" ,
30
+ storage_size = 128 * (1024 ** 3 ),
31
+ num_shard = 32 ,
32
+ kvcache_tensor = all_buffers ,
33
+ num_worker = 32 ,
34
+ )
35
+ if self .do_store
36
+ else None
32
37
)
33
38
self .working_tasks = {}
34
39
except Exception as e :
@@ -48,7 +53,7 @@ def insert_disk(self, req_id, key, value):
48
53
logger .info (f"Created store task for req { req_id } ." )
49
54
50
55
def abort_req_store_task (self , req_id ):
51
- if not self .do_store :
56
+ if not self .do_store or req_id not in self . working_tasks :
52
57
return
53
58
if self .working_tasks [req_id ].ready ():
54
59
logger .info (f"Calling abort for req { req_id } , but is finished." )
@@ -126,48 +131,63 @@ def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, upd
126
131
self .evict_tree_set .add (node )
127
132
128
133
def match_prefix (self , key , update_refs = False ):
129
- st_time = time .time ()
130
134
assert len (key ) != 0
131
135
ans_value_list = []
132
- tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
133
- # add a parameter if get long enough (>50%)
134
- first_query_time = time .time ()
135
- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.1 First GPU query took { first_query_time - st_time } " )
136
- max_len = self ._query_hi_cache (key ) # x64
137
- hi_cache_query_time = time .time ()
138
- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.2 Disk query took { hi_cache_query_time - first_query_time } " )
139
- logger .info (f"Matched { sum (len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
136
+ pull_hi_cache_tensor = torch .tensor ([0 ], dtype = torch .int64 ).cuda (self .rank_in_node )
137
+ if self .do_store :
138
+ # st_time = time.time()
139
+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
140
+ # add a parameter if get long enough (>50%)
141
+ # first_query_time = time.time()
142
+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s")
143
+ max_len = self ._query_hi_cache (key ) # x64
144
+ # hi_cache_q_time = time.time()
145
+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s")
146
+ logger .info (f"Matched { sum (len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
147
+ pull_hi_cache_tensor [0 ] = max_len if (max_len > sum (len (s ) for s in ans_value_list )) else 0
148
+ # hi_cache_q_time = time.time()
149
+ dist .broadcast (pull_hi_cache_tensor , src = 0 )
150
+ # logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}")
140
151
pull_hi_cache = False
141
- if max_len > sum (len (s ) for s in ans_value_list ):
152
+ # logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}")
153
+
154
+ if pull_hi_cache_tensor [0 ] == 0 and not self .do_store :
155
+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
156
+ elif pull_hi_cache_tensor [0 ] > 0 :
142
157
pull_hi_cache = True
158
+ max_len = pull_hi_cache_tensor [0 ]
143
159
try :
144
160
self .free_radix_cache_to_get_enough_token (max_len )
145
161
except :
146
- if update_refs :
147
- tree_node = self . _match_prefix_helper ( self . root_node , key , ans_value_list , update_refs = update_refs )
162
+ logger . info ( f"Unable to free on rank { self . rank_in_node } " )
163
+ pull_hi_cache_tensor [ 0 ] = 0
148
164
pull_hi_cache = False
165
+ ans_value_list = []
166
+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
149
167
if pull_hi_cache :
150
168
buffers = self .mem_manager .alloc (max_len )
151
- before_pull_time = time .time ()
152
- logger .info (
153
- f"HiCache of [{ self .rank_in_node } ]: No.2.5 Before pull took { before_pull_time - hi_cache_query_time } "
154
- )
155
- read_task = self .py_cache_service .create (tokens = key [:max_len ], kv_page_indexer = buffers , mode = "r" )
156
- while not read_task .ready ():
157
- time .sleep (0.1 )
158
- hicache_pull_time = time .time ()
159
- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.3 Disk pull took { hicache_pull_time - before_pull_time } " )
169
+ # before_pull_time = time.time()
170
+ # logger.info(
171
+ # f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}"
172
+ # )
173
+ if self .do_store :
174
+ read_task = self .py_cache_service .create (tokens = key [:max_len ], kv_page_indexer = buffers , mode = "r" )
175
+ while not read_task .ready ():
176
+ time .sleep (0.05 )
177
+ dist .broadcast (self .mem_manager .get_index_kv_buffer (buffers )["kv_buffer" ], src = 0 )
178
+ # hicache_pull_time = time.time()
179
+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s")
160
180
logger .info (f"HiCache pulled one cache with len = { max_len } " )
161
181
# maybe try: add a function to only insert middle part of kv cache
162
182
self ._insert_helper (self .root_node , key , buffers )
163
- insert_time = time .time ()
164
- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.4 Reinsert took { insert_time - hicache_pull_time } " )
183
+ # insert_time = time.time()
184
+ # logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
165
185
ans_value_list = []
166
186
tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
167
- logger .info (
168
- f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } "
169
- + f" matched { sum (len (s ) for s in ans_value_list )} tokens"
170
- )
187
+ # logger.info(
188
+ # f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
189
+ # + f" matched {sum(len(s) for s in ans_value_list)} tokens"
190
+ # )
171
191
if tree_node != self .root_node :
172
192
if len (ans_value_list ) != 0 :
173
193
value = torch .concat (ans_value_list )
0 commit comments