Skip to content

Commit

Permalink
add prefill cache for async request
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 27, 2025
1 parent b52788a commit 016f8a9
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 69 deletions.
25 changes: 16 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
2.1 (no communication)

```bash
tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local
tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local --client_size 1
```

2.2 (with communication)
Expand All @@ -37,26 +37,32 @@ python3 benchmarks/run_async_requests.py
In `examples/config.json`

```json
// 客户端的数量会决定模型拆分的数量
{
"server": {
"grpc_port": 25001, // server 的 grpc 端口,用于每个 client 发送状态数据以及最后一个 client 发送计算后的结果
"http_port": 8022, // server 的 http 端口,API 接口 以及 WebSocket 服务
"hostname": "mac-mini" // server 的 hostname,可以用 ip 代替,如 192.168.1.10,需要确保 client 能够访问
"grpc_port": 25001,
"http_port": 8022,
"hostname": "mac-mini"
},
"client": [
{
"grpc_port": 25002, // 第一个 client 的 grpc 端口
"hostname": "m3pro" // 第一个 client 的 hostname,需要确保 server 和 其他 client 能够访问
"grpc_port": 25002,
"hostname": "m3pro"
},
{
"grpc_port": 25003, // 第二个 client 的 grpc 端口
"hostname": "m3" // 第二个 client 的 hostname,需要确保 server 和 其他 client 能够访问
"grpc_port": 25003,
"hostname": "m3"
}
]
}
```

- 客户端的数量会决定模型拆分的数量
- `server.grpc_port`: server 的 grpc 端口,用于每个 client 发送状态数据以及最后一个 client 发送计算后的结果
- `server.http_port`: server 的 http 端口,API 接口 以及 WebSocket 服务
- `server.hostname`: server 的 hostname,可以用 ip 代替,如 192.168.1.10,需要确保 client 能够访问
- `client.grpc_port`: client 的 grpc 端口
- `client.hostname`: client 的 hostname,需要确保 server 和 其他 client 能够访问

### Features

- [X] Support Multi-Requests
Expand All @@ -75,6 +81,7 @@ In `examples/config.json`
- [X] Attention
- [X] xformers
- [X] flash-attn
- [x] Prefill-Cache (Token-Level)
- [ ] PageAttention

### Performance
Expand Down
4 changes: 2 additions & 2 deletions tllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class BackendEnum(Enum):
MLX = 2


ENABLE_PREFIX_CACHE = os.environ.get("TLLM_ENABLE_PREFIX_CACHE", "true").lower() == "true"
ENABLE_PREFIX_CACHE = False
ENABLE_PREFILL_CACHE = os.environ.get("TLLM_ENABLE_PREFILL_CACHE", "true").lower() == "true"

if importlib.util.find_spec("mlx"):
BACKEND = BackendEnum.MLX
elif importlib.util.find_spec("torch"):
Expand Down
6 changes: 3 additions & 3 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from typing import Dict, List, Optional

from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFIX_CACHE, BackendEnum
from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFILL_CACHE, BackendEnum
from tllm.commons.radix_tree import RadixTree
from tllm.schemas import MIX_TENSOR, SeqInput

Expand Down Expand Up @@ -79,7 +79,7 @@ def clear(self):
self.cache_dict.clear()

def insert_cache(self, seq_input: SeqInput):
if ENABLE_PREFIX_CACHE:
if ENABLE_PREFILL_CACHE:
for input_ids, request_id in zip(seq_input.input_ids_list, seq_input.uuid_list):
self.radix_tree.append_to_request(input_ids, request_id)

Expand Down Expand Up @@ -108,7 +108,7 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
k_len_list.append(cache_seq_len + q_len)
# prefilling 阶段
else:
if ENABLE_PREFIX_CACHE:
if ENABLE_PREFILL_CACHE:
hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids)
else:
# 不启用 prefix cache
Expand Down
34 changes: 29 additions & 5 deletions tllm/grpc/proto/schemas_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
from typing import (
ClassVar as _ClassVar,
Iterable as _Iterable,
Mapping as _Mapping,
Optional as _Optional,
Union as _Union,
)

