Skip to content

Commit

Permalink
refactor the cache manager
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 3, 2025
1 parent 218bb8a commit 4519823
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 268 deletions.
2 changes: 1 addition & 1 deletion run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def parse_args():
os.environ["TLLM_BACKEND"] = args.backend
os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend

from tllm.commons.manager import load_client_model, load_master_model
from tllm.commons.tp_communicator import Communicator
from tllm.commons.weight_manager import load_client_model, load_master_model
from tllm.engine import AsyncEngine
from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest
from tllm.entrypoints.image_server.server_image import ImageServing
Expand Down
2 changes: 1 addition & 1 deletion run_janus_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from PIL import Image
import numpy as np

from tllm.commons.manager import load_client_model, load_master_model
from tllm.commons.tp_communicator import Communicator
from tllm.commons.weight_manager import load_client_model, load_master_model
from tllm.engine import AsyncEngine
from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse
from tllm.entrypoints.server_chat import OpenAIServing
Expand Down
1 change: 1 addition & 0 deletions tllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class BackendEnum(Enum):
DEVICE = None
from tllm.models.mlx import *
elif BACKEND == BackendEnum.TORCH:
ENABLE_PREFILL_CACHE = False
import torch

DTYPE = torch.float16
Expand Down
2 changes: 2 additions & 0 deletions tllm/commons/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class AttnBackendEnum(Enum):

if os.environ.get("TLLM_ATTN_BACKEND", None):
ATTN_BACKEND = AttnBackendEnum[os.environ["TLLM_ATTN_BACKEND"]]
else:
ATTN_BACKEND = AttnBackendEnum.TORCH

if ATTN_BACKEND in [AttnBackendEnum.AUTO, AttnBackendEnum.VLLM] and importlib.util.find_spec("vllm"):
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down
91 changes: 49 additions & 42 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# coding: utf-8
import copy
from dataclasses import dataclass
import time
from typing import Dict, List, Optional
from typing import Dict, Generic, List, Optional, TypeVar

from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFILL_CACHE, BackendEnum
from tllm.commons.radix_tree import RadixTree
Expand All @@ -23,6 +24,48 @@
arange_func = lambda x: torch.arange(0, x, dtype=torch.long)


T = TypeVar("T")


@dataclass
class CacheEntry(Generic[T]):
value: T
timestamp: float


class Cache(Generic[T]):
def __init__(self, max_alive_time: int = 60):
self._cache: Dict[str, CacheEntry[T]] = {}
self.max_alive_time = max_alive_time

def get(self, key: str) -> Optional[T]:
if not self.contains(key):
return None
entry = self._cache[key]
entry.timestamp = time.time() # 更新访问时间
return entry.value

def set(self, key: str, value: T) -> None:
self._cache[key] = CacheEntry(value=value, timestamp=time.time())

def contains(self, key: str) -> bool:
return key in self._cache

def delete(self, key: str) -> None:
self._cache.pop(key, None)

def clear(self) -> None:
self._cache.clear()

def check_alive(self) -> None:
current_time = time.time()
expired_keys = [
key for key, entry in self._cache.items() if current_time - entry.timestamp > self.max_alive_time
]
for key in expired_keys:
self.delete(key)


class KVCache:
def __init__(
self, max_seq_len: Optional[int] = -1, num_key_value_heads: Optional[int] = -1, head_dim: Optional[int] = -1
Expand Down Expand Up @@ -91,7 +134,7 @@ def add(self, uuid: str, q_len: int, decoder_cache: Optional[DecoderCache] = Non
else decoder_cache
)

def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
def build(self, seq_input: SeqInput, cache: Cache):
q_len_list, k_len_list = [], []
position_ids_list = []
hit_cache_len_list = []
Expand All @@ -100,8 +143,8 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
hit_cache_len = -1
q_len = len(input_ids)
# decoding 阶段
if uuid in cache_manager.cache_dict:
decoder_cache: DecoderCache = cache_manager.get(uuid)
if cache.contains(uuid):
decoder_cache: DecoderCache = cache.get(uuid)
decoder_cache.set_q_len(q_len)
cache_seq_len = decoder_cache[0].kv_len
position_ids = array_func(cache_seq_len)
Expand All @@ -114,8 +157,8 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
# 不启用 prefix cache
hit_uuid, hit_cache_len = None, -1
# 命中了之前的 kv cache,使用历史 cache
if hit_uuid is not None and cache_manager.get(hit_uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid))
if hit_uuid is not None and cache.get(uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(uuid))
# 相同输入时,避免过超过 cache 长度
if q_len <= hit_cache_len:
hit_cache_len = q_len - 2
Expand Down Expand Up @@ -259,50 +302,14 @@ def __init__(
request_cache: RequestsCache,
attn_mask: MIX_TENSOR,
position_ids=None,
hit_cache_len_list=None,
q_len_list=None,
) -> None:
self.uuid_list = uuid_list
self.request_cache = request_cache
self.attn_mask = attn_mask
self.position_ids = position_ids # 只在 torch 下有意义
self.hit_cache_len_list = hit_cache_len_list # 用于 PP=0 截断 hidden_states
self.q_len_list = q_len_list

