Skip to content

Commit c54f20f

Browse files
committed
fixed problems
1 parent c1b2aaf commit c54f20f

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

lightllm/server/router/dynamic_prompt/hiradix_cache.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.distributed as dist
23
from .radix_cache import RadixCache, TreeNode, match
34
from typing import Tuple, Dict, Set, List
45
from lightllm.common.mem_manager import MemoryManager
@@ -23,12 +24,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
2324
self.is_hi_radix_cache = True
2425
all_buffers = self.mem_manager.kv_buffer
2526
all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1)
26-
self.py_cache_service = PyLocalCacheService(
27-
file="cache/cache_file",
28-
storage_size=128 * (1024 ** 3),
29-
num_shard=32,
30-
kvcache_tensor=all_buffers,
31-
num_worker=32,
27+
self.py_cache_service = (
28+
PyLocalCacheService(
29+
file="cache/cache_file",
30+
storage_size=128 * (1024 ** 3),
31+
num_shard=32,
32+
kvcache_tensor=all_buffers,
33+
num_worker=32,
34+
)
35+
if self.do_store
36+
else None
3237
)
3338
self.working_tasks = {}
3439
except Exception as e:
@@ -48,7 +53,7 @@ def insert_disk(self, req_id, key, value):
4853
logger.info(f"Created store task for req {req_id}.")
4954

5055
def abort_req_store_task(self, req_id):
51-
if not self.do_store:
56+
if not self.do_store or req_id not in self.working_tasks:
5257
return
5358
if self.working_tasks[req_id].ready():
5459
logger.info(f"Calling abort for req {req_id}, but is finished.")
@@ -126,48 +131,63 @@ def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, upd
126131
self.evict_tree_set.add(node)
127132

128133
def match_prefix(self, key, update_refs=False):
129-
st_time = time.time()
130134
assert len(key) != 0
131135
ans_value_list = []
132-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
133-
# add a parameter if get long enough (>50%)
134-
first_query_time = time.time()
135-
logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}")
136-
max_len = self._query_hi_cache(key) # x64
137-
hi_cache_query_time = time.time()
138-
logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query took {hi_cache_query_time - first_query_time}")
139-
logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.")
136+
pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node)
137+
if self.do_store:
138+
# st_time = time.time()
139+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
140+
# add a parameter if get long enough (>50%)
141+
# first_query_time = time.time()
142+
# logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s")
143+
max_len = self._query_hi_cache(key) # x64
144+
# hi_cache_q_time = time.time()
145+
# logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s")
146+
logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.")
147+
pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0
148+
# hi_cache_q_time = time.time()
149+
dist.broadcast(pull_hi_cache_tensor, src=0)
150+
# logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}")
140151
pull_hi_cache = False
141-
if max_len > sum(len(s) for s in ans_value_list):
152+
# logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}")
153+
154+
if pull_hi_cache_tensor[0] == 0 and not self.do_store:
155+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
156+
elif pull_hi_cache_tensor[0] > 0:
142157
pull_hi_cache = True
158+
max_len = pull_hi_cache_tensor[0]
143159
try:
144160
self.free_radix_cache_to_get_enough_token(max_len)
145161
except:
146-
if update_refs:
147-
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
162+
logger.info(f"Unable to free on rank {self.rank_in_node}")
163+
pull_hi_cache_tensor[0] = 0
148164
pull_hi_cache = False
165+
ans_value_list = []
166+
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
149167
if pull_hi_cache:
150168
buffers = self.mem_manager.alloc(max_len)
151-
before_pull_time = time.time()
152-
logger.info(
153-
f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_query_time}"
154-
)
155-
read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r")
156-
while not read_task.ready():
157-
time.sleep(0.1)
158-
hicache_pull_time = time.time()
159-
logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull took {hicache_pull_time - before_pull_time}")
169+
# before_pull_time = time.time()
170+
# logger.info(
171+
# f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}"
172+
# )
173+
if self.do_store:
174+
read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r")
175+
while not read_task.ready():
176+
time.sleep(0.05)
177+
dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0)
178+
# hicache_pull_time = time.time()
179+
# logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s")
160180
logger.info(f"HiCache pulled one cache with len = {max_len}")
161181
# maybe try: add a function to only insert middle part of kv cache
162182
self._insert_helper(self.root_node, key, buffers)
163-
insert_time = time.time()
164-
logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
183+
# insert_time = time.time()
184+
# logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
165185
ans_value_list = []
166186
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
167-
logger.info(
168-
f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
169-
+ f" matched {sum(len(s) for s in ans_value_list)} tokens"
170-
)
187+
# logger.info(
188+
# f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
189+
# + f" matched {sum(len(s) for s in ans_value_list)} tokens"
190+
# )
171191
if tree_node != self.root_node:
172192
if len(ans_value_list) != 0:
173193
value = torch.concat(ans_value_list)

0 commit comments

Comments
 (0)