From 4519823165367dfc46c7ae9d7eba55796cc5ede9 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 3 Feb 2025 13:33:14 +0800 Subject: [PATCH] refactor the cache manager --- run_engine.py | 2 +- run_janus_pro.py | 2 +- tllm/__init__.py | 1 + tllm/commons/attn.py | 2 + tllm/commons/cache.py | 91 +++++++++--------- tllm/commons/cache_manager.py | 37 ++++++++ .../commons/{manager.py => weight_manager.py} | 0 tllm/entrypoints/api_server.py | 2 +- tllm/grpc/worker_service/worker_server.py | 2 +- tllm/models/mlx/flux/transformer.py | 13 +-- tllm/models/mlx/helper.py | 93 +++++++++---------- tllm/models/mlx/llama.py | 45 ++++----- tllm/models/mlx/qwen2.py | 43 ++++----- tllm/models/tinygrad/llama.py | 26 ++---- tllm/models/torch/helper.py | 91 ++++++++++-------- tllm/models/torch/llama.py | 51 ++++------ tllm/models/torch/qwen2.py | 51 ++++------ 17 files changed, 284 insertions(+), 268 deletions(-) create mode 100644 tllm/commons/cache_manager.py rename tllm/commons/{manager.py => weight_manager.py} (100%) diff --git a/run_engine.py b/run_engine.py index 2beb57d..af3fdc1 100644 --- a/run_engine.py +++ b/run_engine.py @@ -28,8 +28,8 @@ def parse_args(): os.environ["TLLM_BACKEND"] = args.backend os.environ["TLLM_ATTN_BACKEND"] = args.attn_backend -from tllm.commons.manager import load_client_model, load_master_model from tllm.commons.tp_communicator import Communicator +from tllm.commons.weight_manager import load_client_model, load_master_model from tllm.engine import AsyncEngine from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest from tllm.entrypoints.image_server.server_image import ImageServing diff --git a/run_janus_pro.py b/run_janus_pro.py index 48ffde4..249af5c 100644 --- a/run_janus_pro.py +++ b/run_janus_pro.py @@ -9,8 +9,8 @@ from PIL import Image import numpy as np -from tllm.commons.manager import load_client_model, load_master_model from tllm.commons.tp_communicator import Communicator +from tllm.commons.weight_manager import load_client_model, load_master_model from tllm.engine import AsyncEngine from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse from tllm.entrypoints.server_chat import OpenAIServing diff --git a/tllm/__init__.py b/tllm/__init__.py index 06b59fe..de8017e 100644 --- a/tllm/__init__.py +++ b/tllm/__init__.py @@ -30,6 +30,7 @@ class BackendEnum(Enum): DEVICE = None from tllm.models.mlx import * elif BACKEND == BackendEnum.TORCH: + ENABLE_PREFILL_CACHE = False import torch DTYPE = torch.float16 diff --git a/tllm/commons/attn.py b/tllm/commons/attn.py index 212139c..2e1c612 100644 --- a/tllm/commons/attn.py +++ b/tllm/commons/attn.py @@ -17,6 +17,8 @@ class AttnBackendEnum(Enum): if os.environ.get("TLLM_ATTN_BACKEND", None): ATTN_BACKEND = AttnBackendEnum[os.environ["TLLM_ATTN_BACKEND"]] + else: + ATTN_BACKEND = AttnBackendEnum.TORCH if ATTN_BACKEND in [AttnBackendEnum.AUTO, AttnBackendEnum.VLLM] and importlib.util.find_spec("vllm"): from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index 9f9c8a1..3978bdd 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -1,7 +1,8 @@ # coding: utf-8 import copy +from dataclasses import dataclass import time -from typing import Dict, List, Optional +from typing import Dict, Generic, List, Optional, TypeVar from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFILL_CACHE, BackendEnum from tllm.commons.radix_tree import RadixTree @@ -23,6 +24,48 @@ arange_func = lambda x: torch.arange(0, x, dtype=torch.long) +T = TypeVar("T") + + +@dataclass +class CacheEntry(Generic[T]): + value: T + timestamp: float + + +class Cache(Generic[T]): + def __init__(self, max_alive_time: int = 60): + self._cache: Dict[str, CacheEntry[T]] = {} + self.max_alive_time = max_alive_time + + def get(self, key: str) -> Optional[T]: + if not self.contains(key): + return None + entry = self._cache[key] + entry.timestamp = time.time() # 更新访问时间 + return entry.value + + def set(self, key: str, value: T) -> None: + self._cache[key] = CacheEntry(value=value, timestamp=time.time()) + + def contains(self, key: str) -> bool: + return key in self._cache + + def delete(self, key: str) -> None: + self._cache.pop(key, None) + + def clear(self) -> None: + self._cache.clear() + + def check_alive(self) -> None: + current_time = time.time() + expired_keys = [ + key for key, entry in self._cache.items() if current_time - entry.timestamp > self.max_alive_time + ] + for key in expired_keys: + self.delete(key) + + class KVCache: def __init__( self, max_seq_len: Optional[int] = -1, num_key_value_heads: Optional[int] = -1, head_dim: Optional[int] = -1 @@ -91,7 +134,7 @@ def add(self, uuid: str, q_len: int, decoder_cache: Optional[DecoderCache] = Non else decoder_cache ) - def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): + def build(self, seq_input: SeqInput, cache: Cache): q_len_list, k_len_list = [], [] position_ids_list = [] hit_cache_len_list = [] @@ -100,8 +143,8 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): hit_cache_len = -1 q_len = len(input_ids) # decoding 阶段 - if uuid in cache_manager.cache_dict: - decoder_cache: DecoderCache = cache_manager.get(uuid) + if cache.contains(uuid): + decoder_cache: DecoderCache = cache.get(uuid) decoder_cache.set_q_len(q_len) cache_seq_len = decoder_cache[0].kv_len position_ids = array_func(cache_seq_len) @@ -114,8 +157,8 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): # 不启用 prefix cache hit_uuid, hit_cache_len = None, -1 # 命中了之前的 kv cache,使用历史 cache - if hit_uuid is not None and cache_manager.get(hit_uuid) is not None: - hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid)) + if hit_uuid is not None and cache.get(uuid) is not None: + hid_decoder_cache: DecoderCache = copy.deepcopy(cache.get(uuid)) # 相同输入时,避免过超过 cache 长度 if q_len <= hit_cache_len: hit_cache_len = q_len - 2 @@ -259,50 +302,14 @@ def __init__( request_cache: RequestsCache, attn_mask: MIX_TENSOR, position_ids=None, - hit_cache_len_list=None, - q_len_list=None, ) -> None: self.uuid_list = uuid_list self.request_cache = request_cache self.attn_mask = attn_mask self.position_ids = position_ids # 只在 torch 下有意义 - self.hit_cache_len_list = hit_cache_len_list # 用于 PP=0 截断 hidden_states - self.q_len_list = q_len_list def get_decoder_cache(self, uuid: str) -> DecoderCache: return self.request_cache.get_decoder_cache(uuid) def get_kv_len(self, uuid: str) -> int: return self.request_cache.get_kv_len(uuid) - - -class CacheManager: - # 管理每个节点的所有层 kv_cache - # max_alive_time: 超过多久没有访问就删除,单位秒 - def __init__(self, max_alive_time=60): - self.max_alive_time = max_alive_time - self.cache_dict = {} - - def get(self, key) -> Optional[DecoderCache]: - if self.is_contain(key): - return self.cache_dict.get(key)["cache"] - return None - - def set(self, key, value: DecoderCache) -> None: - self.cache_dict[key] = {"cache": value, "ts": time.time()} - - def is_contain(self, key) -> bool: - return key in self.cache_dict - - def delete(self, key): - self.cache_dict.pop(key) - - def clear(self): - self.cache_dict.clear() - - def check_alive(self): - now = time.time() - key_list = list(self.cache_dict.keys()) - for key in key_list: - if now - self.cache_dict[key]["ts"] > self.max_alive_time: - self.cache_dict.pop(key) diff --git a/tllm/commons/cache_manager.py b/tllm/commons/cache_manager.py new file mode 100644 index 0000000..8f6b84e --- /dev/null +++ b/tllm/commons/cache_manager.py @@ -0,0 +1,37 @@ +from tllm.commons.cache import AttentionData, Cache, RequestsCache + + +class CacheManager: + # 管理每个节点的所有层 kv_cache + # max_alive_time: 超过多久没有访问就删除,单位秒 + def __init__(self, max_alive_time=60): + self.cache = Cache(max_alive_time) + self.request_cache: RequestsCache = None + self.is_start_pp: bool = None + self.is_end_pp: bool = None + self.attn_data: AttentionData = None + + def init_request_cache(self, num_layers: int, max_seq_len: int, n_kv_heads: int, head_dim: int): + self.request_cache = RequestsCache(num_layers, max_seq_len, n_kv_heads, head_dim) + + def contains(self, key) -> bool: + return self.cache.contains(key) + + def update_cache(self, seq_input): + for uuid in seq_input.uuid_list: + self.cache.set(uuid, self.attn_data.get_decoder_cache(uuid)) + self.cache.check_alive() + + if self.request_cache is not None: + self.request_cache.clear() + self.request_cache.insert_cache(seq_input) + + def post_init(self, is_start_pp: bool, is_end_pp: bool): + self.is_start_pp = is_start_pp + self.is_end_pp = is_end_pp + + def build_forward_cache(self, hidden_states, seq_input): + raise NotImplementedError + + def get_last_hidden_states(self, hidden_states): + raise NotImplementedError diff --git a/tllm/commons/manager.py b/tllm/commons/weight_manager.py similarity index 100% rename from tllm/commons/manager.py rename to tllm/commons/weight_manager.py diff --git a/tllm/entrypoints/api_server.py b/tllm/entrypoints/api_server.py index 0cea6d3..e4b481d 100644 --- a/tllm/entrypoints/api_server.py +++ b/tllm/entrypoints/api_server.py @@ -9,7 +9,7 @@ from fastapi.responses import HTMLResponse, JSONResponse, Response, StreamingResponse from tllm import CLIENT_SOCKET_PATH, MASTER_SOCKET_PATH -from tllm.commons.manager import load_master_model +from tllm.commons.weight_manager import load_master_model from tllm.engine import AsyncEngine from tllm.entrypoints.image_server.image_protocol import Text2ImageRequest, Text2ImageResponse from tllm.entrypoints.image_server.server_image import ImageServing diff --git a/tllm/grpc/worker_service/worker_server.py b/tllm/grpc/worker_service/worker_server.py index 5f24a44..42e2f4c 100644 --- a/tllm/grpc/worker_service/worker_server.py +++ b/tllm/grpc/worker_service/worker_server.py @@ -8,8 +8,8 @@ from tllm import CLIENT_SOCKET_PATH, GRPC_OPTIONS from tllm.commons.convert import Convertor -from tllm.commons.manager import load_client_model from tllm.commons.tp_communicator import BaseCommunicator, Communicator +from tllm.commons.weight_manager import load_client_model from tllm.entrypoints.utils import parse_handler_args, update_handler_args from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc from tllm.grpc.worker_service.http_client import HTTPClient diff --git a/tllm/models/mlx/flux/transformer.py b/tllm/models/mlx/flux/transformer.py index 0077172..d167b81 100644 --- a/tllm/models/mlx/flux/transformer.py +++ b/tllm/models/mlx/flux/transformer.py @@ -11,7 +11,9 @@ from mlx import nn import mlx.core as mx -from tllm.commons.cache import CacheManager +from tllm.commons.cache import Cache + +cache = Cache() def prepare_latent_image_ids(height: int, width: int) -> mx.array: @@ -74,7 +76,6 @@ def __init__(self): super().__init__() self.num_hidden_layers = 38 self.transformer = SingleTransformer(0, self.num_hidden_layers, self.num_hidden_layers) - self.cache_manager = CacheManager() @classmethod def from_pretrained(cls, config, state_dict, **kwargs): @@ -115,12 +116,12 @@ def __call__( ) -> mx.array: request_id = request_id_list[0] - if self.cache_manager.is_contain(request_id): - image_rotary_emb = self.cache_dict.get(request_id) + if cache.contains(request_id): + image_rotary_emb = cache.get(request_id) else: image_rotary_emb = self.get_image_rotary_emb(height, width, seq_len) - self.cache_dict.set(request_id, image_rotary_emb) - self.cache_dict.check_alive() + cache.set(request_id, image_rotary_emb) + cache.check_alive() hidden_states = self.transformer(hidden_states, text_embeddings, image_rotary_emb, seq_len) mx.eval(hidden_states) diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index 67509e3..d48e87e 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -7,7 +7,8 @@ import mlx.nn as nn from tllm import DTYPE, ENABLE_PREFILL_CACHE -from tllm.commons.cache import AttentionData, CacheManager, RequestsCache +from tllm.commons.cache import AttentionData +from tllm.commons.cache_manager import CacheManager from tllm.schemas import SeqInput @@ -44,19 +45,51 @@ def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], hit_cache_len_l return final_mask -def build_forward_cache( - seq_input: SeqInput, cache_manager: CacheManager, request_cache: RequestsCache -) -> AttentionData: - q_len_list, k_len_list, position_ids_list, hit_cache_len_list = request_cache.build(seq_input, cache_manager) +class MLXCacheManager(CacheManager): + def build_forward_cache(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: + q_len_list, k_len_list, position_ids_list, hit_cache_len_list = self.request_cache.build(seq_input, self.cache) + + self.hit_cache_flag = any(x != -1 for x in hit_cache_len_list) + + # 截断 hidden_states + if ENABLE_PREFILL_CACHE and self.is_start_pp and self.hit_cache_flag: + hidden_states_list = [] + q_start = 0 + for q_len, hit_cache_len in zip(q_len_list, hit_cache_len_list): + if hit_cache_len != -1: + hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:]) + else: + hidden_states_list.append(hidden_states[q_start : q_start + q_len]) + q_start += q_len + hidden_states = mx.concat(hidden_states_list, axis=0) + + if hidden_states.dtype == mx.float16: # float16 is much slower than bfloat16 + hidden_states = hidden_states.astype(mx.bfloat16) + + self.q_len_list = q_len_list + self.hit_cache_len_list = hit_cache_len_list + self.attn_data = AttentionData( + request_cache=self.request_cache, + attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list), + uuid_list=seq_input.uuid_list, + position_ids=mx.concatenate(position_ids_list, axis=-1), + ) + + return hidden_states - return AttentionData( - request_cache=request_cache, - attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list), - uuid_list=seq_input.uuid_list, - position_ids=mx.concatenate(position_ids_list, axis=-1), - hit_cache_len_list=hit_cache_len_list, - q_len_list=q_len_list, - ) + def get_last_hidden_states(self, hidden_states: mx.array) -> mx.array: + split_len_list = self.q_len_list + if self.hit_cache_flag: + q_start = 0 + for i, (q_len, hit_cache_len) in enumerate(zip(self.q_len_list, self.hit_cache_len_list)): + if hit_cache_len != -1: + split_len_list[i] = q_len - hit_cache_len + q_start += q_len + if self.is_end_pp: + index_list = list(itertools.accumulate(split_len_list[:-1])) + seq_hidden_states = mx.split(hidden_states, index_list, axis=0) + hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0) + return hidden_states def quantization_func(config, model, state_dict): @@ -82,40 +115,6 @@ def read_from_safetensors(file_path: str) -> Dict[str, mx.array]: return mx.load(file_path) -def get_last_hidden_states( - hit_cache_flag: bool, is_end_pp: bool, attention_data: AttentionData, hidden_states: mx.array -) -> mx.array: - split_len_list = attention_data.q_len_list - if hit_cache_flag: - q_start = 0 - for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)): - if hit_cache_len != -1: - split_len_list[i] = q_len - hit_cache_len - q_start += q_len - if is_end_pp: - index_list = list(itertools.accumulate(split_len_list[:-1])) - seq_hidden_states = mx.split(hidden_states, index_list, axis=0) - hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0) - return hidden_states - - -def truncate_hidden_states( - hit_cache_flag: bool, is_start_pp: bool, attention_data: AttentionData, hidden_states: mx.array -) -> mx.array: - # 截断 hidden_states - if ENABLE_PREFILL_CACHE and is_start_pp and hit_cache_flag: - hidden_states_list = [] - q_start = 0 - for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): - if hit_cache_len != -1: - hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:]) - else: - hidden_states_list.append(hidden_states[q_start : q_start + q_len]) - q_start += q_len - hidden_states = mx.concat(hidden_states_list, axis=0) - return hidden_states - - def dict_to_dataclass(data: dict, name: str): """将字典转换为 dataclass diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 2c9df45..f38cfdc 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -7,13 +7,7 @@ from transformers import AutoConfig from tllm import DTYPE -from tllm.commons.cache import CacheManager, RequestsCache -from tllm.models.mlx.helper import ( - build_forward_cache, - get_last_hidden_states, - quantization_func, - truncate_hidden_states, -) +from tllm.models.mlx.helper import MLXCacheManager, quantization_func from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import ( default_merge_attn, @@ -23,6 +17,8 @@ ) from tllm.schemas import SeqInput +cache_manager = MLXCacheManager() + def get_inv_freq_mx(dim, base): return 1.0 / (base ** (mx.arange(0, dim, 2, dtype=mx.int32).astype(mx.float32) / dim)) @@ -59,7 +55,6 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.rank = args.comm.rank self.vocab_size = args.vocab_size - self.cache_manager = CacheManager() self.config = config self.model = Decoder(args, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge) self.num_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx @@ -72,35 +67,29 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): # rope_type="default", # rope_scaling=1.0, # ) - self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) - self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads - self.head_dim = self.model.layers[-1].self_attn.head_dim - self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) - self.is_start_pp = self.config.decoder_start_layer_idx == 0 - self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) + n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads + head_dim = self.model.layers[-1].self_attn.head_dim + cache_manager.init_request_cache(self.num_layers, max_seq_len, n_kv_heads, head_dim) + + is_start_pp = self.config.decoder_start_layer_idx == 0 + is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + cache_manager.post_init(is_start_pp, is_end_pp) def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: - attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) - hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list) - hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states) + hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) # cos, sin = self.rotary_emb(attention_data.position_ids) # attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) - if hidden_states.dtype == mx.float16: # float16 is much slower than bfloat16 - hidden_states = hidden_states.astype(mx.bfloat16) - mask = attention_data.attn_mask + mask = cache_manager.attn_data.attn_mask mask = mask if mask is None else mask.astype(hidden_states.dtype) - output = self.model(hidden_states, mask=mask, cache=attention_data) + output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data) - # TODO 异步保存 cache - for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) - self.cache_manager.check_alive() - self.request_cache.clear() - self.request_cache.insert_cache(seq_input) + # TODO 异步更新 cache + cache_manager.update_cache(seq_input) - output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output) + output = cache_manager.get_last_hidden_states(output) return output @classmethod diff --git a/tllm/models/mlx/qwen2.py b/tllm/models/mlx/qwen2.py index 02b224d..5b928b5 100644 --- a/tllm/models/mlx/qwen2.py +++ b/tllm/models/mlx/qwen2.py @@ -7,17 +7,13 @@ from transformers import AutoConfig from tllm import DTYPE -from tllm.commons.cache import CacheManager, RequestsCache -from tllm.models.mlx.helper import ( - build_forward_cache, - get_last_hidden_states, - quantization_func, - truncate_hidden_states, -) +from tllm.models.mlx.helper import MLXCacheManager, quantization_func from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tie_word_embeddings_func from tllm.schemas import SeqInput +cache_manager = MLXCacheManager() + class MLXQwen2Model(nn.Module): def __init__(self, config: AutoConfig, is_merge: bool = True): @@ -35,35 +31,30 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): args.attention_bias = True # for qwen args.o_proj_bias = False # for qwen self.vocab_size = args.vocab_size - self.cache_manager = CacheManager() self.config = config self.model = Decoder(args, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge) self.num_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx - self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) - self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads - self.head_dim = self.model.layers[-1].self_attn.head_dim - self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) + max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) + n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads + head_dim = self.model.layers[-1].self_attn.head_dim + cache_manager.init_request_cache(self.num_layers, max_seq_len, n_kv_heads, head_dim) - self.is_start_pp = self.config.decoder_start_layer_idx == 0 - self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + is_start_pp = self.config.decoder_start_layer_idx == 0 + is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + cache_manager.post_init(is_start_pp, is_end_pp) def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: - attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) - hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list) - hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states) + hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) - mask = attention_data.attn_mask + mask = cache_manager.attn_data.attn_mask mask = mask if mask is None else mask.astype(hidden_states.dtype) - output = self.model(hidden_states, mask=mask, cache=attention_data) + output = self.model(hidden_states, mask=mask, cache=cache_manager.attn_data) - for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) - self.cache_manager.check_alive() - self.request_cache.clear() - self.request_cache.insert_cache(seq_input) + # TODO 异步更新 cache + cache_manager.update_cache(seq_input) - output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output) + output = cache_manager.get_last_hidden_states(output) return output @classmethod @@ -75,7 +66,7 @@ def from_pretrained(cls, config: AutoConfig, state_dict: Dict[str, mx.array], ** state_dict = model.sanitize(state_dict, config.decoder_start_layer_idx, config.decoder_end_layer_idx) model = quantization_func(config, model, state_dict) - model.load_weights(list(state_dict.items())) # strict=False + model.load_weights(list(state_dict.items())) mx.eval(model.parameters()) model.eval() diff --git a/tllm/models/tinygrad/llama.py b/tllm/models/tinygrad/llama.py index 29633ac..053b014 100644 --- a/tllm/models/tinygrad/llama.py +++ b/tllm/models/tinygrad/llama.py @@ -9,7 +9,8 @@ from tinygrad.helpers import getenv from tinygrad.nn.state import load_state_dict, safe_load, torch_load -from tllm.commons.cache import AttentionData, CacheManager, RequestsCache +from tllm.commons.cache import AttentionData, RequestsCache +from tllm.commons.cache_manager import CacheManager from tllm.schemas import SeqInput # Edited from https://github.com/tinygrad/tinygrad/blob/master/extra/models/llama.py @@ -211,11 +212,13 @@ def __call__(self, x: Tensor, freqs_cis: Tensor, attention_data: AttentionData): return (h + self.mlp(self.post_attention_layernorm(h))).contiguous() +cache_manager = CacheManager() + + class TinyGradLlamaModel: def __init__(self, config, is_merge: bool = True, jit: bool = True): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.cache_manager = CacheManager() self.config = config self.model = Decoder(config, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge) self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx @@ -277,28 +280,19 @@ def forward(self, hidden_states: Tensor, seq_input: SeqInput) -> Tensor: @return: bs x seq_len x hidden_size """ # Not support multi requests - attention_data = build_forward_cache(seq_input, self.cache_manager, self.num_decoder_layers) - - if attention_data.attn_mask is not None: - attention_data.attn_mask = attention_data.attn_mask.cast(hidden_states.dtype).to(hidden_states.device) - q_len_list, k_len_list = attention_data.position_ids + hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) freqs_cis = [] - for q_len, k_len in zip(q_len_list, k_len_list): + for q_len, k_len in zip(cache_manager.attn_data.q_len_list, cache_manager.attn_data.k_len_list): tmp_freqs_cis = self.freqs_cis.cast(hidden_states.dtype).realize() tmp_freqs_cis = tmp_freqs_cis.shrink(((k_len - q_len, k_len), None, None, None)) freqs_cis.append(tmp_freqs_cis) - hidden_states = self.model(hidden_states, freqs_cis=freqs_cis, attention_data=attention_data) - - if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - split_len_list = attention_data.q_len_list - hidden_states = get_last_hidden_states(hidden_states, split_len_list) + hidden_states = self.model(hidden_states, freqs_cis=freqs_cis, attention_data=cache_manager.attn_data) - for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) - self.cache_manager.check_alive() + cache_manager.update_cache(seq_input) + output = cache_manager.get_last_hidden_states(output) return hidden_states def __call__(self, hidden_states, seq_input): diff --git a/tllm/models/torch/helper.py b/tllm/models/torch/helper.py index a52f468..ce2ae05 100644 --- a/tllm/models/torch/helper.py +++ b/tllm/models/torch/helper.py @@ -4,8 +4,10 @@ from safetensors.torch import load_file as safe_load import torch +from tllm import DEVICE, DTYPE from tllm.commons.attn import ATTN_TYPE -from tllm.commons.cache import AttentionData, CacheManager, RequestsCache +from tllm.commons.cache import AttentionData +from tllm.commons.cache_manager import CacheManager from tllm.schemas import SeqInput @@ -14,13 +16,7 @@ def greedy_decode(logits: "torch.Tensor") -> List[int]: return torch.argmax(logits, dim=-1).tolist() -def get_last_hidden_states(hidden_states: torch.Tensor, seq_len_list: List[int]) -> torch.Tensor: - # 只取最后一个 token 的 hidden_states - seq_hidden_states = torch.split(hidden_states, [seq_len for seq_len in seq_len_list], dim=0) - return torch.cat([x[-1:, :] for x in seq_hidden_states], dim=0) - - -def build_mask(q_len_list: List[int], k_len_list: List[int]) -> "torch.Tensor": +def build_mask(q_len_list: List[int], k_len_list: List[int]) -> torch.Tensor: """ 构造多个请求的 casual mask @param @@ -53,32 +49,53 @@ def read_from_safetensors(file_path: str) -> Dict[str, torch.Tensor]: from xformers.ops import fmha -def build_forward_cache( - seq_input: SeqInput, - cache_manager: CacheManager, - num_layers: int, - max_seq_len: int = -1, - num_key_value_heads: int = -1, - head_dim: int = -1, -) -> AttentionData: - request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) - q_len_list, k_len_list, position_ids_list, _ = request_cache.build(seq_input, cache_manager) - - if ATTN_TYPE == "flash_attention": - attn_mask = { - "cu_seqlens_q": torch.tensor([0] + list(itertools.accumulate(q_len_list)), dtype=torch.int32), - "cu_seqlens_k": torch.tensor([0] + list(itertools.accumulate(k_len_list)), dtype=torch.int32), - "max_seqlen_q": max(q_len_list), - "max_seqlen_k": max(k_len_list), - } - # elif ATTN_TYPE == "xformers": - # attn_mask = fmha.BlockDiagonalMask.from_seqlens(q_seqlen=q_len_list, kv_seqlen=k_len_list) - else: - attn_mask = build_mask(q_len_list, k_len_list) - return AttentionData( - request_cache=request_cache, - attn_mask=attn_mask, - uuid_list=seq_input.uuid_list, - position_ids=torch.cat(position_ids_list, dim=-1), - q_len_list=q_len_list, - ) +class TorchCacheManager(CacheManager): + def build_forward_cache(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Tensor: + # request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) + q_len_list, k_len_list, position_ids_list, _ = self.request_cache.build(seq_input, self.cache) + + self.hit_cache_flag = None + + if ATTN_TYPE == "flash_attention": + attn_mask = { + "cu_seqlens_q": torch.tensor([0] + list(itertools.accumulate(q_len_list)), dtype=torch.int32), + "cu_seqlens_k": torch.tensor([0] + list(itertools.accumulate(k_len_list)), dtype=torch.int32), + "max_seqlen_q": max(q_len_list), + "max_seqlen_k": max(k_len_list), + } + # elif ATTN_TYPE == "xformers": + # attn_mask = fmha.BlockDiagonalMask.from_seqlens(q_seqlen=q_len_list, kv_seqlen=k_len_list) + else: + attn_mask = build_mask(q_len_list, k_len_list) + + self.q_len_list = q_len_list + self.hit_cache_len_list = None + + if ATTN_TYPE == "flash_attention": + attn_mask = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in attn_mask.items()} + else: + attn_mask = attn_mask.to(DEVICE) + + self.attn_data = AttentionData( + request_cache=self.request_cache, + attn_mask=attn_mask, + uuid_list=seq_input.uuid_list, + position_ids=torch.cat(position_ids_list, dim=-1), + ) + + hidden_states = hidden_states.to(DTYPE).to(DEVICE) + return hidden_states + + def get_last_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + split_len_list = self.q_len_list + # if self.hit_cache_flag: + # q_start = 0 + # for i, (q_len, hit_cache_len) in enumerate(zip(self.q_len_list, self.hit_cache_len_list)): + # if hit_cache_len != -1: + # split_len_list[i] = q_len - hit_cache_len + # q_start += q_len + if self.is_end_pp: + # 只取最后一个 token 的 hidden_states + seq_hidden_states = torch.split(hidden_states, [seq_len for seq_len in split_len_list], dim=0) + hidden_states = torch.cat([x[-1:, :] for x in seq_hidden_states], dim=0) + return hidden_states diff --git a/tllm/models/torch/llama.py b/tllm/models/torch/llama.py index f868789..941b63d 100644 --- a/tllm/models/torch/llama.py +++ b/tllm/models/torch/llama.py @@ -6,13 +6,13 @@ from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding from tllm import DEVICE, DTYPE -from tllm.commons.attn import ATTN_TYPE -from tllm.commons.cache import CacheManager -from tllm.models.torch.helper import build_forward_cache, get_last_hidden_states +from tllm.models.torch.helper import TorchCacheManager from tllm.models.torch.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tie_word_embeddings_func from tllm.schemas import SeqInput +cache_manager = TorchCacheManager() + class HFLlamaRotaryEmbedding(LlamaRotaryEmbedding): def forward(self, x, position_ids): @@ -43,15 +43,19 @@ def __init__(self, config, is_merge: bool = True): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.cache_manager = CacheManager() self.config = config self.model = Decoder(config, config.decoder_start_layer_idx, config.decoder_end_layer_idx, is_merge) self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx self.rotary_emb = HFLlamaRotaryEmbedding(config=config) - self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) - self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads - self.head_dim = self.model.layers[-1].self_attn.head_dim + max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) + num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads + head_dim = self.model.layers[-1].self_attn.head_dim + cache_manager.init_request_cache(self.num_decoder_layers, max_seq_len, num_key_value_heads, head_dim) + + is_start_pp = self.config.decoder_start_layer_idx == 0 + is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + cache_manager.post_init(is_start_pp, is_end_pp) @classmethod def from_pretrained(cls, config, state_dict: Dict[str, torch.Tensor], is_merge: bool = True, **kwargs): @@ -106,36 +110,21 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten @return: seq_len x hidden_size """ - attention_data = build_forward_cache( - seq_input, - self.cache_manager, - self.num_decoder_layers, - self.max_seq_len, - self.num_key_value_heads, - self.head_dim, + hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) + + position_embeddings = self.rotary_emb( + hidden_states, cache_manager.attn_data.position_ids.to(hidden_states.device) ) - hidden_states = hidden_states.to(DTYPE).to(DEVICE) - position_embeddings = self.rotary_emb(hidden_states, attention_data.position_ids.to(DEVICE)) - if ATTN_TYPE == "flash_attention": - attention_data.attn_mask = { - k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in attention_data.attn_mask.items() - } - else: - attention_data.attn_mask = attention_data.attn_mask.to(DEVICE) hidden_states = self.model( - hidden_states, position_embeddings=position_embeddings, attention_data=attention_data + hidden_states, position_embeddings=position_embeddings, attention_data=cache_manager.attn_data ) - if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - split_len_list = attention_data.q_len_list - hidden_states = get_last_hidden_states(hidden_states, split_len_list) - - for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) - self.cache_manager.check_alive() + # TODO 异步更新 cache + cache_manager.update_cache(seq_input) - return hidden_states + output = cache_manager.get_last_hidden_states(hidden_states) + return output class HFLlamaForCausalLM(nn.Module): diff --git a/tllm/models/torch/qwen2.py b/tllm/models/torch/qwen2.py index 7f43b9c..28564a9 100644 --- a/tllm/models/torch/qwen2.py +++ b/tllm/models/torch/qwen2.py @@ -6,13 +6,13 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2RotaryEmbedding from tllm import DEVICE, DTYPE -from tllm.commons.attn import ATTN_TYPE -from tllm.commons.cache import CacheManager -from tllm.models.torch.helper import build_forward_cache, get_last_hidden_states +from tllm.models.torch.helper import TorchCacheManager from tllm.models.torch.llama import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tie_word_embeddings_func from tllm.schemas import SeqInput +cache_manager = TorchCacheManager() + class HFQwen2RotaryEmbedding(Qwen2RotaryEmbedding): def forward(self, x, position_ids): @@ -43,7 +43,6 @@ def __init__(self, config, is_merge: bool = True): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.cache_manager = CacheManager() # for qwen config.attention_bias = True @@ -54,9 +53,14 @@ def __init__(self, config, is_merge: bool = True): self.num_decoder_layers = config.decoder_end_layer_idx - config.decoder_start_layer_idx self.rotary_emb = HFQwen2RotaryEmbedding(config=config) - self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) - self.num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads - self.head_dim = self.model.layers[-1].self_attn.head_dim + max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) + num_key_value_heads = self.model.layers[-1].self_attn.num_key_value_heads + head_dim = self.model.layers[-1].self_attn.head_dim + cache_manager.init_request_cache(self.num_decoder_layers, max_seq_len, num_key_value_heads, head_dim) + + is_start_pp = self.config.decoder_start_layer_idx == 0 + is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + cache_manager.post_init(is_start_pp, is_end_pp) @classmethod def from_pretrained(cls, config, state_dict: Dict[str, torch.Tensor], is_merge: bool = True, **kwargs): @@ -114,36 +118,21 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten @return: seq_len x hidden_size """ - attention_data = build_forward_cache( - seq_input, - self.cache_manager, - self.num_decoder_layers, - self.max_seq_len, - self.num_key_value_heads, - self.head_dim, + hidden_states = cache_manager.build_forward_cache(hidden_states, seq_input) + + position_embeddings = self.rotary_emb( + hidden_states, cache_manager.attn_data.position_ids.to(hidden_states.device) ) - hidden_states = hidden_states.to(DTYPE).to(DEVICE) - position_embeddings = self.rotary_emb(hidden_states, attention_data.position_ids.to(DEVICE)) - if ATTN_TYPE == "flash_attention": - attention_data.attn_mask = { - k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in attention_data.attn_mask.items() - } - else: - attention_data.attn_mask = attention_data.attn_mask.to(DEVICE) hidden_states = self.model( - hidden_states, position_embeddings=position_embeddings, attention_data=attention_data + hidden_states, position_embeddings=position_embeddings, attention_data=cache_manager.attn_data ) - if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - split_len_list = attention_data.q_len_list - hidden_states = get_last_hidden_states(hidden_states, split_len_list) - - for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) - self.cache_manager.check_alive() + # TODO 异步更新 cache + cache_manager.update_cache(seq_input) - return hidden_states + output = cache_manager.get_last_hidden_states(hidden_states) + return output class HFQwen2ForCausalLM(nn.Module):