def get_decoder_cache(self, uuid: str) -> DecoderCache:
return self.request_cache.get_decoder_cache(uuid)

def get_kv_len(self, uuid: str) -> int:
return self.request_cache.get_kv_len(uuid)


class CacheManager:
# 管理每个节点的所有层 kv_cache
# max_alive_time: 超过多久没有访问就删除,单位秒
def __init__(self, max_alive_time=60):
self.max_alive_time = max_alive_time
self.cache_dict = {}

def get(self, key) -> Optional[DecoderCache]:
if self.is_contain(key):
return self.cache_dict.get(key)["cache"]
return None

def set(self, key, value: DecoderCache) -> None:
self.cache_dict[key] = {"cache": value, "ts": time.time()}

def is_contain(self, key) -> bool:
return key in self.cache_dict

def delete(self, key):
self.cache_dict.pop(key)

def clear(self):
self.cache_dict.clear()

def check_alive(self):
now = time.time()
key_list = list(self.cache_dict.keys())
for key in key_list:
if now - self.cache_dict[key]["ts"] > self.max_alive_time:
self.cache_dict.pop(key)
37 changes: 37 additions & 0 deletions tllm/commons/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from tllm.commons.cache import AttentionData, Cache, RequestsCache


class CacheManager:
# 管理每个节点的所有层 kv_cache
# max_alive_time: 超过多久没有访问就删除,单位秒
def __init__(self, max_alive_time=60):
self.cache = Cache(max_alive_time)
self.request_cache: RequestsCache = None
self.is_start_pp: bool = None
self.is_end_pp: bool = None
self.attn_data: AttentionData = None

def init_request_cache(self, num_layers: int, max_seq_len: int, n_kv_heads: int, head_dim: int):
self.request_cache = RequestsCache(num_layers, max_seq_len, n_kv_heads, head_dim)

def contains(self, key) -> bool:
return self.cache.contains(key)

def update_cache(self, seq_input):
for uuid in seq_input.uuid_list:
self.cache.set(uuid, self.attn_data.get_decoder_cache(uuid))
self.cache.check_alive()

if self.request_cache is not None:
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)

def post_init(self, is_start_pp: bool, is_end_pp: bool):
self.is_start_pp = is_start_pp
self.is_end_pp = is_end_pp

def build_forward_cache(self, hidden_states, seq_input):
raise NotImplementedError

def get_last_hidden_states(self, hidden_states):
raise NotImplementedError
File renamed without changes.
2 changes: 1 addition & 1 deletion tllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import HTMLResponse, JSONResponse, Response, StreamingResponse

from tllm import CLIENT_SOCKET_PATH, MASTER_SOCKET_PATH
from tllm.commons.manager import load_master_model
from tllm.commons.weight_manager import load_master_model
from tllm.engine import AsyncEngine
from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest, Text2ImageResponse
from tllm.entrypoints.image_server.server_image import ImageServing
Expand Down
2 changes: 1 addition & 1 deletion tllm/grpc/worker_service/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from tllm import CLIENT_SOCKET_PATH, GRPC_OPTIONS
from tllm.commons.convert import Convertor
from tllm.commons.manager import load_client_model
from tllm.commons.tp_communicator import BaseCommunicator, Communicator
from tllm.commons.weight_manager import load_client_model
from tllm.entrypoints.utils import parse_handler_args, update_handler_args
from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc
from tllm.grpc.worker_service.http_client import HTTPClient
Expand Down
13 changes: 7 additions & 6 deletions tllm/models/mlx/flux/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from mlx import nn
import mlx.core as mx

from tllm.commons.cache import CacheManager
from tllm.commons.cache import Cache

cache = Cache()


def prepare_latent_image_ids(height: int, width: int) -> mx.array:
Expand Down Expand Up @@ -74,7 +76,6 @@ def __init__(self):
super().__init__()
self.num_hidden_layers = 38
self.transformer = SingleTransformer(0, self.num_hidden_layers, self.num_hidden_layers)
self.cache_manager = CacheManager()

@classmethod
def from_pretrained(cls, config, state_dict, **kwargs):
Expand Down Expand Up @@ -115,12 +116,12 @@ def __call__(
) -> mx.array:

