diff --git a/tllm/commons/cache_manager.py b/tllm/commons/cache_manager.py index 712190d..edbf0a5 100644 --- a/tllm/commons/cache_manager.py +++ b/tllm/commons/cache_manager.py @@ -1,5 +1,5 @@ import copy -from typing import List, Optional +from typing import List from tllm import ENABLE_PREFILL_CACHE from tllm.commons.cache import AttentionData, Cache, DecoderCache, RequestsCache, arange_func, array_func @@ -35,7 +35,6 @@ def _build_single_cache(self, uuid: str, input_ids: List[int], cache: Cache): q_len = len(input_ids) is_decoding = cache.contains(uuid) - # decoding 阶段 if is_decoding: decoder_cache: DecoderCache = cache.get(uuid) decoder_cache.set_q_len(q_len) @@ -43,28 +42,26 @@ def _build_single_cache(self, uuid: str, input_ids: List[int], cache: Cache): position_ids = array_func(cache_seq_len) return q_len, cache_seq_len + q_len, hit_cache_len, position_ids, decoder_cache - # prefilling 阶段 + # In Prefill if ENABLE_PREFILL_CACHE: hit_uuid, hit_cache_len = self.request_cache.radix_tree.longest_common_prefix(input_ids) else: hit_uuid, hit_cache_len = None, -1 + position_ids = arange_func(q_len) if hit_uuid is not None and cache.get(hit_uuid) is not None: hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(hit_uuid)) - # 相同输入时,避免过超过 cache 长度 + # Do not exceed the cache length for the same input if q_len <= hit_cache_len: hit_cache_len = q_len - 2 hid_decoder_cache.truncate(hit_cache_len) hid_decoder_cache.set_q_len(q_len - hit_cache_len) decoder_cache = hid_decoder_cache - position_ids = arange_func(q_len) - return q_len, q_len, hit_cache_len, position_ids, decoder_cache else: hit_cache_len = -1 decoder_cache = None - position_ids = arange_func(q_len) - return q_len, q_len, hit_cache_len, position_ids, decoder_cache + return q_len, q_len, hit_cache_len, position_ids, decoder_cache def build_cache(self, seq_input, cache: Cache): q_len_list, k_len_list = [], [] @@ -76,8 +73,8 @@ def build_cache(self, seq_input, cache: Cache): q_len_list.append(q_len) k_len_list.append(k_len) - position_ids_list.append(position_ids) hit_cache_len_list.append(hit_cache_len) + position_ids_list.append(position_ids) self.request_cache.add(uuid, q_len, decoder_cache) return q_len_list, k_len_list, position_ids_list, hit_cache_len_list