diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index fc3535b..85e5eb6 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -134,53 +134,6 @@ def add(self, uuid: str, q_len: int, decoder_cache: Optional[DecoderCache] = Non else decoder_cache ) - def build(self, seq_input: SeqInput, cache: Cache): - q_len_list, k_len_list = [], [] - position_ids_list = [] - hit_cache_len_list = [] - - for uuid, input_ids in zip(seq_input.uuid_list, seq_input.input_ids_list): - hit_cache_len = -1 - q_len = len(input_ids) - # decoding 阶段 - if cache.contains(uuid): - decoder_cache: DecoderCache = cache.get(uuid) - decoder_cache.set_q_len(q_len) - cache_seq_len = decoder_cache[0].kv_len - position_ids = array_func(cache_seq_len) - k_len_list.append(cache_seq_len + q_len) - # prefilling 阶段 - else: - if ENABLE_PREFILL_CACHE: - hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids) - else: - # 不启用 prefix cache - hit_uuid, hit_cache_len = None, -1 - # 命中了之前的 kv cache,使用历史 cache - if hit_uuid is not None and cache.get(uuid) is not None: - hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(uuid)) - # 相同输入时,避免过超过 cache 长度 - if q_len <= hit_cache_len: - hit_cache_len = q_len - 2 - - hid_decoder_cache.truncate(hit_cache_len) - hid_decoder_cache.set_q_len(q_len - hit_cache_len) - decoder_cache = hid_decoder_cache - position_ids = arange_func(q_len) - k_len_list.append(q_len) - # 未命中任何 kv cache,新建 cache - else: - hit_cache_len = -1 - decoder_cache = None - position_ids = arange_func(q_len) - k_len_list.append(q_len) - q_len_list.append(q_len) - position_ids_list.append(position_ids) - hit_cache_len_list.append(hit_cache_len) - - self.add(uuid, q_len, decoder_cache) - return q_len_list, k_len_list, position_ids_list, hit_cache_len_list - def get_decoder_cache(self, uuid: str) -> DecoderCache: return self.cache_dict[uuid] diff --git a/tllm/commons/cache_manager.py b/tllm/commons/cache_manager.py index eecc545..712190d 100644 --- a/tllm/commons/cache_manager.py +++ b/tllm/commons/cache_manager.py @@ -1,4 +1,8 @@ -from tllm.commons.cache import AttentionData, Cache, RequestsCache +import copy +from typing import List, Optional + +from tllm import ENABLE_PREFILL_CACHE +from tllm.commons.cache import AttentionData, Cache, DecoderCache, RequestsCache, arange_func, array_func class CacheManager: @@ -26,6 +30,58 @@ def update_cache(self, seq_input): self.request_cache.clear() self.request_cache.insert_cache(seq_input) + def _build_single_cache(self, uuid: str, input_ids: List[int], cache: Cache): + hit_cache_len = -1 + q_len = len(input_ids) + is_decoding = cache.contains(uuid) + + # decoding 阶段 + if is_decoding: + decoder_cache: DecoderCache = cache.get(uuid) + decoder_cache.set_q_len(q_len) + cache_seq_len = decoder_cache[0].kv_len + position_ids = array_func(cache_seq_len) + return q_len, cache_seq_len + q_len, hit_cache_len, position_ids, decoder_cache + + # prefilling 阶段 + if ENABLE_PREFILL_CACHE: + hit_uuid, hit_cache_len = self.request_cache.radix_tree.longest_common_prefix(input_ids) + else: + hit_uuid, hit_cache_len = None, -1 + + if hit_uuid is not None and cache.get(hit_uuid) is not None: + hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(hit_uuid)) + # 相同输入时,避免过超过 cache 长度 + if q_len <= hit_cache_len: + hit_cache_len = q_len - 2 + + hid_decoder_cache.truncate(hit_cache_len) + hid_decoder_cache.set_q_len(q_len - hit_cache_len) + decoder_cache = hid_decoder_cache + position_ids = arange_func(q_len) + return q_len, q_len, hit_cache_len, position_ids, decoder_cache + else: + hit_cache_len = -1 + decoder_cache = None + position_ids = arange_func(q_len) + return q_len, q_len, hit_cache_len, position_ids, decoder_cache + + def build_cache(self, seq_input, cache: Cache): + q_len_list, k_len_list = [], [] + position_ids_list = [] + hit_cache_len_list = [] + + for uuid, input_ids in zip(seq_input.uuid_list, seq_input.input_ids_list): + q_len, k_len, hit_cache_len, position_ids, decoder_cache = self._build_single_cache(uuid, input_ids, cache) + + q_len_list.append(q_len) + k_len_list.append(k_len) + position_ids_list.append(position_ids) + hit_cache_len_list.append(hit_cache_len) + + self.request_cache.add(uuid, q_len, decoder_cache) + return q_len_list, k_len_list, position_ids_list, hit_cache_len_list + def post_init(self, is_start_pp: bool, is_end_pp: bool): self.is_start_pp = is_start_pp self.is_end_pp = is_end_pp diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index d50f56a..c1dd357 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -47,7 +47,7 @@ def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], hit_cache_len_l class MLXCacheManager(CacheManager): def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: - q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.request_cache.build(seq_input, self.cache) + q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.build_cache(seq_input, self.cache) self.hit_cache_flag = any(x != -1 for x in hit_cache_len_list) diff --git a/tllm/models/torch/helper.py b/tllm/models/torch/helper.py index 7c9fd01..2073179 100644 --- a/tllm/models/torch/helper.py +++ b/tllm/models/torch/helper.py @@ -51,8 +51,7 @@ def read_from_safetensors(file_path: str) -> Dict[str, torch.Tensor]: class TorchCacheManager(CacheManager): def build_forward_cache(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Tensor: - # request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) - q_len_list, k_len_list, position_ids_list, _ = self.request_cache.build(seq_input, self.cache) + q_len_list, k_len_list, position_ids_list, _ = self.build_cache(seq_input, self.cache) self.hit_cache_flag = None