request_id = request_id_list[0]
if self.cache_manager.is_contain(request_id):
image_rotary_emb = self.cache_dict.get(request_id)
if cache.contains(request_id):
image_rotary_emb = cache.get(request_id)
else:
image_rotary_emb = self.get_image_rotary_emb(height, width, seq_len)
self.cache_dict.set(request_id, image_rotary_emb)
self.cache_dict.check_alive()
cache.set(request_id, image_rotary_emb)
cache.check_alive()

hidden_states = self.transformer(hidden_states, text_embeddings, image_rotary_emb, seq_len)
mx.eval(hidden_states)
Expand Down
93 changes: 46 additions & 47 deletions tllm/models/mlx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import mlx.nn as nn

from tllm import DTYPE, ENABLE_PREFILL_CACHE
from tllm.commons.cache import AttentionData, CacheManager, RequestsCache
from tllm.commons.cache import AttentionData
from tllm.commons.cache_manager import CacheManager
from tllm.schemas import SeqInput


Expand Down Expand Up @@ -44,19 +45,51 @@ def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], hit_cache_len_l
return final_mask


def build_forward_cache(
seq_input: SeqInput, cache_manager: CacheManager, request_cache: RequestsCache
) -> AttentionData:
q_len_list, k_len_list, position_ids_list, hit_cache_len_list = request_cache.build(seq_input, cache_manager)
class MLXCacheManager(CacheManager):
def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.request_cache.build(seq_input, self.cache)

self.hit_cache_flag = any(x != -1 for x in hit_cache_len_list)

# 截断 hidden_states
if ENABLE_PREFILL_CACHE and self.is_start_pp and self.hit_cache_flag:
hidden_states_list = []
q_start = 0
for q_len, hit_cache_len in zip(q_len_list, hit_cache_len_list):
if hit_cache_len != -1:
hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:])
else:
hidden_states_list.append(hidden_states[q_start : q_start + q_len])
q_start += q_len
hidden_states = mx.concat(hidden_states_list, axis=0)

if hidden_states.dtype == mx.float16: # float16 is much slower than bfloat16
hidden_states = hidden_states.astype(mx.bfloat16)

self.q_len_list = q_len_list
self.hit_cache_len_list = hit_cache_len_list
self.attn_data = AttentionData(
request_cache=self.request_cache,
attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list),
uuid_list=seq_input.uuid_list,
position_ids=mx.concatenate(position_ids_list, axis=-1),
)

return hidden_states

return AttentionData(
request_cache=request_cache,
attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list),
uuid_list=seq_input.uuid_list,
position_ids=mx.concatenate(position_ids_list, axis=-1),
hit_cache_len_list=hit_cache_len_list,
q_len_list=q_len_list,
)
def get_last_hidden_states(self, hidden_states: mx.array) -> mx.array:
split_len_list = self.q_len_list
if self.hit_cache_flag:
q_start = 0
for i, (q_len, hit_cache_len) in enumerate(zip(self.q_len_list, self.hit_cache_len_list)):
if hit_cache_len != -1:
split_len_list[i] = q_len - hit_cache_len
q_start += q_len
if self.is_end_pp:
index_list = list(itertools.accumulate(split_len_list[:-1]))
seq_hidden_states = mx.split(hidden_states, index_list, axis=0)
hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0)
return hidden_states


def quantization_func(config, model, state_dict):
Expand All @@ -82,40 +115,6 @@ def read_from_safetensors(file_path: str) -> Dict[str, mx.array]:
return mx.load(file_path)


def get_last_hidden_states(
hit_cache_flag: bool, is_end_pp: bool, attention_data: AttentionData, hidden_states: mx.array
) -> mx.array:
split_len_list = attention_data.q_len_list
if hit_cache_flag:
q_start = 0
for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)):
if hit_cache_len != -1:
split_len_list[i] = q_len - hit_cache_len
q_start += q_len
if is_end_pp:
index_list = list(itertools.accumulate(split_len_list[:-1]))
seq_hidden_states = mx.split(hidden_states, index_list, axis=0)
hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0)
return hidden_states


def truncate_hidden_states(
hit_cache_flag: bool, is_start_pp: bool, attention_data: AttentionData, hidden_states: mx.array
) -> mx.array:
# 截断 hidden_states
if ENABLE_PREFILL_CACHE and is_start_pp and hit_cache_flag:
hidden_states_list = []
q_start = 0
for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list):
if hit_cache_len != -1:
hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:])
else:
hidden_states_list.append(hidden_states[q_start : q_start + q_len])
q_start += q_len
hidden_states = mx.concat(hidden_states_list, axis=0)
return hidden_states


def dict_to_dataclass(data: dict, name: str):
"""将字典转换为 dataclass
Expand Down
Loading

0 comments on commit 4519823

Please sign in to comment.