From 5594ab27655a33ae8cf6a8cde9b968321345b862 Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 27 Jan 2025 18:27:39 +0800 Subject: [PATCH] add radix tree for token level prefill cache --- benchmarks/comm/main.py | 4 +- benchmarks/run_async_requests.py | 17 +- examples/run_engine.py | 32 +- scripts/rpc_compile.sh | 2 +- tests/test_radix_tree.py | 91 ++++++ tllm/__init__.py | 2 + tllm/commons/cache.py | 155 ++++++---- tllm/commons/radix_tree.py | 71 +++++ tllm/engine.py | 47 +-- tllm/entrypoints/server_chat.py | 26 +- tllm/generate/llm_generator.py | 16 +- tllm/generate/message_processor.py | 21 -- tllm/grpc/master_service/master_server.py | 2 +- tllm/grpc/master_service/pending_requests.py | 2 - tllm/grpc/master_service/worker_manager.py | 11 +- tllm/grpc/proto/schemas.proto | 13 +- tllm/grpc/proto/schemas_pb2.py | 65 ++--- tllm/grpc/proto/schemas_pb2.pyi | 36 ++- tllm/grpc/proto/schemas_pb2_grpc.py | 292 ++++++++++--------- tllm/grpc/worker_service/master_manager.py | 16 +- tllm/grpc/worker_service/worker_server.py | 4 +- tllm/models/mlx/flux/transformer.py | 8 +- tllm/models/mlx/helper.py | 25 +- tllm/models/mlx/layers.py | 2 +- tllm/models/mlx/llama.py | 27 +- tllm/models/mlx/qwen.py | 42 ++- tllm/models/tinygrad/llama.py | 12 +- tllm/models/torch/llama.py | 7 +- tllm/models/torch/qwen.py | 7 +- tllm/schemas.py | 27 +- tllm/utils.py | 4 +- 31 files changed, 649 insertions(+), 437 deletions(-) create mode 100644 tests/test_radix_tree.py create mode 100644 tllm/commons/radix_tree.py diff --git a/benchmarks/comm/main.py b/benchmarks/comm/main.py index 43add75..d5f6dca 100644 --- a/benchmarks/comm/main.py +++ b/benchmarks/comm/main.py @@ -79,7 +79,9 @@ def run_client_test(self, matrix_shape, num_iterations=3): with grpc.insecure_channel(f"{self.host}:{self.port}") as channel: stub = schemas_pb2_grpc.RPCServiceStub(channel) - request = schemas_pb2.ForwardRequest(uuid=["123"], seq_len=[1], hidden_states=compress_bytes(byte_tensor)) + request = schemas_pb2.ForwardRequest( + uuid_list=["123"], input_ids_list=[1], hidden_states=compress_bytes(byte_tensor) + ) for _ in range(num_iterations): start_time = time.time() diff --git a/benchmarks/run_async_requests.py b/benchmarks/run_async_requests.py index 659be1f..b2b5d23 100644 --- a/benchmarks/run_async_requests.py +++ b/benchmarks/run_async_requests.py @@ -39,6 +39,14 @@ def llm_message(): }, {"role": "user", "content": "今天天气怎么样?"}, ] + # messages2 = [ + # {"role": "user", "content": "Hello, how are you?"}, + # { + # "role": "assistant", + # "content": "Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with whatever you need. How are you doing? 😊", + # }, + # {"role": "user", "content": "今天天气怎么样?"}, + # ] messages_list = [messages1, messages2, messages2] return messages_list @@ -68,15 +76,16 @@ def mllm_message(): async def main(messages_list: List[List[Dict[str, Any]]]): - # print("异步并发请求结果") - # s1 = time.time() - # await asyncio.gather(*[requests_func(messages) for messages in messages_list]) - # print(f"time cost: {time.time() - s1:.4f} s") + print("异步并发请求结果") + s1 = time.time() + await asyncio.gather(*[requests_func(messages) for messages in messages_list]) + print(f"time cost: {time.time() - s1:.4f} s") print("单独请求结果") s1 = time.time() for message in messages_list: await requests_func(message) + print("=" * 20) print(f"time cost: {time.time() - s1:.4f} s") diff --git a/examples/run_engine.py b/examples/run_engine.py index 8ab9175..9fea04b 100644 --- a/examples/run_engine.py +++ b/examples/run_engine.py @@ -104,13 +104,13 @@ async def llm_generate(args, messages): engine = init_engine(args.model_path) await engine.start() messages = [{"role": "user", "content": "Hello, how are you?"}] - messages = [ - { - "role": "system", - "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", - }, - {"role": "user", "content": "hello"}, - ] + # messages = [ + # { + # "role": "system", + # "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", + # }, + # {"role": "user", "content": "hello"}, + # ] openai_serving_chat = OpenAIServing(engine, args) # for _ in range(3): @@ -126,15 +126,15 @@ async def llm_generate(args, messages): }, {"role": "user", "content": "今天天气怎么样?"}, ] - messages = [ - { - "role": "system", - "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", - }, - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "Hello! How can I assist you today?"}, - {"role": "user", "content": "今年天气怎么样"}, - ] + # messages = [ + # { + # "role": "system", + # "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", + # }, + # {"role": "user", "content": "hello"}, + # {"role": "assistant", "content": "Hello! How can I assist you today?"}, + # {"role": "user", "content": "今年天气怎么样"}, + # ] for _ in range(3): request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100) response = await openai_serving_chat.create_chat_completion(request, None) diff --git a/scripts/rpc_compile.sh b/scripts/rpc_compile.sh index 8ed04b8..2c9e46b 100644 --- a/scripts/rpc_compile.sh +++ b/scripts/rpc_compile.sh @@ -1 +1 @@ -python3 -m grpc_tools.protoc -I=. --python_out=./ --pyi_out=./ --grpc_python_out=./ tllm/entrypoints/grpc/proto/schemas.proto \ No newline at end of file +python3 -m grpc_tools.protoc -I=. --python_out=./ --pyi_out=./ --grpc_python_out=./ tllm/grpc/proto/schemas.proto \ No newline at end of file diff --git a/tests/test_radix_tree.py b/tests/test_radix_tree.py new file mode 100644 index 0000000..29cbc22 --- /dev/null +++ b/tests/test_radix_tree.py @@ -0,0 +1,91 @@ +from tllm.commons.radix_tree import RadixTree + +if __name__ == "__main__": + tree = RadixTree() + tree.append_to_request([151646, 151646, 151644, 9707, 11, 1246, 525, 498, 30, 151645], "123") + tree.append_to_request([151648], "123") + tree.append_to_request([271], "123") + tree.append_to_request([151649], "123") + tree.append_to_request([271], "123") + tree.append_to_request([9707], "123") + tree.append_to_request([0], "123") + tree.append_to_request([358], "123") + tree.append_to_request([2776], "123") + tree.append_to_request([1101], "123") + tree.append_to_request([264], "123") + tree.append_to_request([4108], "123") + tree.append_to_request([17847], "123") + tree.append_to_request([11], "123") + tree.append_to_request([773], "123") + + input_ids = [ + 151646, + 151646, + 151644, + 9707, + 11, + 1246, + 525, + 498, + 30, + 151645, + 9707, + 0, + 358, + 2776, + 1101, + 264, + 4108, + 17847, + 11, + 773, + 358, + 1513, + 944, + 614, + 15650, + 11, + 714, + 358, + 2776, + 1588, + 323, + 5527, + 311, + 1492, + 498, + 448, + 8820, + 498, + 1184, + 13, + 2585, + 525, + 498, + 3730, + 30, + 26525, + 232, + 151643, + 151644, + 100644, + 104307, + 104472, + 11319, + 151645, + ] + longest = tree.longest_common_prefix(input_ids) + print("longest common prefix:", longest) + print("hit input ids", input_ids[: longest[1]]) + + # longest = tree.longest_common_prefix([1, 2, 3, 4, 6, 7, 8, 9]) + # print("longest common prefix:", longest) + + # longest = tree.longest_common_prefix([1, 2, 3, 4, 6, 7, 8, 9]) + # print("longest common prefix:", longest) + + # longest = tree.longest_common_prefix([1, 2, 3]) + # print("longest common prefix:", longest) + tree.remove(tree.request_id_map["123"].path) + longest = tree.longest_common_prefix([1, 2, 3, 4]) + print("longest common prefix:", longest) diff --git a/tllm/__init__.py b/tllm/__init__.py index 65a9e3d..2bcdb0f 100644 --- a/tllm/__init__.py +++ b/tllm/__init__.py @@ -8,6 +8,8 @@ class BackendEnum(Enum): MLX = 2 +ENABLE_PREFIX_CACHE = os.environ.get("TLLM_ENABLE_PREFIX_CACHE", "true").lower() == "true" +ENABLE_PREFIX_CACHE = False if importlib.util.find_spec("mlx"): BACKEND = BackendEnum.MLX elif importlib.util.find_spec("torch"): diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index 23b0e1e..9f6a796 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -1,10 +1,11 @@ # coding: utf-8 import copy import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional -from tllm import BACKEND, DEVICE, DTYPE, BackendEnum -from tllm.schemas import MIX_TENSOR +from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFIX_CACHE, BackendEnum +from tllm.commons.radix_tree import RadixTree +from tllm.schemas import MIX_TENSOR, SeqInput if BACKEND == BackendEnum.MLX: import mlx.core as mx @@ -21,8 +22,6 @@ array_func = lambda x: torch.tensor([x], dtype=torch.long) arange_func = lambda x: torch.arange(0, x, dtype=torch.long) -KV_CACHE_TYPE = Tuple[MIX_TENSOR, MIX_TENSOR] - class KVCache: def __init__( @@ -35,21 +34,41 @@ def __init__( else: self.key_states = zeros_func(max_seq_len, num_key_value_heads, head_dim) self.value_states = zeros_func(max_seq_len, num_key_value_heads, head_dim) - self._len = 0 + self.kv_len = 0 + + def set_kv_len(self, kv_len: int): + self.kv_len = kv_len - def set_kv_len(self, len_: int): - self._len = len_ + +class DecoderCache: + def __init__( + self, num_layers: int, q_len: int, max_seq_len: int = -1, num_key_value_heads: int = -1, head_dim: int = -1 + ): + self.kv_cache_list = [KVCache(max_seq_len, num_key_value_heads, head_dim) for _ in range(num_layers)] + self.q_len = q_len @property def kv_len(self): - return self._len + return self.kv_cache_list[0].kv_len + + def __getitem__(self, idx: int) -> KVCache: + return self.kv_cache_list[idx] + + def truncate(self, len_: int): + for kv_cache in self.kv_cache_list: + kv_cache.key_states = kv_cache.key_states[:len_] + kv_cache.value_states = kv_cache.value_states[:len_] + kv_cache.set_kv_len(len_) + + def set_q_len(self, q_len: int): + self.q_len = q_len class RequestsCache: def __init__( self, num_layers: int, max_seq_len: int = -1, num_key_value_heads: int = -1, head_dim: int = -1 ) -> None: - self.cache_dict: Dict[str : Dict[str, Union[List[KVCache], int]]] = {} + self.cache_dict: Dict[str, DecoderCache] = {} self.num_layers = num_layers self.max_seq_len, self.num_key_value_heads, self.head_dim = max_seq_len, num_key_value_heads, head_dim if self.max_seq_len == -1: @@ -58,70 +77,85 @@ def __init__( else: # cat attention to save time self.update = self.update_no_cat + self.radix_tree = RadixTree() - def add(self, uuid: str, seq_len: int, layer_cache_list: Optional[List[KVCache]] = None): + def clear(self): + self.cache_dict.clear() + + def insert_cache(self, seq_input: SeqInput): + if ENABLE_PREFIX_CACHE: + for input_ids, request_id in zip(seq_input.input_ids_list, seq_input.uuid_list): + self.radix_tree.append_to_request(input_ids, request_id) + + def add(self, uuid: str, q_len: int, decoder_cache: Optional[DecoderCache] = None): # 保存每个 uuid 请求所有层的 cache - self.cache_dict[uuid] = { - "cache": ( - [KVCache(self.max_seq_len, self.num_key_value_heads, self.head_dim) for _ in range(self.num_layers)] - if layer_cache_list is None - else layer_cache_list - ), - "seq_len": seq_len, - } - - def build(self, seq_input, cache_manager): + self.cache_dict[uuid] = ( + DecoderCache(self.num_layers, q_len, self.max_seq_len, self.num_key_value_heads, self.head_dim) + if decoder_cache is None + else decoder_cache + ) + + def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): q_len_list, k_len_list = [], [] position_ids_list = [] - conv_len_list = [] + hit_cache_len_list = [] - for uuid, q_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - conv_len = -1 + for uuid, input_ids in zip(seq_input.uuid_list, seq_input.input_ids_list): + hit_cache_len = -1 + q_len = len(input_ids) # decoding 阶段 if q_len == 1 and uuid in cache_manager.cache_dict: - layer_cache_list, cache_seq_len = cache_manager.get(uuid) + decoder_cache: DecoderCache = cache_manager.get(uuid) + cache_seq_len = decoder_cache.kv_len position_ids = array_func(cache_seq_len) k_len_list.append(cache_seq_len + q_len) # prefilling 阶段 else: - # 如果是历史对话,则使用历史的 kv_cache - chat_uuid, chat_len = uuid.rsplit("-", 1) - if uuid.count("-") == 2 and chat_uuid in cache_manager.cache_dict: - layer_cache_list, cache_seq_len = cache_manager.get(chat_uuid) - layer_cache_list = copy.deepcopy(layer_cache_list) + # 命中了之前的 kv cache,使用历史 cache + if ENABLE_PREFIX_CACHE: + hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids) + else: + hit_uuid, hit_cache_len = None, -1 # 不启用 prefix cache + if hit_uuid is not None and cache_manager.get(hit_uuid) is not None: + hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid)) + # 相同请求时,避免过超过 cache 长度 + if q_len <= hit_cache_len: + hit_cache_len = q_len - 1 + + hid_decoder_cache.truncate(hit_cache_len) + hid_decoder_cache.set_q_len(q_len - hit_cache_len) + decoder_cache = hid_decoder_cache position_ids = arange_func(q_len) k_len_list.append(q_len) - conv_len = q_len - cache_seq_len - # 首次出现过的 uuid,第一次 conversation + # 未命中任何 kv cache,新建 cache else: - layer_cache_list = None + decoder_cache = None position_ids = arange_func(q_len) k_len_list.append(q_len) q_len_list.append(q_len) position_ids_list.append(position_ids) - conv_len_list.append(conv_len) + hit_cache_len_list.append(hit_cache_len) - self.add(uuid, q_len, layer_cache_list) - return q_len_list, k_len_list, position_ids_list, conv_len_list + self.add(uuid, q_len, decoder_cache) + return q_len_list, k_len_list, position_ids_list, hit_cache_len_list - def get_kv_cache(self, uuid: str) -> List[KVCache]: - return self.cache_dict[uuid]["cache"] + def get_kv_cache(self, uuid: str) -> DecoderCache: + return self.cache_dict[uuid] def get_layer_idx_kv_cache(self, uuid: str, layer_idx: int) -> KVCache: return self.get_kv_cache(uuid)[layer_idx] - def get_seq_len(self, uuid: str) -> int: - # 获取每个 uuid 请求的 key_states/value_states 的 seq_len - return self.cache_dict[uuid]["seq_len"] + def get_q_len(self, uuid: str) -> int: + # 获取每个 uuid 请求的 q_len + return self.get_kv_cache(uuid).q_len - def get_cache_seq_len(self, uuid: str, layer_idx: Optional[int] = 0) -> int: - # 获取每个 uuid 请求的 kv cache 的 seq_len - x = self.get_kv_cache(uuid)[layer_idx].kv_len - return x + def get_kv_len(self, uuid: str, layer_idx: Optional[int] = 0) -> int: + # 获取每个 uuid 请求的 kv cache 的 kv_len + return self.get_layer_idx_kv_cache(uuid, layer_idx).kv_len def get_offset_list(self, uuid_list: List[str], layer_idx: int) -> List[int]: # 获取每个 uuid 请求的 offset,用于 mlx framework 旋转位置编码 - return [self.get_cache_seq_len(uuid, layer_idx) for uuid in uuid_list] + return [self.get_kv_len(uuid, layer_idx) for uuid in uuid_list] def update_cat( self, @@ -136,7 +170,7 @@ def update_cat( start = 0 for uuid in uuid_list: kv_cache: KVCache = self.get_layer_idx_kv_cache(uuid, layer_idx) - interval = self.get_seq_len(uuid) + interval = self.get_q_len(uuid) end = start + interval cur_key_states, cur_value_states = key_states[start:end], value_states[start:end] if kv_cache.key_states is None: @@ -166,7 +200,7 @@ def update_no_cat( for uuid in uuid_list: kv_cache: KVCache = self.get_layer_idx_kv_cache(uuid, layer_idx) - end = start + self.get_seq_len(uuid) # 获取每个请求对应的区间 + end = start + self.get_q_len(uuid) # 获取每个请求对应的区间 cur_key_states, cur_value_states = key_states[start:end], value_states[start:end] if kv_cache.kv_len == 0: @@ -199,7 +233,7 @@ def update_tinygrad(self, key_states, value_states, uuid_list, layer_idx): for uuid in uuid_list: kv_cache = self.get_layer_idx_kv_cache(uuid, layer_idx) - interval = self.get_seq_len(uuid) + interval = self.get_q_len(uuid) end = start + interval cur_key, cur_value = key_states[start:end], value_states[start:end] @@ -225,31 +259,40 @@ def __init__( request_cache: RequestsCache, attn_mask: MIX_TENSOR, position_ids=None, + hit_cache_len_list=None, + q_len_list=None, ) -> None: self.uuid_list = uuid_list self.request_cache = request_cache self.attn_mask = attn_mask self.position_ids = position_ids # 只在 torch 下有意义 + self.hit_cache_len_list = hit_cache_len_list # 用于 PP=0 截断 hidden_states + self.q_len_list = q_len_list def get_kv_cache_list(self, uuid: str) -> List[KVCache]: return self.request_cache.get_kv_cache(uuid) - def get_cache_seq_len(self, uuid: str) -> int: - return self.request_cache.get_cache_seq_len(uuid) + def get_kv_len(self, uuid: str) -> int: + return self.request_cache.get_kv_len(uuid) class CacheManager: - # 管理每个节点的 cache kv_cache + # 管理每个节点的所有层 kv_cache # max_alive_time: 超过多久没有访问就删除,单位秒 def __init__(self, max_alive_time=60): self.max_alive_time = max_alive_time self.cache_dict = {} - def get(self, key) -> Tuple[AttentionData, int]: - return self.cache_dict.get(key)["cache"], self.cache_dict.get(key)["seq_len"] + def get(self, key) -> Optional[DecoderCache]: + if self.is_contain(key): + return self.cache_dict.get(key)["cache"] + return None + + def set(self, key, value: DecoderCache) -> None: + self.cache_dict[key] = {"cache": value, "ts": time.time()} - def set(self, key, value: List[KV_CACHE_TYPE], seq_len: int) -> None: - self.cache_dict[key] = {"cache": value, "ts": time.time(), "seq_len": seq_len} + def is_contain(self, key) -> bool: + return key in self.cache_dict def delete(self, key): self.cache_dict.pop(key) diff --git a/tllm/commons/radix_tree.py b/tllm/commons/radix_tree.py new file mode 100644 index 0000000..d930432 --- /dev/null +++ b/tllm/commons/radix_tree.py @@ -0,0 +1,71 @@ +from typing import Dict, List, Optional, Tuple + + +class Node: + def __init__(self, request_id: str): + self.children: Dict[int, Node] = {} + self.is_end = False + self.path = None + self.request_id = request_id + + def __repr__(self): + return f"Node({self.request_id}): path={self.path}; is_end={self.is_end}" + + +class RadixTree: + def __init__(self): + self.root = Node(None) # 根节点 + self.request_id_map: Dict[str, Node] = {} + + def insert(self, input_ids: List[int], request_id: str): + node = self.root + path = [] + for id_ in input_ids: + if id_ not in node.children: + node.children[id_] = Node(request_id) + node = node.children[id_] + path.append(id_) + node.path = path[:] + node.is_end = True + self.request_id_map[request_id] = node + + def append_to_request(self, input_ids: List[int], request_id: str): + if request_id not in self.request_id_map: + self.insert(input_ids, request_id) + return + node = self.request_id_map.pop(request_id) + path = node.path + node.is_end = False + for id_ in input_ids: + if id_ not in node.children: + node.children[id_] = Node(request_id) + node = node.children[id_] + path.append(id_) + node.path = path[:] + node.is_end = True + self.request_id_map[request_id] = node + + def longest_common_prefix(self, input_ids: List[int]) -> Tuple[Optional[str], int]: + # 返回最长的公共前缀 + node = self.root + longest = [] + for id_ in input_ids: + if id_ not in node.children: + return node.request_id, len(longest) - 1 if len(longest) > 0 else -1 + node = node.children[id_] + if node.path is not None and len(node.path) > len(longest): + longest = node.path[:] + return node.request_id, len(longest) - 1 if len(longest) > 0 else -1 + + def remove(self, input_ids: List[int]): + # 删除节点 + node = self.root + for id_ in input_ids: + if id_ not in node.children: + return + node = node.children[id_] + node.is_end = False + if len(node.children) == 0: + del self.request_id_map[node.request_id] + return + node.request_id = None diff --git a/tllm/engine.py b/tllm/engine.py index 39f62ba..245897e 100644 --- a/tllm/engine.py +++ b/tllm/engine.py @@ -7,49 +7,6 @@ from tllm.schemas import SequenceRequestData from tllm.singleton_logger import SingletonLogger -conversations_dict = {} # List[int] -> str, TODO LRU 缓存 - - -class Node: - def __init__(self): - self.children = {} # int -> Node - self.is_end_of_word = False # 是否是单词的结束 - self.path = None - - -class RadixTree: - def __init__(self): - self.root = Node() # 根节点 - - def insert(self, input_ids: List[int]): - node = self.root - path = [] - for id_ in input_ids: - if id_ not in node.children: - node.children[id_] = Node() - node = node.children[id_] - path.append(id_) - node.path = path[:] - node.is_end_of_word = True - - def longest_common_prefix(self, input_ids: List[int]) -> List[int]: - node = self.root - longest = [] - for id_ in input_ids: - if id_ not in node.children: - return longest - node = node.children[id_] - if node.path is not None and len(node.path) > len(longest): - longest = node.path[:] - return longest - - -def post_process(data: SequenceRequestData): - # 保存输入 + 输出 - token_ids = data.input_ids + data.output_ids - conversations_dict[token_ids] = data.history_request_id if data.history_request_id else data.request_id - return - class AsyncEngine: def __init__(self, generator: Union[LLMGenerator, ImageGenerator], sleep_time: float = 0.0, limit_size: int = 5): @@ -143,8 +100,8 @@ async def generate_stream(self, request_data: SequenceRequestData): async with request_data.condition: while not request_data.is_stop: await asyncio.wait_for(request_data.condition.wait(), request_data.timeout) - yield request_data.to_request_output() # 流式返回数据的内容,可以控制 - # post_process(request_data) + # 流式返回数据的内容 + yield request_data.to_request_output() try: if hasattr(request_data, "ttft_cost_time"): self.logger.info( diff --git a/tllm/entrypoints/server_chat.py b/tllm/entrypoints/server_chat.py index 0d919ab..1c0ec67 100644 --- a/tllm/entrypoints/server_chat.py +++ b/tllm/entrypoints/server_chat.py @@ -48,12 +48,7 @@ async def show_available_models(self): async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request): request_id = f"chat-{random_uuid()}" messages, mm_input_dict = await self.message_processor.parse_message(request.messages) - # print("messages", messages) input_ids = self.message_processor.preprocess(messages) - history_request_id, q_len = self.message_processor.fetch_request_id(messages) - - if history_request_id is not None: - request_id = history_request_id if request.temperature == 0.0: method = "greedy" @@ -62,8 +57,6 @@ async def create_chat_completion(self, request: ChatCompletionRequest, raw_reque sequence_data = SequenceRequestData( request_id=request_id, - history_request_id=history_request_id, - q_len=q_len, sampling_params=request.to_sampling_params(self.engine.tok.tokenizer), input_ids=input_ids, multi_modal_inputs=mm_input_dict, @@ -71,11 +64,9 @@ async def create_chat_completion(self, request: ChatCompletionRequest, raw_reque result_generator = self.engine.generate_stream(sequence_data) if request.stream: - return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator, messages) + return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator) else: - return await self.chat_completion_full_generator( - request, raw_request, request_id, result_generator, messages - ) + return await self.chat_completion_full_generator(request, raw_request, request_id, result_generator) async def chat_completion_stream_generator( self, @@ -83,14 +74,11 @@ async def chat_completion_stream_generator( raw_request: Request, request_id: str, result_generator: AsyncIterator, - messages, ) -> AsyncIterator[str]: created_time = int(time.time()) n = 1 previous_texts = [""] * n try: - response_text = "" - num_generated_tokens = 0 async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. @@ -102,8 +90,6 @@ async def chat_completion_stream_generator( delta_text = output.text previous_texts[i] = output.text - num_prompt_tokens = len(res.prompt_token_ids) - num_generated_tokens += 1 # 根据 finish_reason 判断是否结束,分别处理 if output.finish_reason is not None: @@ -117,16 +103,11 @@ async def chat_completion_stream_generator( logprobs=None, finish_reason=output.finish_reason, ) - response_text += delta_text chunk = ChatCompletionStreamResponse( id=request_id, model=self.model_name, created=created_time, choices=[choice_data] ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" - - messages.append({"role": self.response_role, "content": response_text}) - total_tokens = num_prompt_tokens + num_generated_tokens - self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 2) yield "data: [DONE]\n\n" except asyncio.CancelledError: await self.engine.abort(request_id) @@ -138,7 +119,6 @@ async def chat_completion_full_generator( raw_request: Request, request_id: str, result_generator: AsyncIterator, - messages, ) -> ChatCompletionResponse: final_res = None created_time = int(time.time()) @@ -182,6 +162,4 @@ async def chat_completion_full_generator( usage=usage, ) - messages.append({"role": self.response_role, "content": output.text}) - self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 1) return response diff --git a/tllm/generate/llm_generator.py b/tllm/generate/llm_generator.py index 30c7052..4eecd6f 100644 --- a/tllm/generate/llm_generator.py +++ b/tllm/generate/llm_generator.py @@ -116,18 +116,8 @@ async def generate(self, request_list: List[SequenceRequestData]): if self.processor is not None: mm_input_list.append(process_mm_input(sequence_request, self.processor, **self.mm_config)) - # 已经存在的 message,需要更新使用之前的 request_id - if sequence_request.history_request_id is not None: - uuid_list[-1] = sequence_request.history_request_id - new_q_len = len(sequence_request.input_ids) - sequence_request.q_len - # print("new_q_len", len(sequence_request.input_ids[-new_q_len:])) - input_ids_list.append(np.array(sequence_request.input_ids[-new_q_len:])) - # print("q_len", len(sequence_request.input_ids)) - # input_ids_list.append(np.array(sequence_request.input_ids)) - seq_len_list.append(len(sequence_request.input_ids)) - else: - input_ids_list.append(np.array(sequence_request.input_ids)) - seq_len_list.append(len(sequence_request.input_ids)) + input_ids_list.append(np.array(sequence_request.input_ids)) + seq_len_list.append(len(sequence_request.input_ids)) else: input_ids_list.append(np.array([sequence_request.output_ids[-1]])) seq_len_list.append(1) @@ -140,7 +130,7 @@ async def generate(self, request_list: List[SequenceRequestData]): else: input_embeds = self.model.get_input_embeddings(input_ids, **mm_input) - seq_input = SeqInput(uuid_list=uuid_list, seq_len_list=seq_len_list) + seq_input = SeqInput(uuid_list=uuid_list, input_ids_list=input_ids_list) s0 = time.perf_counter() forward_result = await self.forward(input_embeds, seq_input) self.logger.debug(f"decoder cost time: {time.perf_counter() - s0:.4f}s") diff --git a/tllm/generate/message_processor.py b/tllm/generate/message_processor.py index 90b3a6a..cf3e443 100644 --- a/tllm/generate/message_processor.py +++ b/tllm/generate/message_processor.py @@ -14,7 +14,6 @@ class MessageProcessor: def __init__(self, tok: TokenizerUtils): self.tok = tok self.role_set = {"user", "system", "assistant"} - self.conversations_dict: Dict[MESSAGES, str] = {} async def read_image(self, image: UrlItem) -> ImageFile: if image.base64 is not None: @@ -65,23 +64,3 @@ async def parse_message(self, messages: MESSAGES) -> Tuple[List[Dict[str, str]], def preprocess(self, messages: List[Dict[str, str]]) -> List[int]: input_ids = self.tok.preprocess(messages=messages).input_ids return input_ids - - def fetch_request_id(self, messages: List[MESSAGES]) -> Optional[str]: - messages = copy.deepcopy(messages) - while len(messages) > 0 and messages[-1]["role"] == "user": - messages.pop() - if len(messages) == 0: - return None, -1 - # 查询历史 messages 是否在 conversations_dict 中 - key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids) - if key_str in self.conversations_dict: - request_id, total_token_num = self.conversations_dict[key_str] - # TODO: maybe have the bug - if len(key_str) <= total_token_num: - return None, -1 - return request_id + f"-{len(messages)}", total_token_num - return None, -1 - - def update_conversations_dict(self, request_id: str, messages: List[MESSAGES], total_token_num: int) -> None: - key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids) - self.conversations_dict[key_str] = (request_id, total_token_num) diff --git a/tllm/grpc/master_service/master_server.py b/tllm/grpc/master_service/master_server.py index f8eb570..d55ea0c 100644 --- a/tllm/grpc/master_service/master_server.py +++ b/tllm/grpc/master_service/master_server.py @@ -36,7 +36,7 @@ async def Forward( self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerContext ) -> schemas_pb2.ForwardResponse: """处理从最后一个节点返回的结果""" - request_id = "-".join(x for x in list(request.uuid)) + request_id = "-".join(x for x in list(request.uuid_list)) try: self.pending_requests.complete_forward_request(request_id, request.hidden_states) diff --git a/tllm/grpc/master_service/pending_requests.py b/tllm/grpc/master_service/pending_requests.py index 31776b0..49b1a00 100644 --- a/tllm/grpc/master_service/pending_requests.py +++ b/tllm/grpc/master_service/pending_requests.py @@ -1,4 +1,3 @@ - import asyncio from typing import Any, Dict, Tuple @@ -60,4 +59,3 @@ def fail_status_request(self, trace_id: str, error: Exception): if not tracker.future.done(): tracker.future.set_exception(error) del self._status_requests[trace_id] - diff --git a/tllm/grpc/master_service/worker_manager.py b/tllm/grpc/master_service/worker_manager.py index 869bdf5..3e8f807 100644 --- a/tllm/grpc/master_service/worker_manager.py +++ b/tllm/grpc/master_service/worker_manager.py @@ -32,8 +32,9 @@ async def rpc_image_forward( stub.ImageForward(schemas_pb2.ImageForwardRequest(**forward_request)) -async def rpc_forward(stub, uuid, seq_len, hidden_states: schemas_pb2.BFloat16Tensor): - forward_request = {"uuid": uuid, "seq_len": seq_len, "hidden_states": hidden_states} +async def rpc_forward(stub, seq_input: SeqInput, hidden_states: schemas_pb2.BFloat16Tensor): + forward_request = seq_input.to_dict() + forward_request["hidden_states"] = hidden_states stub.Forward(schemas_pb2.ForwardRequest(**forward_request)) @@ -69,7 +70,7 @@ async def send_config(self, master_url: str, host_list: List[List[str]]): assert len(host_list) == self.client_size async def set_single_config(i: int) -> None: - url = master_url if i == self.client_size - 1 else host_list[i + 1][0] # TODO: 对于多个 PP,这里有 bug + url = master_url if i == self.client_size - 1 else host_list[i + 1][0] # TODO: 对于多个 PP,这里有 bug for stub in self.stub_list[i]: await rpc_set_config(stub, {"forward_url": url, "master_url": master_url, "pp_rank": i}) @@ -79,7 +80,7 @@ async def set_single_config(i: int) -> None: async def health_check(self) -> Tuple[int]: async def check_single_client(index: int) -> Tuple[int, bool]: try: - await rpc_health_check(self.stub_list[index][0]) # 只需要检查一个 + await rpc_health_check(self.stub_list[index][0]) # 只需要检查一个 return (index, True) except Exception as e: return (index, False) @@ -101,7 +102,7 @@ async def forward(self, hidden_states: MIX_TENSOR, seq_input: SeqInput) -> Tuple "-".join(x for x in seq_input.uuid_list), self.client_size ) for stub in self.stub_list[0]: - asyncio.create_task(rpc_forward(stub, seq_input.uuid_list, seq_input.seq_len_list, hidden_states)) + asyncio.create_task(rpc_forward(stub, seq_input, hidden_states)) await asyncio.sleep(0) try: output = await asyncio.wait_for(forward_future, timeout=PP_TIMEOUT) diff --git a/tllm/grpc/proto/schemas.proto b/tllm/grpc/proto/schemas.proto index b8a23fb..4268707 100644 --- a/tllm/grpc/proto/schemas.proto +++ b/tllm/grpc/proto/schemas.proto @@ -8,17 +8,20 @@ message BFloat16Tensor { repeated int32 shape = 2; // 形状信息 } +message InputIds { + repeated int32 input_ids = 1; +} + message ForwardRequest { - repeated string uuid = 1; - repeated int32 seq_len = 2; + repeated string uuid_list = 1; + repeated InputIds input_ids_list = 2; BFloat16Tensor hidden_states = 3; } message StatusRequest { repeated string uuid = 1; - repeated int32 seq_len = 2; - int32 pp_idx = 3; - float cost_time = 4; + int32 pp_idx = 2; + float cost_time = 3; } message StatusResponse { diff --git a/tllm/grpc/proto/schemas_pb2.py b/tllm/grpc/proto/schemas_pb2.py index eff3e06..e169235 100644 --- a/tllm/grpc/proto/schemas_pb2.py +++ b/tllm/grpc/proto/schemas_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE -# source: tllm/entrypoints/grpc/proto/schemas.proto +# source: tllm/grpc/proto/schemas.proto # Protobuf Python Version: 5.28.1 """Generated protocol buffer code.""" from google.protobuf import ( @@ -13,47 +13,44 @@ from google.protobuf.internal import builder as _builder _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 28, - 1, - '', - 'tllm/entrypoints/grpc/proto/schemas.proto' + _runtime_version.Domain.PUBLIC, 5, 28, 1, "", "tllm/grpc/proto/schemas.proto" ) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)tllm/entrypoints/grpc/proto/schemas.proto\x12\x07schemas\"-\n\x0e\x42\x46loat16Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\"_\n\x0e\x46orwardRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12\x0f\n\x07seq_len\x18\x02 \x03(\x05\x12.\n\rhidden_states\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\"Q\n\rStatusRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12\x0f\n\x07seq_len\x18\x02 \x03(\x05\x12\x0e\n\x06pp_idx\x18\x03 \x01(\x05\x12\x11\n\tcost_time\x18\x04 \x01(\x02\"-\n\x0eStatusResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\".\n\x0f\x46orwardResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\"-\n\x0eHealthResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\"\x07\n\x05\x45mpty\"L\n\x10SetConfigRequest\x12\x13\n\x0b\x66orward_url\x18\x01 \x01(\t\x12\x12\n\nmaster_url\x18\x02 \x01(\t\x12\x0f\n\x07pp_rank\x18\x03 \x01(\x05\"0\n\x11SetConfigResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\"\xed\x01\n\x13ImageForwardRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12.\n\rhidden_states\x18\x02 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x36\n\x15\x65ncoder_hidden_states\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x30\n\x0ftext_embeddings\x18\x04 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x0f\n\x07seq_len\x18\x05 \x01(\x05\x12\x0e\n\x06height\x18\x06 \x01(\x05\x12\r\n\x05width\x18\x07 \x01(\x05\x32\xc4\x02\n\nRPCService\x12\x39\n\x06Status\x12\x16.schemas.StatusRequest\x1a\x17.schemas.StatusResponse\x12<\n\x07\x46orward\x12\x17.schemas.ForwardRequest\x1a\x18.schemas.ForwardResponse\x12\x31\n\x06Health\x12\x0e.schemas.Empty\x1a\x17.schemas.HealthResponse\x12\x42\n\tSetConfig\x12\x19.schemas.SetConfigRequest\x1a\x1a.schemas.SetConfigResponse\x12\x46\n\x0cImageForward\x12\x1c.schemas.ImageForwardRequest\x1a\x18.schemas.ForwardResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1dtllm/grpc/proto/schemas.proto\x12\x07schemas"-\n\x0e\x42\x46loat16Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05"\x1d\n\x08InputIds\x12\x11\n\tinput_ids\x18\x01 \x03(\x05"~\n\x0e\x46orwardRequest\x12\x11\n\tuuid_list\x18\x01 \x03(\t\x12)\n\x0einput_ids_list\x18\x02 \x03(\x0b\x32\x11.schemas.InputIds\x12.\n\rhidden_states\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor"@\n\rStatusRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12\x0e\n\x06pp_idx\x18\x02 \x01(\x05\x12\x11\n\tcost_time\x18\x03 \x01(\x02"-\n\x0eStatusResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05".\n\x0f\x46orwardResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"-\n\x0eHealthResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"\x07\n\x05\x45mpty"L\n\x10SetConfigRequest\x12\x13\n\x0b\x66orward_url\x18\x01 \x01(\t\x12\x12\n\nmaster_url\x18\x02 \x01(\t\x12\x0f\n\x07pp_rank\x18\x03 \x01(\x05"0\n\x11SetConfigResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"\xed\x01\n\x13ImageForwardRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12.\n\rhidden_states\x18\x02 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x36\n\x15\x65ncoder_hidden_states\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x30\n\x0ftext_embeddings\x18\x04 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x0f\n\x07seq_len\x18\x05 \x01(\x05\x12\x0e\n\x06height\x18\x06 \x01(\x05\x12\r\n\x05width\x18\x07 \x01(\x05\x32\xc4\x02\n\nRPCService\x12\x39\n\x06Status\x12\x16.schemas.StatusRequest\x1a\x17.schemas.StatusResponse\x12<\n\x07\x46orward\x12\x17.schemas.ForwardRequest\x1a\x18.schemas.ForwardResponse\x12\x31\n\x06Health\x12\x0e.schemas.Empty\x1a\x17.schemas.HealthResponse\x12\x42\n\tSetConfig\x12\x19.schemas.SetConfigRequest\x1a\x1a.schemas.SetConfigResponse\x12\x46\n\x0cImageForward\x12\x1c.schemas.ImageForwardRequest\x1a\x18.schemas.ForwardResponseb\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tllm.entrypoints.grpc.proto.schemas_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "tllm.grpc.proto.schemas_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_BFLOAT16TENSOR']._serialized_start=54 - _globals['_BFLOAT16TENSOR']._serialized_end=99 - _globals['_FORWARDREQUEST']._serialized_start=101 - _globals['_FORWARDREQUEST']._serialized_end=196 - _globals['_STATUSREQUEST']._serialized_start=198 - _globals['_STATUSREQUEST']._serialized_end=279 - _globals['_STATUSRESPONSE']._serialized_start=281 - _globals['_STATUSRESPONSE']._serialized_end=326 - _globals['_FORWARDRESPONSE']._serialized_start=328 - _globals['_FORWARDRESPONSE']._serialized_end=374 - _globals['_HEALTHRESPONSE']._serialized_start=376 - _globals['_HEALTHRESPONSE']._serialized_end=421 - _globals['_EMPTY']._serialized_start=423 - _globals['_EMPTY']._serialized_end=430 - _globals['_SETCONFIGREQUEST']._serialized_start=432 - _globals['_SETCONFIGREQUEST']._serialized_end=508 - _globals['_SETCONFIGRESPONSE']._serialized_start=510 - _globals['_SETCONFIGRESPONSE']._serialized_end=558 - _globals['_IMAGEFORWARDREQUEST']._serialized_start=561 - _globals['_IMAGEFORWARDREQUEST']._serialized_end=798 - _globals['_RPCSERVICE']._serialized_start=801 - _globals['_RPCSERVICE']._serialized_end=1125 + DESCRIPTOR._loaded_options = None + _globals["_BFLOAT16TENSOR"]._serialized_start = 42 + _globals["_BFLOAT16TENSOR"]._serialized_end = 87 + _globals["_INPUTIDS"]._serialized_start = 89 + _globals["_INPUTIDS"]._serialized_end = 118 + _globals["_FORWARDREQUEST"]._serialized_start = 120 + _globals["_FORWARDREQUEST"]._serialized_end = 246 + _globals["_STATUSREQUEST"]._serialized_start = 248 + _globals["_STATUSREQUEST"]._serialized_end = 312 + _globals["_STATUSRESPONSE"]._serialized_start = 314 + _globals["_STATUSRESPONSE"]._serialized_end = 359 + _globals["_FORWARDRESPONSE"]._serialized_start = 361 + _globals["_FORWARDRESPONSE"]._serialized_end = 407 + _globals["_HEALTHRESPONSE"]._serialized_start = 409 + _globals["_HEALTHRESPONSE"]._serialized_end = 454 + _globals["_EMPTY"]._serialized_start = 456 + _globals["_EMPTY"]._serialized_end = 463 + _globals["_SETCONFIGREQUEST"]._serialized_start = 465 + _globals["_SETCONFIGREQUEST"]._serialized_end = 541 + _globals["_SETCONFIGRESPONSE"]._serialized_start = 543 + _globals["_SETCONFIGRESPONSE"]._serialized_end = 591 + _globals["_IMAGEFORWARDREQUEST"]._serialized_start = 594 + _globals["_IMAGEFORWARDREQUEST"]._serialized_end = 831 + _globals["_RPCSERVICE"]._serialized_start = 834 + _globals["_RPCSERVICE"]._serialized_end = 1158 # @@protoc_insertion_point(module_scope) diff --git a/tllm/grpc/proto/schemas_pb2.pyi b/tllm/grpc/proto/schemas_pb2.pyi index 778afe8..bc45984 100644 --- a/tllm/grpc/proto/schemas_pb2.pyi +++ b/tllm/grpc/proto/schemas_pb2.pyi @@ -1,13 +1,7 @@ -from typing import ( - ClassVar as _ClassVar, - Iterable as _Iterable, - Mapping as _Mapping, - Optional as _Optional, - Union as _Union, -) - -from google.protobuf import descriptor as _descriptor, message as _message from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor @@ -19,27 +13,31 @@ class BFloat16Tensor(_message.Message): shape: _containers.RepeatedScalarFieldContainer[int] def __init__(self, data: _Optional[bytes] = ..., shape: _Optional[_Iterable[int]] = ...) -> None: ... +class InputIds(_message.Message): + __slots__ = ("input_ids",) + INPUT_IDS_FIELD_NUMBER: _ClassVar[int] + input_ids: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, input_ids: _Optional[_Iterable[int]] = ...) -> None: ... + class ForwardRequest(_message.Message): - __slots__ = ("uuid", "seq_len", "hidden_states") - UUID_FIELD_NUMBER: _ClassVar[int] - SEQ_LEN_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("uuid_list", "input_ids_list", "hidden_states") + UUID_LIST_FIELD_NUMBER: _ClassVar[int] + INPUT_IDS_LIST_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - uuid: _containers.RepeatedScalarFieldContainer[str] - seq_len: _containers.RepeatedScalarFieldContainer[int] + uuid_list: _containers.RepeatedScalarFieldContainer[str] + input_ids_list: _containers.RepeatedCompositeFieldContainer[InputIds] hidden_states: BFloat16Tensor - def __init__(self, uuid: _Optional[_Iterable[str]] = ..., seq_len: _Optional[_Iterable[int]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...) -> None: ... + def __init__(self, uuid_list: _Optional[_Iterable[str]] = ..., input_ids_list: _Optional[_Iterable[_Union[InputIds, _Mapping]]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...) -> None: ... class StatusRequest(_message.Message): - __slots__ = ("uuid", "seq_len", "pp_idx", "cost_time") + __slots__ = ("uuid", "pp_idx", "cost_time") UUID_FIELD_NUMBER: _ClassVar[int] - SEQ_LEN_FIELD_NUMBER: _ClassVar[int] PP_IDX_FIELD_NUMBER: _ClassVar[int] COST_TIME_FIELD_NUMBER: _ClassVar[int] uuid: _containers.RepeatedScalarFieldContainer[str] - seq_len: _containers.RepeatedScalarFieldContainer[int] pp_idx: int cost_time: float - def __init__(self, uuid: _Optional[_Iterable[str]] = ..., seq_len: _Optional[_Iterable[int]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ...) -> None: ... + def __init__(self, uuid: _Optional[_Iterable[str]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ...) -> None: ... class StatusResponse(_message.Message): __slots__ = ("msg", "status") diff --git a/tllm/grpc/proto/schemas_pb2_grpc.py b/tllm/grpc/proto/schemas_pb2_grpc.py index 2878f2b..8325a36 100644 --- a/tllm/grpc/proto/schemas_pb2_grpc.py +++ b/tllm/grpc/proto/schemas_pb2_grpc.py @@ -4,25 +4,26 @@ import grpc -from tllm.grpc.proto import schemas_pb2 as tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2 +from tllm.grpc.proto import schemas_pb2 as tllm_dot_grpc_dot_proto_dot_schemas__pb2 -GRPC_GENERATED_VERSION = '1.68.1' +GRPC_GENERATED_VERSION = "1.68.1" GRPC_VERSION = grpc.__version__ _version_not_supported = False try: from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) except ImportError: _version_not_supported = True if _version_not_supported: raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in tllm/entrypoints/grpc/proto/schemas_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + f"The grpc package installed is at version {GRPC_VERSION}," + + f" but the generated code in tllm/grpc/proto/schemas_pb2_grpc.py depends on" + + f" grpcio>={GRPC_GENERATED_VERSION}." + + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." ) @@ -36,30 +37,35 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Status = channel.unary_unary( - '/schemas.RPCService/Status', - request_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.SerializeToString, - response_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.FromString, - _registered_method=True) + "/schemas.RPCService/Status", + request_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.SerializeToString, + response_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.FromString, + _registered_method=True, + ) self.Forward = channel.unary_unary( - '/schemas.RPCService/Forward', - request_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.SerializeToString, - response_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, - _registered_method=True) + "/schemas.RPCService/Forward", + request_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.SerializeToString, + response_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, + _registered_method=True, + ) self.Health = channel.unary_unary( - '/schemas.RPCService/Health', - request_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.Empty.SerializeToString, - response_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.FromString, - _registered_method=True) + "/schemas.RPCService/Health", + request_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.Empty.SerializeToString, + response_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.FromString, + _registered_method=True, + ) self.SetConfig = channel.unary_unary( - '/schemas.RPCService/SetConfig', - request_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.SerializeToString, - response_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.FromString, - _registered_method=True) + "/schemas.RPCService/SetConfig", + request_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.SerializeToString, + response_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.FromString, + _registered_method=True, + ) self.ImageForward = channel.unary_unary( - '/schemas.RPCService/ImageForward', - request_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.SerializeToString, - response_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, - _registered_method=True) + "/schemas.RPCService/ImageForward", + request_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.SerializeToString, + response_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, + _registered_method=True, + ) class RPCServiceServicer(object): @@ -68,89 +74,90 @@ class RPCServiceServicer(object): def Status(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def Forward(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def Health(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def SetConfig(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def ImageForward(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_RPCServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'Status': grpc.unary_unary_rpc_method_handler( - servicer.Status, - request_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.FromString, - response_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.SerializeToString, - ), - 'Forward': grpc.unary_unary_rpc_method_handler( - servicer.Forward, - request_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.FromString, - response_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.SerializeToString, - ), - 'Health': grpc.unary_unary_rpc_method_handler( - servicer.Health, - request_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.Empty.FromString, - response_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.SerializeToString, - ), - 'SetConfig': grpc.unary_unary_rpc_method_handler( - servicer.SetConfig, - request_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.FromString, - response_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.SerializeToString, - ), - 'ImageForward': grpc.unary_unary_rpc_method_handler( - servicer.ImageForward, - request_deserializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.FromString, - response_serializer=tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.SerializeToString, - ), + "Status": grpc.unary_unary_rpc_method_handler( + servicer.Status, + request_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.FromString, + response_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.SerializeToString, + ), + "Forward": grpc.unary_unary_rpc_method_handler( + servicer.Forward, + request_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.FromString, + response_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.SerializeToString, + ), + "Health": grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.Empty.FromString, + response_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.SerializeToString, + ), + "SetConfig": grpc.unary_unary_rpc_method_handler( + servicer.SetConfig, + request_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.FromString, + response_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.SerializeToString, + ), + "ImageForward": grpc.unary_unary_rpc_method_handler( + servicer.ImageForward, + request_deserializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.FromString, + response_serializer=tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler( - 'schemas.RPCService', rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler("schemas.RPCService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('schemas.RPCService', rpc_method_handlers) + server.add_registered_method_handlers("schemas.RPCService", rpc_method_handlers) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class RPCService(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Status(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Status( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/schemas.RPCService/Status', - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.SerializeToString, - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.FromString, + "/schemas.RPCService/Status", + tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusRequest.SerializeToString, + tllm_dot_grpc_dot_proto_dot_schemas__pb2.StatusResponse.FromString, options, channel_credentials, insecure, @@ -159,25 +166,28 @@ def Status(request, wait_for_ready, timeout, metadata, - _registered_method=True) + _registered_method=True, + ) @staticmethod - def Forward(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Forward( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/schemas.RPCService/Forward', - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.SerializeToString, - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, + "/schemas.RPCService/Forward", + tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardRequest.SerializeToString, + tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, options, channel_credentials, insecure, @@ -186,25 +196,28 @@ def Forward(request, wait_for_ready, timeout, metadata, - _registered_method=True) + _registered_method=True, + ) @staticmethod - def Health(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def Health( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/schemas.RPCService/Health', - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.Empty.SerializeToString, - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.FromString, + "/schemas.RPCService/Health", + tllm_dot_grpc_dot_proto_dot_schemas__pb2.Empty.SerializeToString, + tllm_dot_grpc_dot_proto_dot_schemas__pb2.HealthResponse.FromString, options, channel_credentials, insecure, @@ -213,25 +226,28 @@ def Health(request, wait_for_ready, timeout, metadata, - _registered_method=True) + _registered_method=True, + ) @staticmethod - def SetConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def SetConfig( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/schemas.RPCService/SetConfig', - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.SerializeToString, - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.FromString, + "/schemas.RPCService/SetConfig", + tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigRequest.SerializeToString, + tllm_dot_grpc_dot_proto_dot_schemas__pb2.SetConfigResponse.FromString, options, channel_credentials, insecure, @@ -240,25 +256,28 @@ def SetConfig(request, wait_for_ready, timeout, metadata, - _registered_method=True) + _registered_method=True, + ) @staticmethod - def ImageForward(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + def ImageForward( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): return grpc.experimental.unary_unary( request, target, - '/schemas.RPCService/ImageForward', - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.SerializeToString, - tllm_dot_entrypoints_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, + "/schemas.RPCService/ImageForward", + tllm_dot_grpc_dot_proto_dot_schemas__pb2.ImageForwardRequest.SerializeToString, + tllm_dot_grpc_dot_proto_dot_schemas__pb2.ForwardResponse.FromString, options, channel_credentials, insecure, @@ -267,4 +286,5 @@ def ImageForward(request, wait_for_ready, timeout, metadata, - _registered_method=True) + _registered_method=True, + ) diff --git a/tllm/grpc/worker_service/master_manager.py b/tllm/grpc/worker_service/master_manager.py index a5cf023..594b0fa 100644 --- a/tllm/grpc/worker_service/master_manager.py +++ b/tllm/grpc/worker_service/master_manager.py @@ -22,9 +22,19 @@ def update_url(self, master_url: str, forward_url: str, pp_idx: int): self.forward_stub = schemas_pb2_grpc.RPCServiceStub(forward_channel) self.pp_idx = pp_idx - async def rpc_func(self, uuid, seq_len, hidden_states: schemas_pb2.BFloat16Tensor, cost_time: float): - forward_request = {"uuid": uuid, "seq_len": seq_len, "hidden_states": hidden_states} - status_request = {"uuid": uuid, "seq_len": seq_len, "pp_idx": self.pp_idx, "cost_time": cost_time} + async def rpc_func( + self, request: schemas_pb2.ForwardRequest, hidden_states: schemas_pb2.BFloat16Tensor, cost_time: float + ): + forward_request = { + "uuid_list": request.uuid_list, + "hidden_states": hidden_states, + "input_ids_list": request.input_ids_list, + } + status_request = { + "uuid": request.uuid_list, + "pp_idx": self.pp_idx, + "cost_time": cost_time, + } self.master_stub.Status(schemas_pb2.StatusRequest(**status_request)) self.forward_stub.Forward(schemas_pb2.ForwardRequest(**forward_request)) diff --git a/tllm/grpc/worker_service/worker_server.py b/tllm/grpc/worker_service/worker_server.py index 94bccd5..11b34c4 100644 --- a/tllm/grpc/worker_service/worker_server.py +++ b/tllm/grpc/worker_service/worker_server.py @@ -80,7 +80,7 @@ async def forward_func(self, request: schemas_pb2.ForwardRequest): convertor = Convertor() hidden_states = convertor.deserialize(request.hidden_states) - seq_input = SeqInput(uuid_list=list(request.uuid), seq_len_list=list(request.seq_len)) + seq_input = SeqInput.from_request_data(request) self.comm.debug_rank0(f"deserialize_tensor cost time: {time.perf_counter() - s1:.4f}") @@ -95,7 +95,7 @@ async def forward_func(self, request: schemas_pb2.ForwardRequest): self.comm.debug_rank0("=" * 20) if self.comm.is_rank0(): - await self.master_rpc_manager.rpc_func(request.uuid, request.seq_len, output, cost_time) + await self.master_rpc_manager.rpc_func(request, output, cost_time) async def SetConfig( self, request: schemas_pb2.SetConfigRequest, context: grpc.ServicerContext diff --git a/tllm/models/mlx/flux/transformer.py b/tllm/models/mlx/flux/transformer.py index 0a5086f..0077172 100644 --- a/tllm/models/mlx/flux/transformer.py +++ b/tllm/models/mlx/flux/transformer.py @@ -74,7 +74,7 @@ def __init__(self): super().__init__() self.num_hidden_layers = 38 self.transformer = SingleTransformer(0, self.num_hidden_layers, self.num_hidden_layers) - self.cache_dict = CacheManager() + self.cache_manager = CacheManager() @classmethod def from_pretrained(cls, config, state_dict, **kwargs): @@ -115,11 +115,11 @@ def __call__( ) -> mx.array: request_id = request_id_list[0] - if request_id in self.cache_dict.cache_dict: - image_rotary_emb, _ = self.cache_dict.get(request_id) + if self.cache_manager.is_contain(request_id): + image_rotary_emb = self.cache_dict.get(request_id) else: image_rotary_emb = self.get_image_rotary_emb(height, width, seq_len) - self.cache_dict.set(request_id, image_rotary_emb, -1) + self.cache_dict.set(request_id, image_rotary_emb) self.cache_dict.check_alive() hidden_states = self.transformer(hidden_states, text_embeddings, image_rotary_emb, seq_len) diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index 0a61b8c..6ce9d21 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -16,17 +16,17 @@ def greedy_decode(logits: mx.array) -> List[int]: return out.tolist() # TODO: first requests is too slow -def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], conv_len_list: List[int]) -> mx.array: +def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], hit_cache_len_list: List[int]) -> mx.array: mask_list = [] sum_q_len = sum(q_len_list) sum_k_len = sum(k_len_list) - for q_len, k_len, conv_len in zip(q_len_list, k_len_list, conv_len_list): + for q_len, k_len, hit_cache_len in zip(q_len_list, k_len_list, hit_cache_len_list): # prefilling if q_len > 1: mask = mx.tril(mx.ones((q_len, k_len), dtype=mx.bool_), k=0) - if conv_len != -1: - sum_q_len -= q_len - conv_len - mask = mask[-conv_len:] + if hit_cache_len != -1: + sum_q_len -= hit_cache_len + mask = mask[-(q_len - hit_cache_len) :] else: mask = mx.ones((q_len, k_len), dtype=mx.bool_) mask_list.append(mask) @@ -44,22 +44,17 @@ def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], conv_len_list: def build_forward_cache( - seq_input: SeqInput, - cache_manager: CacheManager, - num_layers: int, - max_seq_len: int = -1, - num_key_value_heads: int = -1, - head_dim: int = -1, + seq_input: SeqInput, cache_manager: CacheManager, request_cache: RequestsCache ) -> AttentionData: - # TODO: 不需要每次 forward 都初始化 - request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) - q_len_list, k_len_list, position_ids_list, conv_len_list = request_cache.build(seq_input, cache_manager) + q_len_list, k_len_list, position_ids_list, hit_cache_len_list = request_cache.build(seq_input, cache_manager) return AttentionData( request_cache=request_cache, - attn_mask=build_mlx_mask(q_len_list, k_len_list, conv_len_list), + attn_mask=build_mlx_mask(q_len_list, k_len_list, hit_cache_len_list), uuid_list=seq_input.uuid_list, position_ids=mx.concatenate(position_ids_list, axis=-1), + hit_cache_len_list=hit_cache_len_list, + q_len_list=q_len_list, ) diff --git a/tllm/models/mlx/layers.py b/tllm/models/mlx/layers.py index 0f296dc..bd2e872 100644 --- a/tllm/models/mlx/layers.py +++ b/tllm/models/mlx/layers.py @@ -155,7 +155,7 @@ def _rope(self, xs: mx.array, request_cache: RequestsCache, uuid_list: List[str] x_list = [] start = 0 for uuid, offset in zip(uuid_list, offset_list): - end = start + request_cache.get_seq_len(uuid) + end = start + request_cache.get_q_len(uuid) x_list.append(self.rope(xs[start:end].transpose(1, 0, 2), offset).transpose(1, 0, 2)) start = end return cat_func(x_list) diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 67cc9e5..19c47ac 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -7,7 +7,7 @@ from transformers import AutoConfig from tllm import DTYPE -from tllm.commons.cache import CacheManager +from tllm.commons.cache import CacheManager, RequestsCache from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tensor_parallel_state_dict @@ -60,11 +60,21 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads self.head_dim = self.model.layers[-1].self_attn.head_dim + self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: - attention_data = build_forward_cache( - seq_input, self.cache_manager, self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim - ) + attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) + # 截断 hidden_states + if self.config.decoder_start_layer_idx == 0 and any(x != -1 for x in attention_data.hit_cache_len_list): + hidden_states_list = [] + q_start = 0 + for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): + if hit_cache_len != -1: + hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:]) + else: + hidden_states_list.append(hidden_states[q_start:q_len]) + q_start += q_len + hidden_states = mx.concat(hidden_states_list, axis=0) # cos, sin = self.rotary_emb(attention_data.position_ids) # attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) @@ -75,12 +85,15 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: output = self.model(hidden_states, mask=mask, cache=attention_data) # TODO 异步保存 cache - for uuid, seq_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid), attention_data.get_cache_seq_len(uuid)) + for uuid in seq_input.uuid_list: + self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) self.cache_manager.check_alive() + self.request_cache.clear() + self.request_cache.insert_cache(seq_input) if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - output = get_last_hidden_states(output, seq_input.seq_len_list) + split_len_list = attention_data.q_len_list + output = get_last_hidden_states(output, split_len_list) return output @classmethod diff --git a/tllm/models/mlx/qwen.py b/tllm/models/mlx/qwen.py index a1fa65a..6500e96 100644 --- a/tllm/models/mlx/qwen.py +++ b/tllm/models/mlx/qwen.py @@ -6,8 +6,8 @@ import numpy as np from transformers import AutoConfig -from tllm import DTYPE -from tllm.commons.cache import CacheManager +from tllm import DTYPE, ENABLE_PREFIX_CACHE +from tllm.commons.cache import CacheManager, RequestsCache from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp @@ -38,22 +38,48 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.max_seq_len = getattr(self.model.layers[-1].self_attn, "max_seq_len", -1) self.n_kv_heads = self.model.layers[-1].self_attn.n_kv_heads self.head_dim = self.model.layers[-1].self_attn.head_dim + self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: - attention_data = build_forward_cache( - seq_input, self.cache_manager, self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim - ) + attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) + # 截断 hidden_states + if ( + ENABLE_PREFIX_CACHE + and self.config.decoder_start_layer_idx == 0 + and any(x != -1 for x in attention_data.hit_cache_len_list) + ): + hidden_states_list = [] + q_start = 0 + for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): + if hit_cache_len != -1: + hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:]) + else: + hidden_states_list.append(hidden_states[q_start:q_len]) + q_start += q_len + hidden_states = mx.concat(hidden_states_list, axis=0) mask = attention_data.attn_mask mask = mask if mask is None else mask.astype(hidden_states.dtype) output = self.model(hidden_states, mask=mask, cache=attention_data) - for uuid, seq_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid), attention_data.get_cache_seq_len(uuid)) + for uuid in seq_input.uuid_list: + self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) self.cache_manager.check_alive() + self.request_cache.clear() + self.request_cache.insert_cache(seq_input) if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - output = get_last_hidden_states(output, seq_input.seq_len_list) + split_len_list = attention_data.q_len_list + # if any(x != -1 for x in attention_data.hit_cache_len_list): + # q_start = 0 + # for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)): + # if hit_cache_len != -1: + # print("split_len_list[i]:", split_len_list[i]) + # print("q_len:", q_len) + # print("hit_cache_len:", hit_cache_len) + # split_len_list[i] = q_len - hit_cache_len + # q_start += q_len + output = get_last_hidden_states(output, split_len_list) return output @classmethod diff --git a/tllm/models/tinygrad/llama.py b/tllm/models/tinygrad/llama.py index 7ebd7e7..bed8507 100644 --- a/tllm/models/tinygrad/llama.py +++ b/tllm/models/tinygrad/llama.py @@ -169,9 +169,10 @@ def build_tinygrad_mask(q_len_list: List[int], k_len_list: List[int]) -> Tensor: def build_forward_cache(seq_input: SeqInput, cache_manager: CacheManager, num_layers: int) -> AttentionData: request_cache = RequestsCache(num_layers) q_len_list, k_len_list = [], [] - for uuid, q_len in zip(seq_input.uuid_list, seq_input.seq_len_list): + for uuid, q_len in zip(seq_input.uuid_list, ...): if uuid in cache_manager.cache_dict: - layer_cache_list, cache_seq_len = cache_manager.get(uuid) + layer_cache_list = cache_manager.get(uuid) + cache_seq_len = ... k_len_list.append(cache_seq_len + q_len) else: layer_cache_list = None @@ -291,10 +292,11 @@ def forward(self, hidden_states: Tensor, seq_input: SeqInput) -> Tensor: hidden_states = self.model(hidden_states, freqs_cis=freqs_cis, attention_data=attention_data) if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - hidden_states = get_last_hidden_states(hidden_states, seq_input.seq_len_list) + split_len_list = attention_data.q_len_list + hidden_states = get_last_hidden_states(hidden_states, split_len_list) - for uuid, seq_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid), attention_data.get_cache_seq_len(uuid)) + for uuid in seq_input.uuid_list: + self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/models/torch/llama.py b/tllm/models/torch/llama.py index afe1e40..9d7cde8 100644 --- a/tllm/models/torch/llama.py +++ b/tllm/models/torch/llama.py @@ -113,10 +113,11 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten ) if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - hidden_states = get_last_hidden_states(hidden_states, seq_input.seq_len_list) + split_len_list = attention_data.q_len_list + hidden_states = get_last_hidden_states(hidden_states, split_len_list) - for uuid, seq_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid), attention_data.get_cache_seq_len(uuid)) + for uuid in seq_input.uuid_list: + self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/models/torch/qwen.py b/tllm/models/torch/qwen.py index 0e56300..864e956 100644 --- a/tllm/models/torch/qwen.py +++ b/tllm/models/torch/qwen.py @@ -119,10 +119,11 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten ) if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - hidden_states = get_last_hidden_states(hidden_states, seq_input.seq_len_list) + split_len_list = attention_data.q_len_list + hidden_states = get_last_hidden_states(hidden_states, split_len_list) - for uuid, seq_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid), attention_data.get_cache_seq_len(uuid)) + for uuid in seq_input.uuid_list: + self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/schemas.py b/tllm/schemas.py index 8529983..7919135 100644 --- a/tllm/schemas.py +++ b/tllm/schemas.py @@ -64,10 +64,34 @@ def __init__( self.stop_token_ids = stop_token_ids +from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc + + +def numpy_to_grpc_input_ids(input_ids_list: List[np.ndarray]) -> List[schemas_pb2.InputIds]: + rows = [] + for input_ids in input_ids_list: + row = schemas_pb2.InputIds(input_ids=input_ids.tolist()) + rows.append(row) + return rows + + @dataclass class SeqInput: uuid_list: List[str] - seq_len_list: List[int] + input_ids_list: List[List[int]] + + @classmethod + def from_request_data(cls, forward_request: "schemas_pb2.ForwardRequest"): + return cls( + uuid_list=list(forward_request.uuid_list), + input_ids_list=[list(input_ids.input_ids) for input_ids in forward_request.input_ids_list], + ) + + def to_dict(self): + return { + "uuid_list": self.uuid_list, + "input_ids_list": numpy_to_grpc_input_ids(self.input_ids_list), + } @dataclass @@ -151,7 +175,6 @@ class SequenceRequestData: sampling_params: SamplingParams multi_modal_inputs: Optional[Dict[str, List[Image.Image]]] = None - history_request_id: Optional[str] = None finish_reason_list: Optional[List[str]] = None output_ids: Optional[List[int]] = None # 最终生成的 token id diff --git a/tllm/utils.py b/tllm/utils.py index 3b66d25..c652c9e 100644 --- a/tllm/utils.py +++ b/tllm/utils.py @@ -1,8 +1,10 @@ +from typing import Tuple + from tllm import BACKEND, BackendEnum from tllm.grpc.master_service.master_server import MasterServer from tllm.grpc.master_service.pending_requests import PendingRequests from tllm.grpc.master_service.worker_manager import WorkerRPCManager -from typing import Tuple + def setup_seed(seed: int = 42): if BACKEND == BackendEnum.TORCH: