Skip to content

Commit

Permalink
splitting the build cache logic
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 6, 2025
1 parent 68e67b6 commit 6e86376
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 51 deletions.
47 changes: 0 additions & 47 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,53 +134,6 @@ def add(self, uuid: str, q_len: int, decoder_cache: Optional[DecoderCache] = Non
else decoder_cache
)

def build(self, seq_input: SeqInput, cache: Cache):
q_len_list, k_len_list = [], []
position_ids_list = []
hit_cache_len_list = []

for uuid, input_ids in zip(seq_input.uuid_list, seq_input.input_ids_list):
hit_cache_len = -1
q_len = len(input_ids)
# decoding 阶段
if cache.contains(uuid):
decoder_cache: DecoderCache = cache.get(uuid)
decoder_cache.set_q_len(q_len)
cache_seq_len = decoder_cache[0].kv_len
position_ids = array_func(cache_seq_len)
k_len_list.append(cache_seq_len + q_len)
# prefilling 阶段
else:
if ENABLE_PREFILL_CACHE:
hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids)
else:
# 不启用 prefix cache
hit_uuid, hit_cache_len = None, -1
# 命中了之前的 kv cache,使用历史 cache
if hit_uuid is not None and cache.get(uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(uuid))
# 相同输入时,避免过超过 cache 长度
if q_len <= hit_cache_len:
hit_cache_len = q_len - 2

hid_decoder_cache.truncate(hit_cache_len)
hid_decoder_cache.set_q_len(q_len - hit_cache_len)
decoder_cache = hid_decoder_cache
position_ids = arange_func(q_len)
k_len_list.append(q_len)
# 未命中任何 kv cache,新建 cache
else:
hit_cache_len = -1
decoder_cache = None
position_ids = arange_func(q_len)
k_len_list.append(q_len)
q_len_list.append(q_len)
position_ids_list.append(position_ids)
hit_cache_len_list.append(hit_cache_len)

self.add(uuid, q_len, decoder_cache)
return q_len_list, k_len_list, position_ids_list, hit_cache_len_list

def get_decoder_cache(self, uuid: str) -> DecoderCache:
return self.cache_dict[uuid]

Expand Down
58 changes: 57 additions & 1 deletion tllm/commons/cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from tllm.commons.cache import AttentionData, Cache, RequestsCache
import copy
from typing import List, Optional

from tllm import ENABLE_PREFILL_CACHE
from tllm.commons.cache import AttentionData, Cache, DecoderCache, RequestsCache, arange_func, array_func


class CacheManager:
Expand Down Expand Up @@ -26,6 +30,58 @@ def update_cache(self, seq_input):
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)

def _build_single_cache(self, uuid: str, input_ids: List[int], cache: Cache):
hit_cache_len = -1
q_len = len(input_ids)
is_decoding = cache.contains(uuid)

# decoding 阶段
if is_decoding:
decoder_cache: DecoderCache = cache.get(uuid)
decoder_cache.set_q_len(q_len)
cache_seq_len = decoder_cache[0].kv_len
position_ids = array_func(cache_seq_len)
return q_len, cache_seq_len + q_len, hit_cache_len, position_ids, decoder_cache

# prefilling 阶段
if ENABLE_PREFILL_CACHE:
hit_uuid, hit_cache_len = self.request_cache.radix_tree.longest_common_prefix(input_ids)
else:
hit_uuid, hit_cache_len = None, -1

if hit_uuid is not None and cache.get(hit_uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(hit_uuid))
# 相同输入时,避免过超过 cache 长度
if q_len <= hit_cache_len:
hit_cache_len = q_len - 2

hid_decoder_cache.truncate(hit_cache_len)
hid_decoder_cache.set_q_len(q_len - hit_cache_len)
decoder_cache = hid_decoder_cache
position_ids = arange_func(q_len)
return q_len, q_len, hit_cache_len, position_ids, decoder_cache
else:
hit_cache_len = -1
decoder_cache = None
position_ids = arange_func(q_len)
return q_len, q_len, hit_cache_len, position_ids, decoder_cache

def build_cache(self, seq_input, cache: Cache):
q_len_list, k_len_list = [], []
position_ids_list = []
hit_cache_len_list = []

for uuid, input_ids in zip(seq_input.uuid_list, seq_input.input_ids_list):
q_len, k_len, hit_cache_len, position_ids, decoder_cache = self._build_single_cache(uuid, input_ids, cache)

q_len_list.append(q_len)
k_len_list.append(k_len)
position_ids_list.append(position_ids)
hit_cache_len_list.append(hit_cache_len)

self.request_cache.add(uuid, q_len, decoder_cache)
return q_len_list, k_len_list, position_ids_list, hit_cache_len_list

def post_init(self, is_start_pp: bool, is_end_pp: bool):
self.is_start_pp = is_start_pp
self.is_end_pp = is_end_pp
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/mlx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], hit_cache_len_l

class MLXCacheManager(CacheManager):
def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.request_cache.build(seq_input, self.cache)
q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.build_cache(seq_input, self.cache)

self.hit_cache_flag = any(x != -1 for x in hit_cache_len_list)

Expand Down
3 changes: 1 addition & 2 deletions tllm/models/torch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def read_from_safetensors(file_path: str) -> Dict[str, torch.Tensor]:

class TorchCacheManager(CacheManager):
def build_forward_cache(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Tensor:
# request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim)
q_len_list, k_len_list, position_ids_list, _ = self.request_cache.build(seq_input, self.cache)
q_len_list, k_len_list, position_ids_list, _ = self.build_cache(seq_input, self.cache)

self.hit_cache_flag = None

Expand Down

0 comments on commit 6e86376

Please sign in to comment.