DESCRIPTOR: _descriptor.FileDescriptor

Expand All @@ -27,7 +33,12 @@ class ForwardRequest(_message.Message):
uuid_list: _containers.RepeatedScalarFieldContainer[str]
input_ids_list: _containers.RepeatedCompositeFieldContainer[InputIds]
hidden_states: BFloat16Tensor
def __init__(self, uuid_list: _Optional[_Iterable[str]] = ..., input_ids_list: _Optional[_Iterable[_Union[InputIds, _Mapping]]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...) -> None: ...
def __init__(
self,
uuid_list: _Optional[_Iterable[str]] = ...,
input_ids_list: _Optional[_Iterable[_Union[InputIds, _Mapping]]] = ...,
hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...,
) -> None: ...

class StatusRequest(_message.Message):
__slots__ = ("uuid", "pp_idx", "cost_time")
Expand All @@ -37,7 +48,9 @@ class StatusRequest(_message.Message):
uuid: _containers.RepeatedScalarFieldContainer[str]
pp_idx: int
cost_time: float
def __init__(self, uuid: _Optional[_Iterable[str]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ...) -> None: ...
def __init__(
self, uuid: _Optional[_Iterable[str]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ...
) -> None: ...

class StatusResponse(_message.Message):
__slots__ = ("msg", "status")
Expand Down Expand Up @@ -75,7 +88,9 @@ class SetConfigRequest(_message.Message):
forward_url: str
master_url: str
pp_rank: int
def __init__(self, forward_url: _Optional[str] = ..., master_url: _Optional[str] = ..., pp_rank: _Optional[int] = ...) -> None: ...
def __init__(
self, forward_url: _Optional[str] = ..., master_url: _Optional[str] = ..., pp_rank: _Optional[int] = ...
) -> None: ...

class SetConfigResponse(_message.Message):
__slots__ = ("msg", "status")
Expand All @@ -101,4 +116,13 @@ class ImageForwardRequest(_message.Message):
seq_len: int
height: int
width: int
def __init__(self, uuid: _Optional[_Iterable[str]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., encoder_hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., text_embeddings: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., seq_len: _Optional[int] = ..., height: _Optional[int] = ..., width: _Optional[int] = ...) -> None: ...
def __init__(
self,
uuid: _Optional[_Iterable[str]] = ...,
hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...,
encoder_hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...,
text_embeddings: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...,
seq_len: _Optional[int] = ...,
height: _Optional[int] = ...,
width: _Optional[int] = ...,
) -> None: ...
37 changes: 32 additions & 5 deletions tllm/models/mlx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from tllm import DTYPE
from tllm import DTYPE, ENABLE_PREFILL_CACHE
from tllm.commons.cache import AttentionData, CacheManager, RequestsCache
from tllm.schemas import SeqInput

Expand Down Expand Up @@ -87,8 +87,35 @@ def read_from_safetensors(file_path: str, key_list: List[str]) -> Dict[str, mx.a
return tensors


def get_last_hidden_states(hidden_states: mx.array, seq_len_list: List[int]) -> mx.array:
index_list = list(itertools.accumulate(seq_len_list[:-1]))
seq_hidden_states = mx.split(hidden_states, index_list, axis=0)
hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0)
def get_last_hidden_states(
hit_cache_flag: bool, is_end_pp: bool, attention_data: AttentionData, hidden_states: mx.array
) -> mx.array:
split_len_list = attention_data.q_len_list
if hit_cache_flag:
q_start = 0
for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)):
if hit_cache_len != -1:
split_len_list[i] = q_len - hit_cache_len
q_start += q_len
if is_end_pp:
index_list = list(itertools.accumulate(split_len_list[:-1]))
seq_hidden_states = mx.split(hidden_states, index_list, axis=0)
hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0)
return hidden_states


def truncate_hidden_states(
hit_cache_flag: bool, is_start_pp: bool, attention_data: AttentionData, hidden_states: mx.array
) -> mx.array:
# 截断 hidden_states
if ENABLE_PREFILL_CACHE and is_start_pp and hit_cache_flag:
hidden_states_list = []
q_start = 0
for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list):
if hit_cache_len != -1:
hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:])
else:
hidden_states_list.append(hidden_states[q_start : q_start + q_len])
q_start += q_len
hidden_states = mx.concat(hidden_states_list, axis=0)
return hidden_states
27 changes: 12 additions & 15 deletions tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

