Skip to content

Commit

Permalink
first add conversation kv cache
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 27, 2025
1 parent c495bb5 commit cb73df2
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 55 deletions.
26 changes: 17 additions & 9 deletions benchmarks/run_async_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@ async def requests_func(messages: List[Dict[str, Any]]):

def llm_message():
messages1 = [{"role": "user", "content": "Hello, how are you?"}]
messages2 = [{"role": "user", "content": "Hello, What's your name?"}]
messages3 = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "今天天气怎么样"},
# messages2 = [{"role": "user", "content": "Hello, What's your name?"}]
# messages1 = [
# {"role": "system", "content": "You are a helpful AI assistant."},
# {"role": "user", "content": "今天天气怎么样"},
# ]
messages2 = [
{"role": "user", "content": "Hello, how are you?"},
{
"role": "assistant",
"content": "Hello! I'm Qwen, a large language model created by Alibaba Cloud. I'm here to assist you with any questions or tasks you might have. How can I help you today?",
},
{"role": "user", "content": "今天天气怎么样?"},
]
messages_list = [messages1, messages2, messages3]
messages_list = [messages1, messages2, messages2]
return messages_list


Expand Down Expand Up @@ -60,10 +68,10 @@ def mllm_message():


async def main(messages_list: List[List[Dict[str, Any]]]):
print("异步并发请求结果")
s1 = time.time()
await asyncio.gather(*[requests_func(messages) for messages in messages_list])
print(f"time cost: {time.time() - s1:.4f} s")
# print("异步并发请求结果")
# s1 = time.time()
# await asyncio.gather(*[requests_func(messages) for messages in messages_list])
# print(f"time cost: {time.time() - s1:.4f} s")

print("单独请求结果")
s1 = time.time()
Expand Down
31 changes: 30 additions & 1 deletion examples/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,41 @@ async def llm_generate(args, messages):
engine = init_engine(args.model_path)
await engine.start()
messages = [{"role": "user", "content": "Hello, how are you?"}]
messages = [
{
"role": "system",
"content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
},
{"role": "user", "content": "hello"},
]
openai_serving_chat = OpenAIServing(engine, args)

# for _ in range(3):
request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100)
response = await openai_serving_chat.create_chat_completion(request, None)
print(response)

messages = [
{"role": "user", "content": "Hello, how are you?"},
{
"role": "assistant",
"content": "Hello! I'm Qwen, a large language model created by Alibaba Cloud. I'm here to assist you with any questions or tasks you might have. How can I help you today?",
},
{"role": "user", "content": "今天天气怎么样?"},
]
messages = [
{
"role": "system",
"content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "今年天气怎么样"},
]
for _ in range(3):
request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100)
response = await openai_serving_chat.create_chat_completion(request, None)
print(response)
print(response)


