Skip to content

Commit

Permalink
clean cache manager code
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 6, 2025
1 parent 6e86376 commit fa6fcb4
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions tllm/commons/cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List

from tllm import ENABLE_PREFILL_CACHE
from tllm.commons.cache import AttentionData, Cache, DecoderCache, RequestsCache, arange_func, array_func
Expand Down Expand Up @@ -35,36 +35,33 @@ def _build_single_cache(self, uuid: str, input_ids: List[int], cache: Cache):
q_len = len(input_ids)
is_decoding = cache.contains(uuid)

# decoding 阶段
if is_decoding:
decoder_cache: DecoderCache = cache.get(uuid)
decoder_cache.set_q_len(q_len)
cache_seq_len = decoder_cache[0].kv_len
position_ids = array_func(cache_seq_len)
return q_len, cache_seq_len + q_len, hit_cache_len, position_ids, decoder_cache

# prefilling 阶段
# In Prefill
if ENABLE_PREFILL_CACHE:
hit_uuid, hit_cache_len = self.request_cache.radix_tree.longest_common_prefix(input_ids)
else:
hit_uuid, hit_cache_len = None, -1

position_ids = arange_func(q_len)
if hit_uuid is not None and cache.get(hit_uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(hit_uuid))
# 相同输入时,避免过超过 cache 长度
# Do not exceed the cache length for the same input
if q_len <= hit_cache_len:
hit_cache_len = q_len - 2

hid_decoder_cache.truncate(hit_cache_len)
hid_decoder_cache.set_q_len(q_len - hit_cache_len)
decoder_cache = hid_decoder_cache
position_ids = arange_func(q_len)
return q_len, q_len, hit_cache_len, position_ids, decoder_cache
else:
hit_cache_len = -1
decoder_cache = None
position_ids = arange_func(q_len)
return q_len, q_len, hit_cache_len, position_ids, decoder_cache
return q_len, q_len, hit_cache_len, position_ids, decoder_cache

def build_cache(self, seq_input, cache: Cache):
q_len_list, k_len_list = [], []
Expand All @@ -76,8 +73,8 @@ def build_cache(self, seq_input, cache: Cache):

q_len_list.append(q_len)
k_len_list.append(k_len)
position_ids_list.append(position_ids)
hit_cache_len_list.append(hit_cache_len)
position_ids_list.append(position_ids)

self.request_cache.add(uuid, q_len, decoder_cache)
return q_len_list, k_len_list, position_ids_list, hit_cache_len_list
Expand Down

0 comments on commit fa6fcb4

Please sign in to comment.