From cb73df230ea95244df1d1184d9c4f435031d4dbb Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 27 Jan 2025 11:34:55 +0800 Subject: [PATCH] first add conversation kv cache --- benchmarks/run_async_requests.py | 26 ++++++++++----- examples/run_engine.py | 31 ++++++++++++++++- tllm/commons/cache.py | 29 +++++++++++----- tllm/entrypoints/api_server.py | 2 +- tllm/entrypoints/server_chat.py | 53 +++++++++++++++++++++++++----- tllm/generate/llm_generator.py | 18 +++++++--- tllm/generate/message_processor.py | 34 ++++++++++++------- tllm/models/mlx/helper.py | 28 +++++++++++----- tllm/models/torch/helper.py | 2 +- tllm/models/torch/layers.py | 1 + 10 files changed, 169 insertions(+), 55 deletions(-) diff --git a/benchmarks/run_async_requests.py b/benchmarks/run_async_requests.py index 2e56976..659be1f 100644 --- a/benchmarks/run_async_requests.py +++ b/benchmarks/run_async_requests.py @@ -26,12 +26,20 @@ async def requests_func(messages: List[Dict[str, Any]]): def llm_message(): messages1 = [{"role": "user", "content": "Hello, how are you?"}] - messages2 = [{"role": "user", "content": "Hello, What's your name?"}] - messages3 = [ - {"role": "system", "content": "You are a helpful AI assistant."}, - {"role": "user", "content": "今天天气怎么样"}, + # messages2 = [{"role": "user", "content": "Hello, What's your name?"}] + # messages1 = [ + # {"role": "system", "content": "You are a helpful AI assistant."}, + # {"role": "user", "content": "今天天气怎么样"}, + # ] + messages2 = [ + {"role": "user", "content": "Hello, how are you?"}, + { + "role": "assistant", + "content": "Hello! I'm Qwen, a large language model created by Alibaba Cloud. I'm here to assist you with any questions or tasks you might have. How can I help you today?", + }, + {"role": "user", "content": "今天天气怎么样?"}, ] - messages_list = [messages1, messages2, messages3] + messages_list = [messages1, messages2, messages2] return messages_list @@ -60,10 +68,10 @@ def mllm_message(): async def main(messages_list: List[List[Dict[str, Any]]]): - print("异步并发请求结果") - s1 = time.time() - await asyncio.gather(*[requests_func(messages) for messages in messages_list]) - print(f"time cost: {time.time() - s1:.4f} s") + # print("异步并发请求结果") + # s1 = time.time() + # await asyncio.gather(*[requests_func(messages) for messages in messages_list]) + # print(f"time cost: {time.time() - s1:.4f} s") print("单独请求结果") s1 = time.time() diff --git a/examples/run_engine.py b/examples/run_engine.py index 293195c..8ab9175 100644 --- a/examples/run_engine.py +++ b/examples/run_engine.py @@ -104,12 +104,41 @@ async def llm_generate(args, messages): engine = init_engine(args.model_path) await engine.start() messages = [{"role": "user", "content": "Hello, how are you?"}] + messages = [ + { + "role": "system", + "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", + }, + {"role": "user", "content": "hello"}, + ] openai_serving_chat = OpenAIServing(engine, args) + # for _ in range(3): + request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100) + response = await openai_serving_chat.create_chat_completion(request, None) + print(response) + + messages = [ + {"role": "user", "content": "Hello, how are you?"}, + { + "role": "assistant", + "content": "Hello! I'm Qwen, a large language model created by Alibaba Cloud. I'm here to assist you with any questions or tasks you might have. How can I help you today?", + }, + {"role": "user", "content": "今天天气怎么样?"}, + ] + messages = [ + { + "role": "system", + "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.", + }, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "今年天气怎么样"}, + ] for _ in range(3): request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100) response = await openai_serving_chat.create_chat_completion(request, None) - print(response) + print(response) async def image_generate(args): diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index a9563cb..23b0e1e 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -1,4 +1,5 @@ # coding: utf-8 +import copy import time from typing import Dict, List, Optional, Tuple, Union @@ -72,24 +73,36 @@ def add(self, uuid: str, seq_len: int, layer_cache_list: Optional[List[KVCache]] def build(self, seq_input, cache_manager): q_len_list, k_len_list = [], [] position_ids_list = [] + conv_len_list = [] for uuid, q_len in zip(seq_input.uuid_list, seq_input.seq_len_list): - if uuid in cache_manager.cache_dict: - # kv_cache 是整个历史的 kv_cache - # 当 q_len 为 1 时,直接使用 kv_cache,使用历史的全部 token kv cache - # TODO: 当 q_len > 1 时,表示只需要使用前 q_len 的 kv_cache,后面的 kv_cache 需要重新计算 + conv_len = -1 + # decoding 阶段 + if q_len == 1 and uuid in cache_manager.cache_dict: layer_cache_list, cache_seq_len = cache_manager.get(uuid) position_ids = array_func(cache_seq_len) k_len_list.append(cache_seq_len + q_len) + # prefilling 阶段 else: - layer_cache_list = None - position_ids = arange_func(q_len) - k_len_list.append(q_len) + # 如果是历史对话,则使用历史的 kv_cache + chat_uuid, chat_len = uuid.rsplit("-", 1) + if uuid.count("-") == 2 and chat_uuid in cache_manager.cache_dict: + layer_cache_list, cache_seq_len = cache_manager.get(chat_uuid) + layer_cache_list = copy.deepcopy(layer_cache_list) + position_ids = arange_func(q_len) + k_len_list.append(q_len) + conv_len = q_len - cache_seq_len + # 首次出现过的 uuid,第一次 conversation + else: + layer_cache_list = None + position_ids = arange_func(q_len) + k_len_list.append(q_len) q_len_list.append(q_len) position_ids_list.append(position_ids) + conv_len_list.append(conv_len) self.add(uuid, q_len, layer_cache_list) - return q_len_list, k_len_list, position_ids_list + return q_len_list, k_len_list, position_ids_list, conv_len_list def get_kv_cache(self, uuid: str) -> List[KVCache]: return self.cache_dict[uuid]["cache"] diff --git a/tllm/entrypoints/api_server.py b/tllm/entrypoints/api_server.py index 2fa667e..1929884 100644 --- a/tllm/entrypoints/api_server.py +++ b/tllm/entrypoints/api_server.py @@ -56,7 +56,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if openai_serving_chat is None: raise ValueError("OpenAIServing instance is not initialized") if raw_request.headers.get("authorization") == "Bearer anythingllm": - request.max_tokens = 8192 + request.max_tokens = openai_serving_chat.max_model_len try: generator = await openai_serving_chat.create_chat_completion(request, raw_request) if request.stream: diff --git a/tllm/entrypoints/server_chat.py b/tllm/entrypoints/server_chat.py index b2e9d2f..0d919ab 100644 --- a/tllm/entrypoints/server_chat.py +++ b/tllm/entrypoints/server_chat.py @@ -33,16 +33,27 @@ def __init__(self, engine: AsyncEngine, args): self.engine = engine self.message_processor = MessageProcessor(self.engine.tok) self.model_name = os.path.basename(args.model_path) + self.response_role = "assistant" + + @property + def max_model_len(self): + return 8192 async def show_available_models(self): - model_cards = [ModelCard(id=self.model_name, max_model_len=8192, root="tllm", permission=[ModelPermission()])] + model_cards = [ + ModelCard(id=self.model_name, max_model_len=self.max_model_len, root="tllm", permission=[ModelPermission()]) + ] return ModelList(data=model_cards) async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request): request_id = f"chat-{random_uuid()}" messages, mm_input_dict = await self.message_processor.parse_message(request.messages) + # print("messages", messages) input_ids = self.message_processor.preprocess(messages) - history_request_id, q_len = self.message_processor.fetch_request_id(input_ids) + history_request_id, q_len = self.message_processor.fetch_request_id(messages) + + if history_request_id is not None: + request_id = history_request_id if request.temperature == 0.0: method = "greedy" @@ -60,17 +71,26 @@ async def create_chat_completion(self, request: ChatCompletionRequest, raw_reque result_generator = self.engine.generate_stream(sequence_data) if request.stream: - return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator) + return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator, messages) else: - return await self.chat_completion_full_generator(request, raw_request, request_id, result_generator) + return await self.chat_completion_full_generator( + request, raw_request, request_id, result_generator, messages + ) async def chat_completion_stream_generator( - self, request: ChatCompletionRequest, raw_request: Request, request_id: str, result_generator: AsyncIterator + self, + request: ChatCompletionRequest, + raw_request: Request, + request_id: str, + result_generator: AsyncIterator, + messages, ) -> AsyncIterator[str]: created_time = int(time.time()) n = 1 previous_texts = [""] * n try: + response_text = "" + num_generated_tokens = 0 async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. @@ -82,6 +102,8 @@ async def chat_completion_stream_generator( delta_text = output.text previous_texts[i] = output.text + num_prompt_tokens = len(res.prompt_token_ids) + num_generated_tokens += 1 # 根据 finish_reason 判断是否结束,分别处理 if output.finish_reason is not None: @@ -95,21 +117,30 @@ async def chat_completion_stream_generator( logprobs=None, finish_reason=output.finish_reason, ) + response_text += delta_text chunk = ChatCompletionStreamResponse( id=request_id, model=self.model_name, created=created_time, choices=[choice_data] ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + + messages.append({"role": self.response_role, "content": response_text}) + total_tokens = num_prompt_tokens + num_generated_tokens + self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 2) yield "data: [DONE]\n\n" except asyncio.CancelledError: await self.engine.abort(request_id) create_error_response("Client disconnected") async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, request_id: str, result_generator: AsyncIterator + self, + request: ChatCompletionRequest, + raw_request: Request, + request_id: str, + result_generator: AsyncIterator, + messages, ) -> ChatCompletionResponse: final_res = None - role = "assistant" created_time = int(time.time()) try: async for res in result_generator: @@ -124,7 +155,7 @@ async def chat_completion_full_generator( create_error_response("Client disconnected") output = final_res.outputs[0] - message = ChatMessage(role=role, content=output.text) + message = ChatMessage(role=self.response_role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, @@ -136,11 +167,12 @@ async def chat_completion_full_generator( num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) + total_tokens = num_prompt_tokens + num_generated_tokens usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, + total_tokens=total_tokens, ) response = ChatCompletionResponse( id=request_id, @@ -149,4 +181,7 @@ async def chat_completion_full_generator( choices=[choice_data], usage=usage, ) + + messages.append({"role": self.response_role, "content": output.text}) + self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 1) return response diff --git a/tllm/generate/llm_generator.py b/tllm/generate/llm_generator.py index 249c0ca..30c7052 100644 --- a/tllm/generate/llm_generator.py +++ b/tllm/generate/llm_generator.py @@ -113,13 +113,21 @@ async def generate(self, request_list: List[SequenceRequestData]): # 如果是 prefilling,则为 input_ids; 否则,为 output_ids[-1] # input_ids: seq_len if sequence_request.is_prefill: - # if sequence_request.history_request_id: - # uuid_list[-1] = sequence_request.history_request_id if self.processor is not None: mm_input_list.append(process_mm_input(sequence_request, self.processor, **self.mm_config)) - input_ids_list.append(np.array(sequence_request.input_ids)) - # seq_len_list.append(sequence_request.q_len) # 需要搭配 history_request_id 使用 - seq_len_list.append(len(sequence_request.input_ids)) + + # 已经存在的 message,需要更新使用之前的 request_id + if sequence_request.history_request_id is not None: + uuid_list[-1] = sequence_request.history_request_id + new_q_len = len(sequence_request.input_ids) - sequence_request.q_len + # print("new_q_len", len(sequence_request.input_ids[-new_q_len:])) + input_ids_list.append(np.array(sequence_request.input_ids[-new_q_len:])) + # print("q_len", len(sequence_request.input_ids)) + # input_ids_list.append(np.array(sequence_request.input_ids)) + seq_len_list.append(len(sequence_request.input_ids)) + else: + input_ids_list.append(np.array(sequence_request.input_ids)) + seq_len_list.append(len(sequence_request.input_ids)) else: input_ids_list.append(np.array([sequence_request.output_ids[-1]])) seq_len_list.append(1) diff --git a/tllm/generate/message_processor.py b/tllm/generate/message_processor.py index 27a2431..90b3a6a 100644 --- a/tllm/generate/message_processor.py +++ b/tllm/generate/message_processor.py @@ -1,3 +1,4 @@ +import copy from typing import Dict, List, Optional, Tuple from PIL import Image @@ -13,6 +14,7 @@ class MessageProcessor: def __init__(self, tok: TokenizerUtils): self.tok = tok self.role_set = {"user", "system", "assistant"} + self.conversations_dict: Dict[MESSAGES, str] = {} async def read_image(self, image: UrlItem) -> ImageFile: if image.base64 is not None: @@ -61,17 +63,25 @@ async def parse_message(self, messages: MESSAGES) -> Tuple[List[Dict[str, str]], return new_messages, mm_input_dict def preprocess(self, messages: List[Dict[str, str]]) -> List[int]: - return self.tok.preprocess(messages=messages).input_ids + input_ids = self.tok.preprocess(messages=messages).input_ids + return input_ids - def fetch_request_id(self, input_ids: List[int]) -> Tuple[Optional[str], int]: - # max_index, max_id = -1, -1 - # for cache_input_ids, id_ in conversations_dict.items(): - # index = list_common_prefix(input_ids, cache_input_ids) - # if index > max_index: - # max_id = id_ - # max_index = index - - # if max_index == 0 or max_id == -1: - # return None, -1 - # return max_id, max_index + def fetch_request_id(self, messages: List[MESSAGES]) -> Optional[str]: + messages = copy.deepcopy(messages) + while len(messages) > 0 and messages[-1]["role"] == "user": + messages.pop() + if len(messages) == 0: + return None, -1 + # 查询历史 messages 是否在 conversations_dict 中 + key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids) + if key_str in self.conversations_dict: + request_id, total_token_num = self.conversations_dict[key_str] + # TODO: maybe have the bug + if len(key_str) <= total_token_num: + return None, -1 + return request_id + f"-{len(messages)}", total_token_num return None, -1 + + def update_conversations_dict(self, request_id: str, messages: List[MESSAGES], total_token_num: int) -> None: + key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids) + self.conversations_dict[key_str] = (request_id, total_token_num) diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index 2c2bb64..0a61b8c 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -16,13 +16,22 @@ def greedy_decode(logits: mx.array) -> List[int]: return out.tolist() # TODO: first requests is too slow -def build_mlx_mask(q_len_list: List[int], k_len_list: List[int]) -> mx.array: - mask_list = [ - mx.tril(mx.ones((L, S), dtype=mx.bool_), k=0) if L > 1 else mx.ones((L, S), dtype=mx.bool_) - for (L, S) in zip(q_len_list, k_len_list) - ] - - combined_mask = mx.zeros((sum(q_len_list), sum(k_len_list)), dtype=mx.bool_) +def build_mlx_mask(q_len_list: List[int], k_len_list: List[int], conv_len_list: List[int]) -> mx.array: + mask_list = [] + sum_q_len = sum(q_len_list) + sum_k_len = sum(k_len_list) + for q_len, k_len, conv_len in zip(q_len_list, k_len_list, conv_len_list): + # prefilling + if q_len > 1: + mask = mx.tril(mx.ones((q_len, k_len), dtype=mx.bool_), k=0) + if conv_len != -1: + sum_q_len -= q_len - conv_len + mask = mask[-conv_len:] + else: + mask = mx.ones((q_len, k_len), dtype=mx.bool_) + mask_list.append(mask) + + combined_mask = mx.zeros((sum_q_len, sum_k_len), dtype=mx.bool_) l_index, r_index = 0, 0 for mask in mask_list: @@ -42,12 +51,13 @@ def build_forward_cache( num_key_value_heads: int = -1, head_dim: int = -1, ) -> AttentionData: + # TODO: 不需要每次 forward 都初始化 request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) - q_len_list, k_len_list, position_ids_list = request_cache.build(seq_input, cache_manager) + q_len_list, k_len_list, position_ids_list, conv_len_list = request_cache.build(seq_input, cache_manager) return AttentionData( request_cache=request_cache, - attn_mask=build_mlx_mask(q_len_list, k_len_list), + attn_mask=build_mlx_mask(q_len_list, k_len_list, conv_len_list), uuid_list=seq_input.uuid_list, position_ids=mx.concatenate(position_ids_list, axis=-1), ) diff --git a/tllm/models/torch/helper.py b/tllm/models/torch/helper.py index 851343b..cc7f37b 100644 --- a/tllm/models/torch/helper.py +++ b/tllm/models/torch/helper.py @@ -73,7 +73,7 @@ def build_forward_cache( head_dim: int = -1, ) -> AttentionData: request_cache = RequestsCache(num_layers, max_seq_len, num_key_value_heads, head_dim) - q_len_list, k_len_list, position_ids_list = request_cache.build(seq_input, cache_manager) + q_len_list, k_len_list, position_ids_list, _ = request_cache.build(seq_input, cache_manager) if ATTN_TYPE == "flash_attention": attn_mask = { diff --git a/tllm/models/torch/layers.py b/tllm/models/torch/layers.py index 3b2b285..0677cd1 100644 --- a/tllm/models/torch/layers.py +++ b/tllm/models/torch/layers.py @@ -152,6 +152,7 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.num_heads * self.head_dim, self.hidden_size, self.world_size, self.rank, bias=o_proj_bias ) + # 长文本、大 batch 收益更大 # self.max_seq_len = 1024 # self._k_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim) # self._v_cache = zeros_func(self.max_seq_len, self.num_key_value_heads, self.head_dim)