Skip to content

Commit

Permalink
add radix tree for token level prefill cache
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 27, 2025
1 parent cb73df2 commit 5594ab2
Show file tree
Hide file tree
Showing 31 changed files with 649 additions and 437 deletions.
4 changes: 3 additions & 1 deletion benchmarks/comm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def run_client_test(self, matrix_shape, num_iterations=3):

with grpc.insecure_channel(f"{self.host}:{self.port}") as channel:
stub = schemas_pb2_grpc.RPCServiceStub(channel)
request = schemas_pb2.ForwardRequest(uuid=["123"], seq_len=[1], hidden_states=compress_bytes(byte_tensor))
request = schemas_pb2.ForwardRequest(
uuid_list=["123"], input_ids_list=[1], hidden_states=compress_bytes(byte_tensor)
)

for _ in range(num_iterations):
start_time = time.time()
Expand Down
17 changes: 13 additions & 4 deletions benchmarks/run_async_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def llm_message():
},
{"role": "user", "content": "今天天气怎么样?"},
]
# messages2 = [
# {"role": "user", "content": "Hello, how are you?"},
# {
# "role": "assistant",
# "content": "Hello! I'm just a virtual assistant, so I don't have feelings, but I'm here and ready to help you with whatever you need. How are you doing? 😊",
# },
# {"role": "user", "content": "今天天气怎么样?"},
# ]
messages_list = [messages1, messages2, messages2]
return messages_list

Expand Down Expand Up @@ -68,15 +76,16 @@ def mllm_message():


async def main(messages_list: List[List[Dict[str, Any]]]):
# print("异步并发请求结果")
# s1 = time.time()
# await asyncio.gather(*[requests_func(messages) for messages in messages_list])
# print(f"time cost: {time.time() - s1:.4f} s")
print("异步并发请求结果")
s1 = time.time()
await asyncio.gather(*[requests_func(messages) for messages in messages_list])
print(f"time cost: {time.time() - s1:.4f} s")

print("单独请求结果")
s1 = time.time()
for message in messages_list:
await requests_func(message)
print("=" * 20)
print(f"time cost: {time.time() - s1:.4f} s")


Expand Down
32 changes: 16 additions & 16 deletions examples/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ async def llm_generate(args, messages):
engine = init_engine(args.model_path)
await engine.start()
messages = [{"role": "user", "content": "Hello, how are you?"}]
messages = [
{
"role": "system",
"content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
},
{"role": "user", "content": "hello"},
]
# messages = [
# {
# "role": "system",
# "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
# },
# {"role": "user", "content": "hello"},
# ]
openai_serving_chat = OpenAIServing(engine, args)

# for _ in range(3):
Expand All @@ -126,15 +126,15 @@ async def llm_generate(args, messages):
},
{"role": "user", "content": "今天天气怎么样?"},
]
messages = [
{
"role": "system",
"content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "今年天气怎么样"},
]
# messages = [
# {
# "role": "system",
# "content": "Given the following conversation, relevant context, and a follow up question, reply with an answer to the current question the user is asking. Return only your response to the question given the above information following the users instructions as needed.",
# },
# {"role": "user", "content": "hello"},
# {"role": "assistant", "content": "Hello! How can I assist you today?"},
# {"role": "user", "content": "今年天气怎么样"},
# ]
for _ in range(3):
request = ChatCompletionRequest(model="test", messages=messages, max_tokens=100)
response = await openai_serving_chat.create_chat_completion(request, None)
Expand Down
2 changes: 1 addition & 1 deletion scripts/rpc_compile.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python3 -m grpc_tools.protoc -I=. --python_out=./ --pyi_out=./ --grpc_python_out=./ tllm/entrypoints/grpc/proto/schemas.proto
python3 -m grpc_tools.protoc -I=. --python_out=./ --pyi_out=./ --grpc_python_out=./ tllm/grpc/proto/schemas.proto
91 changes: 91 additions & 0 deletions tests/test_radix_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from tllm.commons.radix_tree import RadixTree

if __name__ == "__main__":
tree = RadixTree()
tree.append_to_request([151646, 151646, 151644, 9707, 11, 1246, 525, 498, 30, 151645], "123")
tree.append_to_request([151648], "123")
tree.append_to_request([271], "123")
tree.append_to_request([151649], "123")
tree.append_to_request([271], "123")
tree.append_to_request([9707], "123")
tree.append_to_request([0], "123")
tree.append_to_request([358], "123")
tree.append_to_request([2776], "123")
tree.append_to_request([1101], "123")
tree.append_to_request([264], "123")
tree.append_to_request([4108], "123")
tree.append_to_request([17847], "123")
tree.append_to_request([11], "123")
tree.append_to_request([773], "123")

input_ids = [
151646,
151646,
151644,
9707,
11,
1246,
525,
498,
30,
151645,
9707,
0,
358,
2776,
1101,
264,
4108,
17847,
11,
773,
358,
1513,
944,
614,
15650,
11,
714,
358,
2776,
1588,
323,
5527,
311,
1492,
498,
448,
8820,
498,
1184,
13,
2585,
525,
498,
3730,
30,
26525,
232,
151643,
151644,
100644,
104307,
104472,
11319,
151645,
]
longest = tree.longest_common_prefix(input_ids)
print("longest common prefix:", longest)
print("hit input ids", input_ids[: longest[1]])

# longest = tree.longest_common_prefix([1, 2, 3, 4, 6, 7, 8, 9])
# print("longest common prefix:", longest)

# longest = tree.longest_common_prefix([1, 2, 3, 4, 6, 7, 8, 9])
# print("longest common prefix:", longest)

# longest = tree.longest_common_prefix([1, 2, 3])
# print("longest common prefix:", longest)
tree.remove(tree.request_id_map["123"].path)
longest = tree.longest_common_prefix([1, 2, 3, 4])
print("longest common prefix:", longest)
2 changes: 2 additions & 0 deletions tllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class BackendEnum(Enum):
MLX = 2


ENABLE_PREFIX_CACHE = os.environ.get("TLLM_ENABLE_PREFIX_CACHE", "true").lower() == "true"
ENABLE_PREFIX_CACHE = False
if importlib.util.find_spec("mlx"):
BACKEND = BackendEnum.MLX
elif importlib.util.find_spec("torch"):
Expand Down
Loading

0 comments on commit 5594ab2

Please sign in to comment.