diff --git a/README.md b/README.md index 6ee83b2..79dfb33 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ 2.1 (no communication) ```bash - tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local + tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local --client_size 1 ``` 2.2 (with communication) @@ -37,26 +37,32 @@ python3 benchmarks/run_async_requests.py In `examples/config.json` ```json -// 客户端的数量会决定模型拆分的数量 { "server": { - "grpc_port": 25001, // server 的 grpc 端口,用于每个 client 发送状态数据以及最后一个 client 发送计算后的结果 - "http_port": 8022, // server 的 http 端口,API 接口 以及 WebSocket 服务 - "hostname": "mac-mini" // server 的 hostname,可以用 ip 代替,如 192.168.1.10,需要确保 client 能够访问 + "grpc_port": 25001, + "http_port": 8022, + "hostname": "mac-mini" }, "client": [ { - "grpc_port": 25002, // 第一个 client 的 grpc 端口 - "hostname": "m3pro" // 第一个 client 的 hostname,需要确保 server 和 其他 client 能够访问 + "grpc_port": 25002, + "hostname": "m3pro" }, { - "grpc_port": 25003, // 第二个 client 的 grpc 端口 - "hostname": "m3" // 第二个 client 的 hostname,需要确保 server 和 其他 client 能够访问 + "grpc_port": 25003, + "hostname": "m3" } ] } ``` +- 客户端的数量会决定模型拆分的数量 +- `server.grpc_port`: server 的 grpc 端口,用于每个 client 发送状态数据以及最后一个 client 发送计算后的结果 +- `server.http_port`: server 的 http 端口,API 接口 以及 WebSocket 服务 +- `server.hostname`: server 的 hostname,可以用 ip 代替,如 192.168.1.10,需要确保 client 能够访问 +- `client.grpc_port`: client 的 grpc 端口 +- `client.hostname`: client 的 hostname,需要确保 server 和 其他 client 能够访问 + ### Features - [X] Support Multi-Requests @@ -75,6 +81,7 @@ In `examples/config.json` - [X] Attention - [X] xformers - [X] flash-attn + - [x] Prefill-Cache (Token-Level) - [ ] PageAttention ### Performance diff --git a/tllm/__init__.py b/tllm/__init__.py index 2bcdb0f..2b9f7cc 100644 --- a/tllm/__init__.py +++ b/tllm/__init__.py @@ -8,8 +8,8 @@ class BackendEnum(Enum): MLX = 2 -ENABLE_PREFIX_CACHE = os.environ.get("TLLM_ENABLE_PREFIX_CACHE", "true").lower() == "true" -ENABLE_PREFIX_CACHE = False +ENABLE_PREFILL_CACHE = os.environ.get("TLLM_ENABLE_PREFILL_CACHE", "true").lower() == "true" + if importlib.util.find_spec("mlx"): BACKEND = BackendEnum.MLX elif importlib.util.find_spec("torch"): diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index 783eaa2..d2f5af7 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -3,7 +3,7 @@ import time from typing import Dict, List, Optional -from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFIX_CACHE, BackendEnum +from tllm import BACKEND, DEVICE, DTYPE, ENABLE_PREFILL_CACHE, BackendEnum from tllm.commons.radix_tree import RadixTree from tllm.schemas import MIX_TENSOR, SeqInput @@ -79,7 +79,7 @@ def clear(self): self.cache_dict.clear() def insert_cache(self, seq_input: SeqInput): - if ENABLE_PREFIX_CACHE: + if ENABLE_PREFILL_CACHE: for input_ids, request_id in zip(seq_input.input_ids_list, seq_input.uuid_list): self.radix_tree.append_to_request(input_ids, request_id) @@ -108,7 +108,7 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): k_len_list.append(cache_seq_len + q_len) # prefilling 阶段 else: - if ENABLE_PREFIX_CACHE: + if ENABLE_PREFILL_CACHE: hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids) else: # 不启用 prefix cache diff --git a/tllm/grpc/proto/schemas_pb2.pyi b/tllm/grpc/proto/schemas_pb2.pyi index bc45984..036c822 100644 --- a/tllm/grpc/proto/schemas_pb2.pyi +++ b/tllm/grpc/proto/schemas_pb2.pyi @@ -1,7 +1,13 @@ from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) DESCRIPTOR: _descriptor.FileDescriptor @@ -27,7 +33,12 @@ class ForwardRequest(_message.Message): uuid_list: _containers.RepeatedScalarFieldContainer[str] input_ids_list: _containers.RepeatedCompositeFieldContainer[InputIds] hidden_states: BFloat16Tensor - def __init__(self, uuid_list: _Optional[_Iterable[str]] = ..., input_ids_list: _Optional[_Iterable[_Union[InputIds, _Mapping]]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...) -> None: ... + def __init__( + self, + uuid_list: _Optional[_Iterable[str]] = ..., + input_ids_list: _Optional[_Iterable[_Union[InputIds, _Mapping]]] = ..., + hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., + ) -> None: ... class StatusRequest(_message.Message): __slots__ = ("uuid", "pp_idx", "cost_time") @@ -37,7 +48,9 @@ class StatusRequest(_message.Message): uuid: _containers.RepeatedScalarFieldContainer[str] pp_idx: int cost_time: float - def __init__(self, uuid: _Optional[_Iterable[str]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ...) -> None: ... + def __init__( + self, uuid: _Optional[_Iterable[str]] = ..., pp_idx: _Optional[int] = ..., cost_time: _Optional[float] = ... + ) -> None: ... class StatusResponse(_message.Message): __slots__ = ("msg", "status") @@ -75,7 +88,9 @@ class SetConfigRequest(_message.Message): forward_url: str master_url: str pp_rank: int - def __init__(self, forward_url: _Optional[str] = ..., master_url: _Optional[str] = ..., pp_rank: _Optional[int] = ...) -> None: ... + def __init__( + self, forward_url: _Optional[str] = ..., master_url: _Optional[str] = ..., pp_rank: _Optional[int] = ... + ) -> None: ... class SetConfigResponse(_message.Message): __slots__ = ("msg", "status") @@ -101,4 +116,13 @@ class ImageForwardRequest(_message.Message): seq_len: int height: int width: int - def __init__(self, uuid: _Optional[_Iterable[str]] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., encoder_hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., text_embeddings: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., seq_len: _Optional[int] = ..., height: _Optional[int] = ..., width: _Optional[int] = ...) -> None: ... + def __init__( + self, + uuid: _Optional[_Iterable[str]] = ..., + hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., + encoder_hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., + text_embeddings: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., + seq_len: _Optional[int] = ..., + height: _Optional[int] = ..., + width: _Optional[int] = ..., + ) -> None: ... diff --git a/tllm/models/mlx/helper.py b/tllm/models/mlx/helper.py index 6ce9d21..7d09941 100644 --- a/tllm/models/mlx/helper.py +++ b/tllm/models/mlx/helper.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from tllm import DTYPE +from tllm import DTYPE, ENABLE_PREFILL_CACHE from tllm.commons.cache import AttentionData, CacheManager, RequestsCache from tllm.schemas import SeqInput @@ -87,8 +87,35 @@ def read_from_safetensors(file_path: str, key_list: List[str]) -> Dict[str, mx.a return tensors -def get_last_hidden_states(hidden_states: mx.array, seq_len_list: List[int]) -> mx.array: - index_list = list(itertools.accumulate(seq_len_list[:-1])) - seq_hidden_states = mx.split(hidden_states, index_list, axis=0) - hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0) +def get_last_hidden_states( + hit_cache_flag: bool, is_end_pp: bool, attention_data: AttentionData, hidden_states: mx.array +) -> mx.array: + split_len_list = attention_data.q_len_list + if hit_cache_flag: + q_start = 0 + for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)): + if hit_cache_len != -1: + split_len_list[i] = q_len - hit_cache_len + q_start += q_len + if is_end_pp: + index_list = list(itertools.accumulate(split_len_list[:-1])) + seq_hidden_states = mx.split(hidden_states, index_list, axis=0) + hidden_states = mx.concat([x[-1:, :] for x in seq_hidden_states], axis=0) + return hidden_states + + +def truncate_hidden_states( + hit_cache_flag: bool, is_start_pp: bool, attention_data: AttentionData, hidden_states: mx.array +) -> mx.array: + # 截断 hidden_states + if ENABLE_PREFILL_CACHE and is_start_pp and hit_cache_flag: + hidden_states_list = [] + q_start = 0 + for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): + if hit_cache_len != -1: + hidden_states_list.append(hidden_states[q_start : q_start + q_len][hit_cache_len:]) + else: + hidden_states_list.append(hidden_states[q_start : q_start + q_len]) + q_start += q_len + hidden_states = mx.concat(hidden_states_list, axis=0) return hidden_states diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 7168552..582ec96 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -8,7 +8,12 @@ from tllm import DTYPE from tllm.commons.cache import CacheManager, RequestsCache -from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func +from tllm.models.mlx.helper import ( + build_forward_cache, + get_last_hidden_states, + quantization_func, + truncate_hidden_states, +) from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp, tensor_parallel_state_dict from tllm.schemas import SeqInput @@ -62,19 +67,13 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.head_dim = self.model.layers[-1].self_attn.head_dim self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) + self.is_start_pp = self.config.decoder_start_layer_idx == 0 + self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) - # 截断 hidden_states - if self.config.decoder_start_layer_idx == 0 and any(x != -1 for x in attention_data.hit_cache_len_list): - hidden_states_list = [] - q_start = 0 - for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): - if hit_cache_len != -1: - hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:]) - else: - hidden_states_list.append(hidden_states[q_start:q_len]) - q_start += q_len - hidden_states = mx.concat(hidden_states_list, axis=0) + hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list) + hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states) # cos, sin = self.rotary_emb(attention_data.position_ids) # attention_data.cos, attention_data.sin = mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) @@ -91,9 +90,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: self.request_cache.clear() self.request_cache.insert_cache(seq_input) - if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - split_len_list = attention_data.q_len_list - output = get_last_hidden_states(output, split_len_list) + output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output) return output @classmethod diff --git a/tllm/models/mlx/qwen.py b/tllm/models/mlx/qwen.py index 44ec323..089dc0b 100644 --- a/tllm/models/mlx/qwen.py +++ b/tllm/models/mlx/qwen.py @@ -6,9 +6,14 @@ import numpy as np from transformers import AutoConfig -from tllm import DTYPE, ENABLE_PREFIX_CACHE +from tllm import DTYPE from tllm.commons.cache import CacheManager, RequestsCache -from tllm.models.mlx.helper import build_forward_cache, get_last_hidden_states, quantization_func +from tllm.models.mlx.helper import ( + build_forward_cache, + get_last_hidden_states, + quantization_func, + truncate_hidden_states, +) from tllm.models.mlx.layers import Decoder from tllm.models.weight_helper import default_merge_attn, default_merge_mlp from tllm.schemas import SeqInput @@ -40,23 +45,13 @@ def __init__(self, config: AutoConfig, is_merge: bool = True): self.head_dim = self.model.layers[-1].self_attn.head_dim self.request_cache = RequestsCache(self.num_layers, self.max_seq_len, self.n_kv_heads, self.head_dim) + self.is_start_pp = self.config.decoder_start_layer_idx == 0 + self.is_end_pp = self.config.decoder_end_layer_idx == self.config.num_hidden_layers + def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: attention_data = build_forward_cache(seq_input, self.cache_manager, self.request_cache) - # 截断 hidden_states - if ( - ENABLE_PREFIX_CACHE - and self.config.decoder_start_layer_idx == 0 - and any(x != -1 for x in attention_data.hit_cache_len_list) - ): - hidden_states_list = [] - q_start = 0 - for q_len, hit_cache_len in zip(attention_data.q_len_list, attention_data.hit_cache_len_list): - if hit_cache_len != -1: - hidden_states_list.append(hidden_states[q_start:q_len][hit_cache_len:]) - else: - hidden_states_list.append(hidden_states[q_start:q_len]) - q_start += q_len - hidden_states = mx.concat(hidden_states_list, axis=0) + hit_cache_flag = any(x != -1 for x in attention_data.hit_cache_len_list) + hidden_states = truncate_hidden_states(hit_cache_flag, self.is_start_pp, attention_data, hidden_states) mask = attention_data.attn_mask mask = mask if mask is None else mask.astype(hidden_states.dtype) @@ -68,18 +63,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: self.request_cache.clear() self.request_cache.insert_cache(seq_input) - if self.config.decoder_end_layer_idx == self.config.num_hidden_layers: - split_len_list = attention_data.q_len_list - # if any(x != -1 for x in attention_data.hit_cache_len_list): - # q_start = 0 - # for i, (q_len, hit_cache_len) in enumerate(zip(attention_data.q_len_list, attention_data.hit_cache_len_list)): - # if hit_cache_len != -1: - # print("split_len_list[i]:", split_len_list[i]) - # print("q_len:", q_len) - # print("hit_cache_len:", hit_cache_len) - # split_len_list[i] = q_len - hit_cache_len - # q_start += q_len - output = get_last_hidden_states(output, split_len_list) + output = get_last_hidden_states(hit_cache_flag, self.is_end_pp, attention_data, output) return output @classmethod diff --git a/tllm/models/utils.py b/tllm/models/utils.py index 01dc84e..ff697b2 100644 --- a/tllm/models/utils.py +++ b/tllm/models/utils.py @@ -10,4 +10,4 @@ def is_generate_end(output_ids: List[int], eos_token_ids: Set[int], max_tokens: if output_ids[-1] in eos_token_ids: return GenerateEnd(finish_reason="stop", is_end=True) - return GenerateEnd(finish_reason=None, is_end=False) \ No newline at end of file + return GenerateEnd(finish_reason=None, is_end=False)