Skip to content

Commit

Permalink
fix batch async request
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Jan 27, 2025
1 parent 5594ab2 commit b52788a
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 27 deletions.
25 changes: 12 additions & 13 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ def __init__(
self.kv_cache_list = [KVCache(max_seq_len, num_key_value_heads, head_dim) for _ in range(num_layers)]
self.q_len = q_len

@property
def kv_len(self):
return self.kv_cache_list[0].kv_len

def __getitem__(self, idx: int) -> KVCache:
return self.kv_cache_list[idx]

Expand Down Expand Up @@ -104,18 +100,20 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
hit_cache_len = -1
q_len = len(input_ids)
# decoding 阶段
if q_len == 1 and uuid in cache_manager.cache_dict:
if uuid in cache_manager.cache_dict:
decoder_cache: DecoderCache = cache_manager.get(uuid)
cache_seq_len = decoder_cache.kv_len
decoder_cache.set_q_len(q_len)
cache_seq_len = decoder_cache[0].kv_len
position_ids = array_func(cache_seq_len)
k_len_list.append(cache_seq_len + q_len)
# prefilling 阶段
else:
# 命中了之前的 kv cache,使用历史 cache
if ENABLE_PREFIX_CACHE:
hit_uuid, hit_cache_len = self.radix_tree.longest_common_prefix(input_ids)
else:
hit_uuid, hit_cache_len = None, -1 # 不启用 prefix cache
# 不启用 prefix cache
hit_uuid, hit_cache_len = None, -1
# 命中了之前的 kv cache,使用历史 cache
if hit_uuid is not None and cache_manager.get(hit_uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid))
# 相同请求时,避免过超过 cache 长度
Expand All @@ -139,15 +137,15 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
self.add(uuid, q_len, decoder_cache)
return q_len_list, k_len_list, position_ids_list, hit_cache_len_list

def get_kv_cache(self, uuid: str) -> DecoderCache:
def get_decoder_cache(self, uuid: str) -> DecoderCache:
return self.cache_dict[uuid]

def get_layer_idx_kv_cache(self, uuid: str, layer_idx: int) -> KVCache:
return self.get_kv_cache(uuid)[layer_idx]
return self.get_decoder_cache(uuid)[layer_idx]

def get_q_len(self, uuid: str) -> int:
# 获取每个 uuid 请求的 q_len
return self.get_kv_cache(uuid).q_len
return self.get_decoder_cache(uuid).q_len

def get_kv_len(self, uuid: str, layer_idx: Optional[int] = 0) -> int:
# 获取每个 uuid 请求的 kv cache 的 kv_len
Expand All @@ -173,6 +171,7 @@ def update_cat(
interval = self.get_q_len(uuid)
end = start + interval
cur_key_states, cur_value_states = key_states[start:end], value_states[start:end]

if kv_cache.key_states is None:
kv_cache.key_states, kv_cache.value_states = cur_key_states, cur_value_states
else:
Expand Down Expand Up @@ -269,8 +268,8 @@ def __init__(
self.hit_cache_len_list = hit_cache_len_list # 用于 PP=0 截断 hidden_states
self.q_len_list = q_len_list

def get_kv_cache_list(self, uuid: str) -> List[KVCache]:
return self.request_cache.get_kv_cache(uuid)
def get_decoder_cache(self, uuid: str) -> DecoderCache:
return self.request_cache.get_decoder_cache(uuid)

def get_kv_len(self, uuid: str) -> int:
return self.request_cache.get_kv_len(uuid)
Expand Down
5 changes: 1 addition & 4 deletions tllm/generate/llm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,17 @@ async def generate(self, request_list: List[SequenceRequestData]):
input_ids: List[int]
"""
uuid_list, input_ids_list, seq_len_list, mm_input_list = [], [], [], []
uuid_list, input_ids_list, mm_input_list = [], [], []
for sequence_request in request_list:
uuid_list.append(sequence_request.request_id)
# 如果是 prefilling,则为 input_ids; 否则,为 output_ids[-1]
# input_ids: seq_len
if sequence_request.is_prefill:
if self.processor is not None:
mm_input_list.append(process_mm_input(sequence_request, self.processor, **self.mm_config))

input_ids_list.append(np.array(sequence_request.input_ids))
seq_len_list.append(len(sequence_request.input_ids))
else:
input_ids_list.append(np.array([sequence_request.output_ids[-1]]))
seq_len_list.append(1)

mm_input = merge_mm_input(mm_input_list)
input_ids = np.concatenate(input_ids_list, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/mlx/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:

# TODO 异步保存 cache
for uuid in seq_input.uuid_list:
self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid))
self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid))
self.cache_manager.check_alive()
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/mlx/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __call__(self, hidden_states: mx.array, seq_input: SeqInput) -> mx.array:
output = self.model(hidden_states, mask=mask, cache=attention_data)

for uuid in seq_input.uuid_list:
self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid))
self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid))
self.cache_manager.check_alive()
self.request_cache.clear()
self.request_cache.insert_cache(seq_input)
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/tinygrad/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def forward(self, hidden_states: Tensor, seq_input: SeqInput) -> Tensor:
hidden_states = get_last_hidden_states(hidden_states, split_len_list)

for uuid in seq_input.uuid_list:
self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid))
self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid))
self.cache_manager.check_alive()

return hidden_states
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/torch/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten
hidden_states = get_last_hidden_states(hidden_states, split_len_list)

for uuid in seq_input.uuid_list:
self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid))
self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid))
self.cache_manager.check_alive()

return hidden_states
Expand Down
2 changes: 1 addition & 1 deletion tllm/models/torch/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Ten
hidden_states = get_last_hidden_states(hidden_states, split_len_list)

for uuid in seq_input.uuid_list:
self.cache_manager.set(uuid, attention_data.get_kv_cache_list(uuid))
self.cache_manager.set(uuid, attention_data.get_decoder_cache(uuid))
self.cache_manager.check_alive()

return hidden_states
Expand Down
7 changes: 2 additions & 5 deletions tllm/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from pydantic import BaseModel

from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc

finish_reason_type = Literal["length", "stop", None]
modal_type = Literal["text", "image_url"]

Expand Down Expand Up @@ -64,9 +66,6 @@ def __init__(
self.stop_token_ids = stop_token_ids


from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc


def numpy_to_grpc_input_ids(input_ids_list: List[np.ndarray]) -> List[schemas_pb2.InputIds]:
rows = []
for input_ids in input_ids_list:
Expand Down Expand Up @@ -188,7 +187,6 @@ class SequenceRequestData:
timeout: int = 100000 # 请求的总超时时间
is_stop: bool = False
is_prefill: bool = True
q_len: int = -1

condition: asyncio.Condition = field(default_factory=asyncio.Condition)

Expand All @@ -198,7 +196,6 @@ def __post_init__(self):
self.generate_text = None
self.finish_reason_list = [None] * self.sampling_params.n
self.decode_start_ts = None
self.q_len = len(self.input_ids) if self.q_len == -1 else self.q_len

def __repr__(self) -> str:
return f"request_id={self.request_id}; output_ids={self.output_ids}"
Expand Down

0 comments on commit b52788a

Please sign in to comment.