from tllm import DTYPE
from tllm.commons.cache import CacheManager, RequestsCache
from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func
from tllm.models.mlx.helper import (
build_forward_cache,
get_last_hidden_states,
quantization_func,
truncate_hidden_states,
)
from tllm.models.mlx.layers import Decoder
from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tensor_parallel_state_dict
from tllm.schemas import SeqInput
Expand Down Expand Up @@ -62,19 +67,13 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):
self.head_dim = self.model.layers[-1].self_attn.head_dim
self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim)

self.is_start_pp = self.config.decoder_start_layer_idx == 0
self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers

def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache)
# 截断 hidden_states
if self.config.decoder_start_layer_idx == 0 and any(x != -1 for x in attention_data.hit_cache_len_list):
hidden_states_list = []
q_start = 0
for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list):
if hit_cache_len != -1:
hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:])
else:
hidden_states_list.append(hidden_states[q_start:q_len])
q_start += q_len
hidden_states = mx.concat(hidden_states_list, axis=0)
hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list)
hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states)

# cos, sin = self.rotary_emb(attention_data.position_ids)
# attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1)
Expand All @@ -91,9 +90,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)

if self.config.decoder_end_layer_idx == self.config.num_hidden_layers:
split_len_list = attention_data.q_len_list
output = get_last_hidden_states(output, split_len_list)
output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output)
return output

@classmethod
Expand Down
42 changes: 13 additions & 29 deletions tllm/models/mlx/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import numpy as np
from transformers import AutoConfig

from tllm import DTYPE, ENABLE_PREFIX_CACHE
from tllm import DTYPE
from tllm.commons.cache import CacheManager, RequestsCache
from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func
from tllm.models.mlx.helper import (
build_forward_cache,
get_last_hidden_states,
quantization_func,
truncate_hidden_states,
)
from tllm.models.mlx.layers import Decoder
from tllm.models.weight_helper import default_merge_attn, default_merge_mlp
from tllm.schemas import SeqInput
Expand Down Expand Up @@ -40,23 +45,13 @@ def __init__(self, config: AutoConfig, is_merge: bool = True):
self.head_dim = self.model.layers[-1].self_attn.head_dim
self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim)

self.is_start_pp = self.config.decoder_start_layer_idx == 0
self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers

def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache)
# 截断 hidden_states
if (
ENABLE_PREFIX_CACHE
and self.config.decoder_start_layer_idx == 0
and any(x != -1 for x in attention_data.hit_cache_len_list)
):
hidden_states_list = []
q_start = 0
for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list):
if hit_cache_len != -1:
hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:])
else:
hidden_states_list.append(hidden_states[q_start:q_len])
q_start += q_len
hidden_states = mx.concat(hidden_states_list, axis=0)
hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list)
hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states)

mask = attention_data.attn_mask
mask = mask if mask is None else mask.astype(hidden_states.dtype)
Expand All @@ -68,18 +63,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)

if self.config.decoder_end_layer_idx == self.config.num_hidden_layers:
split_len_list = attention_data.q_len_list
# if any(x != -1 for x in attention_data.hit_cache_len_list):
# q_start = 0
# for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)):
# if hit_cache_len != -1:
# print("split_len_list[i]:", split_len_list[i])
# print("q_len:", q_len)
# print("hit_cache_len:", hit_cache_len)
# split_len_list[i] = q_len - hit_cache_len
# q_start += q_len
output = get_last_hidden_states(output, split_len_list)
output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output)
return output

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def is_generate_end(output_ids: List[int], eos_token_ids: Set[int], max_tokens:
if output_ids[-1] in eos_token_ids:
return GenerateEnd(finish_reason="stop", is_end=True)

return GenerateEnd(finish_reason=None, is_end=False)
return GenerateEnd(finish_reason=None, is_end=False)

0 comments on commit 016f8a9

Please sign in to comment.