From b52788a8e6b3e19b078cf755159113ddc1f7fe2e Mon Sep 17 00:00:00 2001 From: lujianghu Date: Mon, 27 Jan 2025 20:58:57 +0800 Subject: [PATCH] fix batch async request --- tllm/commons/cache.py | 25 ++++++++++++------------- tllm/generate/llm_generator.py | 5 +---- tllm/models/mlx/llama.py | 2 +- tllm/models/mlx/qwen.py | 2 +- tllm/models/tinygrad/llama.py | 2 +- tllm/models/torch/llama.py | 2 +- tllm/models/torch/qwen.py | 2 +- tllm/schemas.py | 7 ++----- 8 files changed, 20 insertions(+), 27 deletions(-) diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index 9f6a796..783eaa2 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -47,10 +47,6 @@ def __init__( self.kv_cache_list = [KVCache(max_seq_len, num_key_value_heads, head_dim) for _ in range(num_layers)] self.q_len = q_len - @property - def kv_len(self): - return self.kv_cache_list[0].kv_len - def __getitem__(self, idx: int) -> KVCache: return self.kv_cache_list[idx] @@ -104,18 +100,20 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): hit_cache_len = -1 q_len = len(input_ids) # decoding 阶段 - if q_len == 1 and uuid in cache_manager.cache_dict: + if uuid in cache_manager.cache_dict: decoder_cache: DecoderCache = cache_manager.get(uuid) - cache_seq_len = decoder_cache.kv_len + decoder_cache.set_q_len(q_len) + cache_seq_len = decoder_cache[0].kv_len position_ids = array_func(cache_seq_len) k_len_list.append(cache_seq_len + q_len) # prefilling 阶段 else: - # 命中了之前的 kv cache,使用历史 cache if ENABLE_PREFIX_CACHE: hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids) else: - hit_uuid, hit_cache_len = None, -1 # 不启用 prefix cache + # 不启用 prefix cache + hit_uuid, hit_cache_len = None, -1 + # 命中了之前的 kv cache,使用历史 cache if hit_uuid is not None and cache_manager.get(hit_uuid) is not None: hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid)) # 相同请求时,避免过超过 cache 长度 @@ -139,15 +137,15 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): self.add(uuid, q_len, decoder_cache) return q_len_list, k_len_list, position_ids_list, hit_cache_len_list - def get_kv_cache(self, uuid: str) -> DecoderCache: + def get_decoder_cache(self, uuid: str) -> DecoderCache: return self.cache_dict[uuid] def get_layer_idx_kv_cache(self, uuid: str, layer_idx: int) -> KVCache: - return self.get_kv_cache(uuid)[layer_idx] + return self.get_decoder_cache(uuid)[layer_idx] def get_q_len(self, uuid: str) -> int: # 获取每个 uuid 请求的 q_len - return self.get_kv_cache(uuid).q_len + return self.get_decoder_cache(uuid).q_len def get_kv_len(self, uuid: str, layer_idx: Optional[int] = 0) -> int: # 获取每个 uuid 请求的 kv cache 的 kv_len @@ -173,6 +171,7 @@ def update_cat( interval = self.get_q_len(uuid) end = start + interval cur_key_states, cur_value_states = key_states[start:end], value_states[start:end] + if kv_cache.key_states is None: kv_cache.key_states, kv_cache.value_states = cur_key_states, cur_value_states else: @@ -269,8 +268,8 @@ def __init__( self.hit_cache_len_list = hit_cache_len_list # 用于 PP=0 截断 hidden_states self.q_len_list = q_len_list - def get_kv_cache_list(self, uuid: str) -> List[KVCache]: - return self.request_cache.get_kv_cache(uuid) + def get_decoder_cache(self, uuid: str) -> DecoderCache: + return self.request_cache.get_decoder_cache(uuid) def get_kv_len(self, uuid: str) -> int: return self.request_cache.get_kv_len(uuid) diff --git a/tllm/generate/llm_generator.py b/tllm/generate/llm_generator.py index 4eecd6f..aeaf3f6 100644 --- a/tllm/generate/llm_generator.py +++ b/tllm/generate/llm_generator.py @@ -107,20 +107,17 @@ async def generate(self, request_list: List[SequenceRequestData]): input_ids: List[int] """ - uuid_list, input_ids_list, seq_len_list, mm_input_list = [], [], [], [] + uuid_list, input_ids_list, mm_input_list = [], [], [] for sequence_request in request_list: uuid_list.append(sequence_request.request_id) # 如果是 prefilling,则为 input_ids; 否则,为 output_ids[-1] - # input_ids: seq_len if sequence_request.is_prefill: if self.processor is not None: mm_input_list.append(process_mm_input(sequence_request, self.processor, **self.mm_config)) input_ids_list.append(np.array(sequence_request.input_ids)) - seq_len_list.append(len(sequence_request.input_ids)) else: input_ids_list.append(np.array([sequence_request.output_ids[-1]])) - seq_len_list.append(1) mm_input = merge_mm_input(mm_input_list) input_ids = np.concatenate(input_ids_list, axis=-1) diff --git a/tllm/models/mlx/llama.py b/tllm/models/mlx/llama.py index 19c47ac..7168552 100644 --- a/tllm/models/mlx/llama.py +++ b/tllm/models/mlx/llama.py @@ -86,7 +86,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: # TODO 异步保存 cache for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) + self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) self.cache_manager.check_alive() self.request_cache.clear() self.request_cache.insert_cache(seq_input) diff --git a/tllm/models/mlx/qwen.py b/tllm/models/mlx/qwen.py index 6500e96..44ec323 100644 --- a/tllm/models/mlx/qwen.py +++ b/tllm/models/mlx/qwen.py @@ -63,7 +63,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array: output = self.model(hidden_states, mask=mask, cache=attention_data) for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) + self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) self.cache_manager.check_alive() self.request_cache.clear() self.request_cache.insert_cache(seq_input) diff --git a/tllm/models/tinygrad/llama.py b/tllm/models/tinygrad/llama.py index bed8507..29633ac 100644 --- a/tllm/models/tinygrad/llama.py +++ b/tllm/models/tinygrad/llama.py @@ -296,7 +296,7 @@ def forward(self, hidden_states: Tensor, seq_input: SeqInput) -> Tensor: hidden_states = get_last_hidden_states(hidden_states, split_len_list) for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) + self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/models/torch/llama.py b/tllm/models/torch/llama.py index 9d7cde8..0ac6ec2 100644 --- a/tllm/models/torch/llama.py +++ b/tllm/models/torch/llama.py @@ -117,7 +117,7 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten hidden_states = get_last_hidden_states(hidden_states, split_len_list) for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) + self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/models/torch/qwen.py b/tllm/models/torch/qwen.py index 864e956..95b8ae7 100644 --- a/tllm/models/torch/qwen.py +++ b/tllm/models/torch/qwen.py @@ -123,7 +123,7 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten hidden_states = get_last_hidden_states(hidden_states, split_len_list) for uuid in seq_input.uuid_list: - self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid)) + self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid)) self.cache_manager.check_alive() return hidden_states diff --git a/tllm/schemas.py b/tllm/schemas.py index 7919135..449caf7 100644 --- a/tllm/schemas.py +++ b/tllm/schemas.py @@ -6,6 +6,8 @@ import numpy as np from pydantic import BaseModel +from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc + finish_reason_type = Literal["length", "stop", None] modal_type = Literal["text", "image_url"] @@ -64,9 +66,6 @@ def __init__( self.stop_token_ids = stop_token_ids -from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc - - def numpy_to_grpc_input_ids(input_ids_list: List[np.ndarray]) -> List[schemas_pb2.InputIds]: rows = [] for input_ids in input_ids_list: @@ -188,7 +187,6 @@ class SequenceRequestData: timeout: int = 100000 # 请求的总超时时间 is_stop: bool = False is_prefill: bool = True - q_len: int = -1 condition: asyncio.Condition = field(default_factory=asyncio.Condition) @@ -198,7 +196,6 @@ def __post_init__(self): self.generate_text = None self.finish_reason_list = [None] * self.sampling_params.n self.decode_start_ts = None - self.q_len = len(self.input_ids) if self.q_len == -1 else self.q_len def __repr__(self) -> str: return f"request_id={self.request_id}; output_ids={self.output_ids}"