async def image_generate(args):
Expand Down
29 changes: 21 additions & 8 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import copy
import time
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -72,24 +73,36 @@ def add(self, uuid: str, seq_len: int, layer_cache_list: Optional[List[KVCache]]
def build(self, seq_input, cache_manager):
q_len_list, k_len_list = [], []
position_ids_list = []
conv_len_list = []

for uuid, q_len in zip(seq_input.uuid_list, seq_input.seq_len_list):
if uuid in cache_manager.cache_dict:
# kv_cache 是整个历史的 kv_cache
# 当 q_len 为 1 时,直接使用 kv_cache,使用历史的全部 token kv cache
# TODO: 当 q_len > 1 时,表示只需要使用前 q_len 的 kv_cache,后面的 kv_cache 需要重新计算
conv_len = -1
# decoding 阶段
if q_len == 1 and uuid in cache_manager.cache_dict:
layer_cache_list, cache_seq_len = cache_manager.get(uuid)
position_ids = array_func(cache_seq_len)
k_len_list.append(cache_seq_len + q_len)
# prefilling 阶段
else:
layer_cache_list = None
position_ids = arange_func(q_len)
k_len_list.append(q_len)
# 如果是历史对话,则使用历史的 kv_cache
chat_uuid, chat_len = uuid.rsplit("-", 1)
if uuid.count("-") == 2 and chat_uuid in cache_manager.cache_dict:
layer_cache_list, cache_seq_len = cache_manager.get(chat_uuid)
layer_cache_list = copy.deepcopy(layer_cache_list)
position_ids = arange_func(q_len)
k_len_list.append(q_len)
conv_len = q_len - cache_seq_len
# 首次出现过的 uuid,第一次 conversation
else:
layer_cache_list = None
position_ids = arange_func(q_len)
k_len_list.append(q_len)
q_len_list.append(q_len)
position_ids_list.append(position_ids)
conv_len_list.append(conv_len)

self.add(uuid, q_len, layer_cache_list)
return q_len_list, k_len_list, position_ids_list
return q_len_list, k_len_list, position_ids_list, conv_len_list

def get_kv_cache(self, uuid: str) -> List[KVCache]:
return self.cache_dict[uuid]["cache"]
Expand Down
2 changes: 1 addition & 1 deletion tllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
if openai_serving_chat is None:
raise ValueError("OpenAIServing instance is not initialized")
if raw_request.headers.get("authorization") == "Bearer anythingllm":
request.max_tokens = 8192
request.max_tokens = openai_serving_chat.max_model_len
try:
generator = await openai_serving_chat.create_chat_completion(request, raw_request)
if request.stream:
Expand Down
53 changes: 44 additions & 9 deletions tllm/entrypoints/server_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,27 @@ def __init__(self, engine: AsyncEngine, args):
self.engine = engine
self.message_processor = MessageProcessor(self.engine.tok)
self.model_name = os.path.basename(args.model_path)
self.response_role = "assistant"

@property
def max_model_len(self):
return 8192

async def show_available_models(self):
model_cards = [ModelCard(id=self.model_name, max_model_len=8192, root="tllm", permission=[ModelPermission()])]
model_cards = [
ModelCard(id=self.model_name, max_model_len=self.max_model_len, root="tllm", permission=[ModelPermission()])
]
return ModelList(data=model_cards)

async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request):
request_id = f"chat-{random_uuid()}"
messages, mm_input_dict = await self.message_processor.parse_message(request.messages)
# print("messages", messages)
input_ids = self.message_processor.preprocess(messages)
history_request_id, q_len = self.message_processor.fetch_request_id(input_ids)
history_request_id, q_len = self.message_processor.fetch_request_id(messages)

if history_request_id is not None:
request_id = history_request_id

if request.temperature == 0.0:
method = "greedy"
Expand All @@ -60,17 +71,26 @@ async def create_chat_completion(self, request: ChatCompletionRequest, raw_reque
result_generator = self.engine.generate_stream(sequence_data)

if request.stream:
return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator)
return self.chat_completion_stream_generator(request, raw_request, request_id, result_generator, messages)
else:
return await self.chat_completion_full_generator(request, raw_request, request_id, result_generator)
return await self.chat_completion_full_generator(
request, raw_request, request_id, result_generator, messages
)

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest, raw_request: Request, request_id: str, result_generator: AsyncIterator
self,
request: ChatCompletionRequest,
raw_request: Request,
request_id: str,
result_generator: AsyncIterator,
messages,
) -> AsyncIterator[str]:
created_time = int(time.time())
n = 1
previous_texts = [""] * n
try:
response_text = ""
num_generated_tokens = 0
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
Expand All @@ -82,6 +102,8 @@ async def chat_completion_stream_generator(

delta_text = output.text
previous_texts[i] = output.text
num_prompt_tokens = len(res.prompt_token_ids)
num_generated_tokens += 1

# 根据 finish_reason 判断是否结束,分别处理
if output.finish_reason is not None:
Expand All @@ -95,21 +117,30 @@ async def chat_completion_stream_generator(
logprobs=None,
finish_reason=output.finish_reason,
)
response_text += delta_text
chunk = ChatCompletionStreamResponse(
id=request_id, model=self.model_name, created=created_time, choices=[choice_data]
)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

messages.append({"role": self.response_role, "content": response_text})
total_tokens = num_prompt_tokens + num_generated_tokens
self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 2)
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
await self.engine.abort(request_id)
create_error_response("Client disconnected")

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request, request_id: str, result_generator: AsyncIterator
self,
request: ChatCompletionRequest,
raw_request: Request,
request_id: str,
result_generator: AsyncIterator,
messages,
) -> ChatCompletionResponse:
final_res = None
role = "assistant"
created_time = int(time.time())
try:
async for res in result_generator:
Expand All @@ -124,7 +155,7 @@ async def chat_completion_full_generator(
create_error_response("Client disconnected")

output = final_res.outputs[0]
message = ChatMessage(role=role, content=output.text)
message = ChatMessage(role=self.response_role, content=output.text)

choice_data = ChatCompletionResponseChoice(
index=output.index,
Expand All @@ -136,11 +167,12 @@ async def chat_completion_full_generator(

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs)
total_tokens = num_prompt_tokens + num_generated_tokens

usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
total_tokens=total_tokens,
)
response = ChatCompletionResponse(
id=request_id,
Expand All @@ -149,4 +181,7 @@ async def chat_completion_full_generator(
choices=[choice_data],
usage=usage,
)

messages.append({"role": self.response_role, "content": output.text})
self.message_processor.update_conversations_dict(request_id, messages, total_tokens - 1)
return response
18 changes: 13 additions & 5 deletions tllm/generate/llm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,21 @@ async def generate(self, request_list: List[SequenceRequestData]):
# 如果是 prefilling,则为 input_ids; 否则,为 output_ids[-1]
# input_ids: seq_len
if sequence_request.is_prefill:
# if sequence_request.history_request_id:
# uuid_list[-1] = sequence_request.history_request_id
if self.processor is not None:
mm_input_list.append(process_mm_input(sequence_request, self.processor, **self.mm_config))
input_ids_list.append(np.array(sequence_request.input_ids))
# seq_len_list.append(sequence_request.q_len) # 需要搭配 history_request_id 使用
seq_len_list.append(len(sequence_request.input_ids))

# 已经存在的 message,需要更新使用之前的 request_id
if sequence_request.history_request_id is not None:
uuid_list[-1] = sequence_request.history_request_id
new_q_len = len(sequence_request.input_ids) - sequence_request.q_len
# print("new_q_len", len(sequence_request.input_ids[-new_q_len:]))
input_ids_list.append(np.array(sequence_request.input_ids[-new_q_len:]))
# print("q_len", len(sequence_request.input_ids))
# input_ids_list.append(np.array(sequence_request.input_ids))
seq_len_list.append(len(sequence_request.input_ids))
else:
input_ids_list.append(np.array(sequence_request.input_ids))
seq_len_list.append(len(sequence_request.input_ids))
else:
input_ids_list.append(np.array([sequence_request.output_ids[-1]]))
seq_len_list.append(1)
Expand Down
34 changes: 22 additions & 12 deletions tllm/generate/message_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Dict, List, Optional, Tuple

from PIL import Image
Expand All @@ -13,6 +14,7 @@ class MessageProcessor:
def __init__(self, tok: TokenizerUtils):
self.tok = tok
self.role_set = {"user", "system", "assistant"}
self.conversations_dict: Dict[MESSAGES, str] = {}

async def read_image(self, image: UrlItem) -> ImageFile:
if image.base64 is not None:
Expand Down Expand Up @@ -61,17 +63,25 @@ async def parse_message(self, messages: MESSAGES) -> Tuple[List[Dict[str, str]],
return new_messages, mm_input_dict

def preprocess(self, messages: List[Dict[str, str]]) -> List[int]:
return self.tok.preprocess(messages=messages).input_ids
input_ids = self.tok.preprocess(messages=messages).input_ids
return input_ids

def fetch_request_id(self, input_ids: List[int]) -> Tuple[Optional[str], int]:
# max_index, max_id = -1, -1
# for cache_input_ids, id_ in conversations_dict.items():
# index = list_common_prefix(input_ids, cache_input_ids)
# if index > max_index:
# max_id = id_
# max_index = index

# if max_index == 0 or max_id == -1:
# return None, -1
# return max_id, max_index
def fetch_request_id(self, messages: List[MESSAGES]) -> Optional[str]:
messages = copy.deepcopy(messages)
while len(messages) > 0 and messages[-1]["role"] == "user":
messages.pop()
if len(messages) == 0:
return None, -1
# 查询历史 messages 是否在 conversations_dict 中
key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids)
if key_str in self.conversations_dict:
request_id, total_token_num = self.conversations_dict[key_str]
# TODO: maybe have the bug
if len(key_str) <= total_token_num:
return None, -1
return request_id + f"-{len(messages)}", total_token_num
return None, -1

def update_conversations_dict(self, request_id: str, messages: List[MESSAGES], total_token_num: int) -> None:
key_str = tuple(self.tok.preprocess(messages=messages, add_generation_prompt=False).input_ids)
self.conversations_dict[key_str] = (request_id, total_token_num)
Loading

0 comments on commit cb73df2

Please sign in to comment.