diff --git a/README.md b/README.md index b4e940f..3b2a9a0 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,14 @@ - [x] Performance Testing - [ ] Async Generation - [ ] Queuing mechanism -- [ ] Decoding Strategy +- [ ] Continuous Batch +- [ ] split model before load +- [ ] Streaming Output +- [ ] OpenAI API format +- [x] Decoding Strategy + - [x] Top-K Sampling + - [x] Top-P Sampling + - [x] Temperature Sampling - [ ] Model - [x] LLaMA - [ ] Multi-Model @@ -23,13 +30,14 @@ - [ ] MLX Framework - [ ] LORA Training - [ ] Shard Storage -- [ ] split model before load -- [ ] Streaming Output ### Performance - 2 GHz 四核Intel Core i5, 16 GB 3733 MHz LPDDR4X + - Llama-3.2-1B-Instruct 单机时间:10.96 token/s + - Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录) + - Apple M3 Pro, 18 GB 在保证通信带宽的前提下,速度应当更快 @@ -47,11 +55,9 @@ bfloat 16 CPU | PP,TP | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct | | ---- | --------- | --- | | 2,1(实际) | 5.49 token/s | 2.42 token/s | -| 2,2(实际) | 5.53 token/s | 2.46 token/s | +| 2,2(实际) | 5.66 token/s | 2.46 token/s | -Llama-3.2-1B-Instruct 单机时间:10.96 token/s -Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录) TODO: Meta-Llama-3-8B-Instruct in GPU diff --git a/benchmarks/hf_model.py b/benchmarks/hf_model.py index 5506eda..a583ae4 100644 --- a/benchmarks/hf_model.py +++ b/benchmarks/hf_model.py @@ -25,6 +25,8 @@ output = model.generate(input_ids, max_new_tokens=20, do_sample=False) print("generate token: ", output[0]) + assert False + # 模拟 TTFT 时间 cost_time_list = [] for _ in range(10): diff --git a/benchmarks/run_engine.py b/benchmarks/run_engine.py new file mode 100644 index 0000000..4ea1af8 --- /dev/null +++ b/benchmarks/run_engine.py @@ -0,0 +1,34 @@ +import os +import time + +import torch + +from tllm.engine import MyLlamaForCausalLM +from tllm.generate.decode_utils import DecodeUtils +from tllm.generate.token_utils import TokenizerUtils +from tllm.protocol import ChatCompletionRequest +from tllm.rpc.manager import RPCManager + +if __name__ == "__main__": + BASE_PATH = "/Users/lujianghu/Documents/" + model_path = os.path.join(BASE_PATH, "Llama-3.2-1B-Instruct") + weight_path = os.path.join(model_path, "master_weight.pt") + + url_list = ["localhost:25001"] + server = RPCManager(url_list) + model = MyLlamaForCausalLM.from_pretrained(model_path, weight_path, server) + tok = TokenizerUtils(model_path) + + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello, how are you?"}], max_tokens=20, do_sample=False + ) + input_id_list = tok.preprocess(messages=request.messages).input_ids + + input_ids = torch.tensor(input_id_list).unsqueeze(0) + print("input_ids: ", input_ids) + + s1 = time.time() + output = model.generate( + input_ids, max_new_tokens=request.max_tokens, do_sample=request.do_sample, sampler=DecodeUtils("greedy") + ) + print(output) diff --git a/examples/config.json b/examples/config.json index b1f1693..64b2095 100644 --- a/examples/config.json +++ b/examples/config.json @@ -8,7 +8,7 @@ }, { "pp_rank": 1, - "url": "localhost:25002", + "url": "192.168.0.101:25002", "tp_size": 2, "master_port": 29502, "layer_idx": [11, 22] diff --git a/examples/run_single_client.sh b/examples/run_single_client.sh new file mode 100644 index 0000000..011e6f0 --- /dev/null +++ b/examples/run_single_client.sh @@ -0,0 +1,31 @@ +#!/bin/bash +MASTER_PORT=29501 +GRPC_PORT=25001 +BASE_PATH=/Users/lujianghu/Documents/ +MODE_SIZE=$1 +TP=$2 + + +if [ $MODE_SIZE == "1" ]; then + MODEL_PATH=$BASE_PATH/Llama-3.2-1B-Instruct + START_LAYER_IDX=0 + END_LAYER_IDX=16 +elif [ $MODE_SIZE == "3" ]; then + MODEL_PATH=$BASE_PATH/Llama-3.2-3B-Instruct + START_LAYER_IDX=0 + END_LAYER_IDX=28 +elif [ $MODE_SIZE == "8" ]; then + MODEL_PATH=$BASE_PATH/Meta-Llama-3-8B-Instruct + START_LAYER_IDX=0 + END_LAYER_IDX=32 +elif [ $MODE_SIZE == "70" ]; then + MODEL_PATH=$BASE_PATH/Meta-Llama-3-70B-Instruct +else + echo "Invalid mode size" + exit 1 +fi + +export OMP_NUM_THREADS=8; +export PYTHONPATH="./":$PYTHONPATH; + +torchrun --nproc_per_node=$TP --master_port=$MASTER_PORT tllm/rpc/client.py --port=$GRPC_PORT --start_layer_idx=$START_LAYER_IDX --end_layer_idx=$END_LAYER_IDX --model_path $MODEL_PATH diff --git a/exps/test_attn/causal_mask.py b/exps/test_attn/causal_mask.py new file mode 100644 index 0000000..23a914d --- /dev/null +++ b/exps/test_attn/causal_mask.py @@ -0,0 +1,53 @@ +from typing import List + +import torch + + +def merge_mask(mask_list: List[torch.Tensor], total_length: int) -> torch.Tensor: + combined_mask = torch.zeros((total_length, total_length), dtype=torch.bool) + + start_index = 0 + for mask in mask_list: + combined_mask[start_index : start_index + mask.size(0), start_index : start_index + mask.size(1)] = mask + start_index += mask.size(0) + + combined_attn_bias = torch.zeros(total_length, total_length, dtype=torch.float) + combined_attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf")) + return combined_attn_bias + + +def build_mask(mask: torch.Tensor) -> torch.Tensor: + attn_bias = torch.zeros(mask.size(0), mask.size(1), dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + return attn_bias + + +def build_qkv(bs, num_heads, seq_len, head_dim): + query = torch.randn(bs, num_heads, seq_len, head_dim) + key = torch.randn(bs, num_heads, seq_len, head_dim) + value = torch.randn(bs, num_heads, seq_len, head_dim) + return query, key, value + + +if __name__ == "__main__": + seq_len1, seq_len2 = 3, 4 + temp_mask = torch.ones(seq_len1, seq_len1, dtype=torch.bool).tril(diagonal=0) + temp_mask2 = torch.ones(seq_len2, seq_len2, dtype=torch.bool).tril(diagonal=0) + + combined_attn_bias = merge_mask([temp_mask, temp_mask2], seq_len1 + seq_len2) + + # bs, num_heads, seq_len, head_dim + query1, key1, value1 = build_qkv(1, 2, seq_len1, 4) + base_out1 = torch.nn.functional.scaled_dot_product_attention(query1, key1, value1, is_causal=True) + + query2, key2, value2 = build_qkv(1, 2, seq_len2, 4) + base_out2 = torch.nn.functional.scaled_dot_product_attention(query2, key2, value2, is_causal=True) + + query = torch.cat([query1, query2], dim=-2) + key = torch.cat([key1, key2], dim=-2) + value = torch.cat([value1, value2], dim=-2) + out = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=combined_attn_bias) + out1, out2 = torch.split(out, [seq_len1, seq_len2], dim=-2) + + print("torch.allclose(base_out1, out1)", torch.allclose(base_out1, out1)) + print("torch.allclose(base_out2, out2)", torch.allclose(base_out2, out2)) diff --git a/exps/test_attn/rope.py b/exps/test_attn/rope.py new file mode 100644 index 0000000..3b47285 --- /dev/null +++ b/exps/test_attn/rope.py @@ -0,0 +1,45 @@ +from typing import Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + + +def build_qkv(bs, num_heads, seq_len, head_dim): + query = torch.randn(bs, num_heads, seq_len, head_dim) + key = torch.randn(bs, num_heads, seq_len, head_dim) + value = torch.randn(bs, num_heads, seq_len, head_dim) + return query, key, value + + +if __name__ == "__main__": + seq_len1, seq_len2 = 3, 4 + head_dim = 4 + query1, key1, value1 = build_qkv(1, 2, seq_len1, head_dim) + query2, key2, value2 = build_qkv(1, 2, seq_len2, head_dim) + position_ids1 = torch.arange(seq_len1).unsqueeze(0) + position_ids2 = torch.arange(seq_len2).unsqueeze(0) + + rotary_emb = LlamaRotaryEmbedding( + head_dim, + max_position_embeddings=4096, + base=10000, + ) + cos1, sin1 = rotary_emb(value1, seq_len=seq_len1) + cos2, sin2 = rotary_emb(value2, seq_len=seq_len2) + + query1_out, key1_out = apply_rotary_pos_emb(query1, key1, cos1, sin1, position_ids1) + query2_out, key2_out = apply_rotary_pos_emb(query2, key2, cos2, sin2, position_ids2) + + large_cos1, large_sin1 = rotary_emb(value1, seq_len=seq_len1 * 10) + large_query1, large_key1 = apply_rotary_pos_emb(query1, key1, large_cos1, large_sin1, position_ids1) + + print("test large seq len", torch.allclose(large_query1, query1_out)) + + # merge + query = torch.cat([query1, query2], dim=-2) + key = torch.cat([key1, key2], dim=-2) + position_ids = torch.cat([position_ids1, position_ids2], dim=-1) + cos, sin = rotary_emb(value1, seq_len=max(seq_len1, seq_len2)) + query_out, key_out = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + print("test merge rope for query", torch.allclose(query_out, torch.cat([query1_out, query2_out], dim=-2))) + print("test merge rope for key", torch.allclose(key_out, torch.cat([key1_out, key2_out], dim=-2))) diff --git a/pyproject.toml b/pyproject.toml index 9cb981c..270d18b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,9 @@ exclude = ''' | .venv | build | dist + | .pyi + | schemas_pb2.py + | schemas_pb2_grpc.py )/ ''' diff --git a/requirements.txt b/requirements.txt index f4bddf9..1fe7f8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.36.2 +transformers==4.45.1 torch grpcio==1.66.2 grpcio-tools==1.66.2 diff --git a/tllm/commons/cache_manager.py b/tllm/commons/cache_manager.py index fbdc14c..1f572ea 100644 --- a/tllm/commons/cache_manager.py +++ b/tllm/commons/cache_manager.py @@ -10,7 +10,7 @@ def __init__(self, max_alive_time=60): self.cache_dict = {} def get(self, key) -> Any: - return self.cache_dict.get(key) + return self.cache_dict.get(key)["past_key_values"] def set(self, key, value: Any) -> None: self.cache_dict[key] = {"past_key_values": value, "ts": time.time()} diff --git a/tllm/commons/layers.py b/tllm/commons/layers.py index 2de1cbd..bc03b84 100644 --- a/tllm/commons/layers.py +++ b/tllm/commons/layers.py @@ -1,16 +1,18 @@ +from dataclasses import dataclass from typing import * import torch import torch.nn as nn from transformers.activations import ACT2FN -from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import ( - LlamaConfig, - LlamaRMSNorm, - LlamaRotaryEmbedding, - apply_rotary_pos_emb, - repeat_kv, -) +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv + + +@dataclass +class AttentionCache: + uuid_str_list: List[str] + position_ids: torch.Tensor + past_key_value: "DynamicCache" + attn_mask: torch.Tensor class BaseParallelLayer(nn.Module): @@ -150,14 +152,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.rank, ) self.o_proj = RowParallelLayer(self.num_heads * self.head_dim, self.hidden_size, self.world_size, self.rank) - self._init_rope() - - def _init_rope(self): - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) def load_state_dict(self, state_dict: Dict) -> None: for key in ["o_proj"]: @@ -175,8 +169,8 @@ def load_state_dict(self, state_dict: Dict) -> None: def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional["Cache"] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_cache: AttentionCache, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -186,16 +180,15 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx - self.config.offset) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, attention_cache.position_ids + ) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( + if attention_cache.past_key_value is not None: + cache_kwargs = {"uuid_str_list": attention_cache.uuid_str_list} + key_states, value_states = attention_cache.past_key_value.update( key_states, value_states, self.layer_idx - self.config.offset, cache_kwargs ) @@ -204,18 +197,14 @@ def forward( # TODO: speed up the following line attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and q_len > 1, + query_states, key_states, value_states, attn_mask=attention_cache.attn_mask ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.comm.all_reduce(self.o_proj(attn_output)) - return attn_output, None, past_key_value + return attn_output, None, attention_cache.past_key_value class MyLlamaDecoderLayer(nn.Module): @@ -245,8 +234,8 @@ def load_state_dict(self, state_dict: Dict): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional["Cache"] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_cache: AttentionCache, ) -> Tuple[torch.Tensor, Optional["Cache"]]: residual = hidden_states @@ -255,8 +244,8 @@ def forward( # Self Attention hidden_states, _, past_key_value = self.self_attn( hidden_states=hidden_states, - position_ids=position_ids, - past_key_value=past_key_value, + attention_cache=attention_cache, + position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states diff --git a/tllm/engine.py b/tllm/engine.py index 64ff092..e06ca10 100644 --- a/tllm/engine.py +++ b/tllm/engine.py @@ -12,6 +12,7 @@ from tllm.commons.convert import deserialize_bfloat16_tensor, serialize_bfloat16_tensor from tllm.generate.decode_utils import DecodeUtils from tllm.rpc.manager import RPCManager +from tllm.rpc.protocol import SeqInput finish_reason_type = Literal["length", "stop", None] @@ -75,22 +76,25 @@ def from_pretrained(cls, model_path: str, weight_path: str, server: RPCManager, model.eval() return model - def _prepare_forward_data(self, uuid_str: str, hidden_states: torch.Tensor, need_serialize: bool) -> Dict[str, Any]: + def _prepare_forward_data( + self, seq_input: SeqInput, hidden_states: torch.Tensor, need_serialize: bool + ) -> Dict[str, Any]: if need_serialize: hidden_states = serialize_bfloat16_tensor(hidden_states) - return {"uuid": uuid_str, "hidden_states": hidden_states} + return {"uuid": seq_input.uuid_str_list, "seq_len": seq_input.seq_len_list, "hidden_states": hidden_states} - def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> ForwardResult: + def forward(self, inputs_embeds: torch.Tensor, seq_input: SeqInput) -> ForwardResult: hidden_states = inputs_embeds comm_cost_time_list = [] + last_pp_idx = self.pp_size - 1 for pp_idx in range(self.pp_size): s1 = time.time() outputs = self.server.post_sync( - pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states, need_serialize=pp_idx == 0) - ) - hidden_states = ( - deserialize_bfloat16_tensor(outputs.output) if pp_idx == self.pp_size - 1 else outputs.output + pp_idx, + "/forward", + data=self._prepare_forward_data(seq_input, hidden_states, need_serialize=pp_idx == 0), ) + hidden_states = deserialize_bfloat16_tensor(outputs.output) if pp_idx == last_pp_idx else outputs.output s2 = time.time() comm_cost_time_list.append(s2 - s1 - outputs.cost_time) @@ -108,8 +112,11 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G finish_reason = None uuid_str = str(uuid.uuid4()) ttft_start_time, ttft_end_time = time.time(), time.time() + + seq_len = input_embeds.shape[1] + seq_input = SeqInput(uuid_str_list=[uuid_str], seq_len_list=[seq_len]) while True: - forward_result = self(input_embeds, uuid_str) + forward_result = self(input_embeds, seq_input) logits = forward_result.logits comm_cost_time_list = forward_result.comm_cost_time_list generate_ids = sampler.decode(logits) @@ -121,6 +128,7 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G break input_embeds = self.embed_tokens(torch.tensor(generate_ids)).unsqueeze(0) + seq_input.seq_len_list = [1] if len(output_ids) == 1: ttft_end_time = time.time() logging.info(f"ttft communication cost time: {",".join([f'{x:.4f}' for x in comm_cost_time_list])}") diff --git a/tllm/generate/decode_utils.py b/tllm/generate/decode_utils.py index d596128..e5626a3 100644 --- a/tllm/generate/decode_utils.py +++ b/tllm/generate/decode_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import * import torch import torch.nn.functional as F @@ -33,27 +33,45 @@ def temperature_scaling(logits: torch.Tensor, temperature: float = 1.0): class DecodeUtils: def __init__(self, method: str) -> None: self.method = method - assert self.method in ["greedy", "beam_search"] + assert self.method in ["greedy", "beam_search", "sampling"] - def decode(self, logits: torch.Tensor) -> List[int]: + def decode(self, logits: torch.Tensor, sampling_params: Optional[Dict[str, Any]] = None) -> List[int]: if self.method == "greedy": return self.greedy_decode(logits) elif self.method == "beam_search": - return self.beam_search_decode(logits) + return self.beam_search_decode(logits, sampling_params) + elif self.method == "sampling": + return self.sampling_decode(logits, sampling_params) def greedy_decode(self, logits: torch.Tensor) -> List[int]: # logits shape: [batch_size, sequence_length, vocab_size] return torch.argmax(logits[:, -1], dim=-1).tolist() + def sampling_decode(self, logits: torch.Tensor, sampling_params: Dict[str, Any]) -> List[int]: + top_k = sampling_params["top_k"] + top_p = sampling_params["top_p"] + temperature = sampling_params["temperature"] + temperature_scaled_logits = temperature_scaling(logits, temperature) + # Apply top-k sampling + if top_k > 0: + temperature_scaled_logits = top_k_sampling(temperature_scaled_logits, k=top_k) + + # Apply top-p sampling (nucleus sampling) + if top_p < 1.0: + temperature_scaled_logits = top_p_sampling(temperature_scaled_logits, p=top_p) + + return torch.argmax(temperature_scaled_logits, dim=-1).tolist() + def beam_search_decode( self, logits: torch.Tensor, - beam_width: int = 3, - max_len: int = 20, - top_k: int = 0, - top_p: float = 1.0, - temperature: float = 1.0, + sampling_params: Dict[str, Any], ) -> List[List[int]]: + max_len = sampling_params["max_len"] + top_k = sampling_params["top_k"] + top_p = sampling_params["top_p"] + temperature = sampling_params["temperature"] + beam_width = sampling_params["beam_width"] batch_size, sequence_length, vocab_size = logits.size() # Initialize the beam search diff --git a/tllm/models/cache.py b/tllm/models/cache.py new file mode 100644 index 0000000..856550c --- /dev/null +++ b/tllm/models/cache.py @@ -0,0 +1,56 @@ +from typing import * + +import torch +from transformers.cache_utils import DynamicCache + + +class SeqDynamicCache: + def __init__(self) -> None: + self.cache_dict: Dict[Any] = {} + + def add(self, uuid_str: str, seq_len: int, cache: Optional[DynamicCache] = None): + self.cache_dict.update({uuid_str: {"cache": DynamicCache() if cache is None else cache, "seq_len": seq_len}}) + + def get_cache(self, uuid_str: str) -> DynamicCache: + return self.cache_dict[uuid_str]["cache"] + + def get_seq_len(self, uuid_str: str) -> int: + return self.cache_dict[uuid_str]["seq_len"] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # key_states: bsz x num_heads x seq_len x head_dim + uuid_str_list = cache_kwargs.get("uuid_str_list", None) + seq_len_list = [self.get_seq_len(uuid_str) for uuid_str in uuid_str_list] + seq_key_states = torch.split(key_states, seq_len_list, dim=-2) + seq_value_states = torch.split(value_states, seq_len_list, dim=-2) + + key_states_list, value_states_list = [], [] + for uuid_str, key_states, value_states in zip(uuid_str_list, seq_key_states, seq_value_states): + key, value = self.get_cache(uuid_str).update(key_states, value_states, layer_idx) + key_states_list.append(key) + value_states_list.append(value) + + cat_key_states, cat_value_states = torch.cat(key_states_list, dim=1), torch.cat(value_states_list, dim=1) + return cat_key_states, cat_value_states diff --git a/tllm/models/llama.py b/tllm/models/llama.py index 332f745..4d86f51 100644 --- a/tllm/models/llama.py +++ b/tllm/models/llama.py @@ -2,11 +2,33 @@ import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from tllm.commons.cache_manager import CacheManager -from tllm.commons.layers import MyLlamaDecoderLayer +from tllm.commons.layers import AttentionCache, MyLlamaDecoderLayer +from tllm.models.cache import SeqDynamicCache +from tllm.rpc.protocol import SeqInput + + +def build_mask(seq_len_list: List[int]) -> torch.Tensor: + """ + 构造多个请求的 casual mask + @param seq_len_list: 每个请求的 seq_len + + @return: 一个 mask,形状为 total_length x total_length,其中 total_length 是所有请求的 seq_len 之和 + """ + mask_list = [torch.ones(seq_len, seq_len, dtype=torch.bool).tril(diagonal=0) for seq_len in seq_len_list] + total_length = sum(seq_len_list) + + combined_mask = torch.zeros((total_length, total_length), dtype=torch.bool) + + start_index = 0 + for mask in mask_list: + combined_mask[start_index : start_index + mask.size(0), start_index : start_index + mask.size(1)] = mask + start_index += mask.size(0) + + return combined_mask class Decoder(nn.Module): @@ -24,15 +46,15 @@ def load_state_dict(self, state_dict: Dict) -> None: def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_cache: AttentionCache, ): next_decoder_cache = None for i, layer in enumerate(self.decoder): layer_outputs = layer( hidden_states, - position_ids=position_ids, - past_key_value=past_key_value, + position_embeddings=position_embeddings, + attention_cache=attention_cache, ) hidden_states = layer_outputs[0] @@ -50,31 +72,54 @@ def __init__(self, config): self.cache_manager = CacheManager() self.config = config self.decoder = Decoder(config, config.decoder_start_layer_idx, config.decoder_end_layer_idx) + self.rotary_emb = LlamaRotaryEmbedding(config=config) def load_state_dict(self, state_dict: Dict) -> None: self.decoder.load_state_dict(state_dict) - def forward(self, hidden_states: torch.Tensor, uuid_str: Optional[str] = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, seq_input: SeqInput) -> torch.Tensor: """ @param hidden_states: bs x seq_len x hidden_size - @param position_ids: bs x seq_len - @param uuid_str: 可选的 uuid,用于区分不同的请求。如果 uuid_str 存在,则使用缓存的 kv cache,否则使用新的 kv cache + @param seq_input: + uuid_str_list: List[str]: 每个请求的 uuid + seq_len_list: List[int]: 每个请求的 seq_len + 如果 uuid_str 存在,则使用缓存的 kv cache,否则使用新的 kv cache @return: bs x seq_len x hidden_size """ - if uuid_str in self.cache_manager.cache_dict: - kv_cache_seq_len = self.cache_manager.cache_dict[uuid_str]["past_key_values"].get_seq_length() - position_ids = torch.tensor([kv_cache_seq_len], dtype=torch.long).unsqueeze(0) - past_key_values = self.cache_manager.get(uuid_str)["past_key_values"] - else: - position_ids = torch.arange(hidden_states.size(1), dtype=torch.long).unsqueeze(0) - past_key_values = DynamicCache() + position_ids_list = [] + past_key_values = SeqDynamicCache() + max_position_ids, max_seq_len = None, -1 + seq_len_list = [] + for uuid_str, seq_len in zip(seq_input.uuid_str_list, seq_input.seq_len_list): + if uuid_str in self.cache_manager.cache_dict: + kv_cache = self.cache_manager.get(uuid_str) + position_ids = torch.tensor([kv_cache.get_seq_length()], dtype=torch.long).unsqueeze(0) + past_key_values.add(uuid_str, seq_len, cache=kv_cache) + else: + seq_len = hidden_states.size(1) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + past_key_values.add(uuid_str, seq_len) + if seq_len > max_seq_len: + max_seq_len = seq_len + max_position_ids = position_ids + position_ids_list.append(position_ids) + seq_len_list.append(seq_len) + + attention_cache = AttentionCache( + position_ids=torch.cat(position_ids_list, dim=0).to(self.device), + past_key_value=past_key_values, + uuid_str_list=seq_input.uuid_str_list, + attn_mask=build_mask(seq_len_list), + ) + hidden_states = hidden_states.to(self.device) - position_ids = position_ids.to(self.device) - output = self.decoder(hidden_states, position_ids=position_ids, past_key_value=past_key_values) + position_embeddings = self.rotary_emb(hidden_states, max_position_ids) + output = self.decoder(hidden_states, position_embeddings=position_embeddings, attention_cache=attention_cache) - self.cache_manager.set(uuid_str, output.past_key_values) - self.cache_manager.check_alive() + for uuid_str, seq_len in zip(seq_input.uuid_str_list, seq_input.seq_len_list): + self.cache_manager.set(uuid_str, output.past_key_values.get_cache(uuid_str)) + self.cache_manager.check_alive() return output.last_hidden_state @property diff --git a/tllm/rpc/client.py b/tllm/rpc/client.py index 62229ff..050dcca 100644 --- a/tllm/rpc/client.py +++ b/tllm/rpc/client.py @@ -11,6 +11,7 @@ from tllm.commons.convert import deserialize_bfloat16_tensor, serialize_bfloat16_tensor from tllm.models.llama import MyLlamaModel from tllm.rpc import schemas_pb2, schemas_pb2_grpc +from tllm.rpc.protocol import SeqInput logging.basicConfig(level=logging.INFO) @@ -33,11 +34,11 @@ def __init__(self, config, model, rank: int, pp_rank: int): else: while True: self.config.comm.broadcast_object(uuid_shape_list) - uuid, hidden_states_shape = uuid_shape_list + seq_input, hidden_states_shape = uuid_shape_list hidden_states = torch.empty(hidden_states_shape, dtype=self.model.dtype) self.config.comm.broadcast(hidden_states) - _ = self.model.forward(hidden_states, uuid) + _ = self.model.forward(hidden_states, seq_input=seq_input) def InitModel(self, request: schemas_pb2.ModelConfig, context: grpc.ServicerContext): """ @@ -64,10 +65,11 @@ def Forward(self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerCon s1 = time.time() hidden_states = deserialize_bfloat16_tensor(request.hidden_states) - self.config.comm.broadcast_object([request.uuid, tuple(hidden_states.shape)]) + seq_input = SeqInput(uuid_str_list=list(request.uuid), seq_len_list=list(request.seq_len)) + self.config.comm.broadcast_object([seq_input, tuple(hidden_states.shape)]) self.config.comm.broadcast(hidden_states) - output = self.model(hidden_states, request.uuid) + output = self.model(hidden_states, seq_input) return_output = serialize_bfloat16_tensor(output) cost_time = time.time() - s1 diff --git a/tllm/rpc/manager.py b/tllm/rpc/manager.py index 0d4e8ff..084e369 100644 --- a/tllm/rpc/manager.py +++ b/tllm/rpc/manager.py @@ -38,7 +38,7 @@ def init_model(self, stub, data): return stub.InitModel(request) def forward(self, stub, data): - request = schemas_pb2.ForwardRequest(uuid=data["uuid"], hidden_states=data["hidden_states"]) + request = schemas_pb2.ForwardRequest(**data) return stub.Forward(request) def health(self, stub): diff --git a/tllm/rpc/protocol.py b/tllm/rpc/protocol.py new file mode 100644 index 0000000..762bebe --- /dev/null +++ b/tllm/rpc/protocol.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class SeqInput: + uuid_str_list: List[str] + seq_len_list: List[int] diff --git a/tllm/rpc/schemas.proto b/tllm/rpc/schemas.proto index f4e63df..c7d7b88 100644 --- a/tllm/rpc/schemas.proto +++ b/tllm/rpc/schemas.proto @@ -19,8 +19,9 @@ message ModelConfig { } message ForwardRequest { - string uuid = 1; - BFloat16Tensor hidden_states = 2; + repeated string uuid = 1; + repeated int32 seq_len = 2; + BFloat16Tensor hidden_states = 3; } message StatusResponse { diff --git a/tllm/rpc/schemas_pb2.py b/tllm/rpc/schemas_pb2.py index d4cad91..73bd253 100644 --- a/tllm/rpc/schemas_pb2.py +++ b/tllm/rpc/schemas_pb2.py @@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16tllm/rpc/schemas.proto\x12\x07schemas"\x19\n\x05\x41rray\x12\x10\n\x08\x65lements\x18\x01 \x03(\x02"&\n\x06Matrix\x12\x1c\n\x04rows\x18\x01 \x03(\x0b\x32\x0e.schemas.Array")\n\x06Tensor\x12\x1f\n\x06layers\x18\x01 \x03(\x0b\x32\x0f.schemas.Matrix".\n\x0b\x42lockTensor\x12\x1f\n\x06\x62locks\x18\x01 \x03(\x0b\x32\x0f.schemas.Tensor"-\n\x0e\x42\x46loat16Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05"\xbb\x01\n\x15MultiDimensionalArray\x12\x1f\n\x05\x61rray\x18\x01 \x01(\x0b\x32\x0e.schemas.ArrayH\x00\x12!\n\x06matrix\x18\x02 \x01(\x0b\x32\x0f.schemas.MatrixH\x00\x12!\n\x06tensor\x18\x03 \x01(\x0b\x32\x0f.schemas.TensorH\x00\x12,\n\x0c\x62lock_tensor\x18\x04 \x01(\x0b\x32\x14.schemas.BlockTensorH\x00\x42\r\n\x0bmulti_array"\x8c\x01\n\x0bModelConfig\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x0f\n\x07pp_rank\x18\x02 \x01(\x05\x12\x17\n\x0flayer_idx_start\x18\x03 \x01(\x05\x12\x15\n\rlayer_idx_end\x18\x04 \x01(\x05\x12\x12\n\nmaster_url\x18\x05 \x01(\t\x12\x14\n\x0cnext_pp_rank\x18\x06 \x01(\x05"N\n\x0e\x46orwardRequest\x12\x0c\n\x04uuid\x18\x01 \x01(\t\x12.\n\rhidden_states\x18\x02 \x01(\x0b\x32\x17.schemas.BFloat16Tensor"-\n\x0eStatusResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"j\n\x0f\x46orwardResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\x12\'\n\x06output\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x11\n\tcost_time\x18\x04 \x01(\x02"-\n\x0eHealthResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"4\n\x15InitModelFlagResponse\x12\x0b\n\x03msg\x18\x01 \x01(\x08\x12\x0e\n\x06status\x18\x02 \x01(\x05"\x07\n\x05\x45mpty2\xfa\x01\n\nRPCService\x12:\n\tInitModel\x12\x14.schemas.ModelConfig\x1a\x17.schemas.StatusResponse\x12<\n\x07\x46orward\x12\x17.schemas.ForwardRequest\x1a\x18.schemas.ForwardResponse\x12\x31\n\x06Health\x12\x0e.schemas.Empty\x1a\x17.schemas.HealthResponse\x12?\n\rInitModelFlag\x12\x0e.schemas.Empty\x1a\x1e.schemas.InitModelFlagResponseb\x06proto3' + b'\n\x16tllm/rpc/schemas.proto\x12\x07schemas"-\n\x0e\x42\x46loat16Tensor\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05"\x8c\x01\n\x0bModelConfig\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x0f\n\x07pp_rank\x18\x02 \x01(\x05\x12\x17\n\x0flayer_idx_start\x18\x03 \x01(\x05\x12\x15\n\rlayer_idx_end\x18\x04 \x01(\x05\x12\x12\n\nmaster_url\x18\x05 \x01(\t\x12\x14\n\x0cnext_pp_rank\x18\x06 \x01(\x05"_\n\x0e\x46orwardRequest\x12\x0c\n\x04uuid\x18\x01 \x03(\t\x12\x0f\n\x07seq_len\x18\x02 \x03(\x05\x12.\n\rhidden_states\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor"-\n\x0eStatusResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"j\n\x0f\x46orwardResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05\x12\'\n\x06output\x18\x03 \x01(\x0b\x32\x17.schemas.BFloat16Tensor\x12\x11\n\tcost_time\x18\x04 \x01(\x02"-\n\x0eHealthResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x05"4\n\x15InitModelFlagResponse\x12\x0b\n\x03msg\x18\x01 \x01(\x08\x12\x0e\n\x06status\x18\x02 \x01(\x05"\x07\n\x05\x45mpty2\xfa\x01\n\nRPCService\x12:\n\tInitModel\x12\x14.schemas.ModelConfig\x1a\x17.schemas.StatusResponse\x12<\n\x07\x46orward\x12\x17.schemas.ForwardRequest\x1a\x18.schemas.ForwardResponse\x12\x31\n\x06Health\x12\x0e.schemas.Empty\x1a\x17.schemas.HealthResponse\x12?\n\rInitModelFlag\x12\x0e.schemas.Empty\x1a\x1e.schemas.InitModelFlagResponseb\x06proto3' ) _globals = globals() @@ -27,32 +27,22 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "tllm.rpc.schemas_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_ARRAY"]._serialized_start = 35 - _globals["_ARRAY"]._serialized_end = 60 - _globals["_MATRIX"]._serialized_start = 62 - _globals["_MATRIX"]._serialized_end = 100 - _globals["_TENSOR"]._serialized_start = 102 - _globals["_TENSOR"]._serialized_end = 143 - _globals["_BLOCKTENSOR"]._serialized_start = 145 - _globals["_BLOCKTENSOR"]._serialized_end = 191 - _globals["_BFLOAT16TENSOR"]._serialized_start = 193 - _globals["_BFLOAT16TENSOR"]._serialized_end = 238 - _globals["_MULTIDIMENSIONALARRAY"]._serialized_start = 241 - _globals["_MULTIDIMENSIONALARRAY"]._serialized_end = 428 - _globals["_MODELCONFIG"]._serialized_start = 431 - _globals["_MODELCONFIG"]._serialized_end = 571 - _globals["_FORWARDREQUEST"]._serialized_start = 573 - _globals["_FORWARDREQUEST"]._serialized_end = 651 - _globals["_STATUSRESPONSE"]._serialized_start = 653 - _globals["_STATUSRESPONSE"]._serialized_end = 698 - _globals["_FORWARDRESPONSE"]._serialized_start = 700 - _globals["_FORWARDRESPONSE"]._serialized_end = 806 - _globals["_HEALTHRESPONSE"]._serialized_start = 808 - _globals["_HEALTHRESPONSE"]._serialized_end = 853 - _globals["_INITMODELFLAGRESPONSE"]._serialized_start = 855 - _globals["_INITMODELFLAGRESPONSE"]._serialized_end = 907 - _globals["_EMPTY"]._serialized_start = 909 - _globals["_EMPTY"]._serialized_end = 916 - _globals["_RPCSERVICE"]._serialized_start = 919 - _globals["_RPCSERVICE"]._serialized_end = 1169 + _globals["_BFLOAT16TENSOR"]._serialized_start = 35 + _globals["_BFLOAT16TENSOR"]._serialized_end = 80 + _globals["_MODELCONFIG"]._serialized_start = 83 + _globals["_MODELCONFIG"]._serialized_end = 223 + _globals["_FORWARDREQUEST"]._serialized_start = 225 + _globals["_FORWARDREQUEST"]._serialized_end = 320 + _globals["_STATUSRESPONSE"]._serialized_start = 322 + _globals["_STATUSRESPONSE"]._serialized_end = 367 + _globals["_FORWARDRESPONSE"]._serialized_start = 369 + _globals["_FORWARDRESPONSE"]._serialized_end = 475 + _globals["_HEALTHRESPONSE"]._serialized_start = 477 + _globals["_HEALTHRESPONSE"]._serialized_end = 522 + _globals["_INITMODELFLAGRESPONSE"]._serialized_start = 524 + _globals["_INITMODELFLAGRESPONSE"]._serialized_end = 576 + _globals["_EMPTY"]._serialized_start = 578 + _globals["_EMPTY"]._serialized_end = 585 + _globals["_RPCSERVICE"]._serialized_start = 588 + _globals["_RPCSERVICE"]._serialized_end = 838 # @@protoc_insertion_point(module_scope) diff --git a/tllm/rpc/schemas_pb2.pyi b/tllm/rpc/schemas_pb2.pyi index 731b8e9..37b04a3 100644 --- a/tllm/rpc/schemas_pb2.pyi +++ b/tllm/rpc/schemas_pb2.pyi @@ -11,30 +11,6 @@ from typing import ( DESCRIPTOR: _descriptor.FileDescriptor -class Array(_message.Message): - __slots__ = ("elements",) - ELEMENTS_FIELD_NUMBER: _ClassVar[int] - elements: _containers.RepeatedScalarFieldContainer[float] - def __init__(self, elements: _Optional[_Iterable[float]] = ...) -> None: ... - -class Matrix(_message.Message): - __slots__ = ("rows",) - ROWS_FIELD_NUMBER: _ClassVar[int] - rows: _containers.RepeatedCompositeFieldContainer[Array] - def __init__(self, rows: _Optional[_Iterable[_Union[Array, _Mapping]]] = ...) -> None: ... - -class Tensor(_message.Message): - __slots__ = ("layers",) - LAYERS_FIELD_NUMBER: _ClassVar[int] - layers: _containers.RepeatedCompositeFieldContainer[Matrix] - def __init__(self, layers: _Optional[_Iterable[_Union[Matrix, _Mapping]]] = ...) -> None: ... - -class BlockTensor(_message.Message): - __slots__ = ("blocks",) - BLOCKS_FIELD_NUMBER: _ClassVar[int] - blocks: _containers.RepeatedCompositeFieldContainer[Tensor] - def __init__(self, blocks: _Optional[_Iterable[_Union[Tensor, _Mapping]]] = ...) -> None: ... - class BFloat16Tensor(_message.Message): __slots__ = ("data", "shape") DATA_FIELD_NUMBER: _ClassVar[int] @@ -43,24 +19,6 @@ class BFloat16Tensor(_message.Message): shape: _containers.RepeatedScalarFieldContainer[int] def __init__(self, data: _Optional[bytes] = ..., shape: _Optional[_Iterable[int]] = ...) -> None: ... -class MultiDimensionalArray(_message.Message): - __slots__ = ("array", "matrix", "tensor", "block_tensor") - ARRAY_FIELD_NUMBER: _ClassVar[int] - MATRIX_FIELD_NUMBER: _ClassVar[int] - TENSOR_FIELD_NUMBER: _ClassVar[int] - BLOCK_TENSOR_FIELD_NUMBER: _ClassVar[int] - array: Array - matrix: Matrix - tensor: Tensor - block_tensor: BlockTensor - def __init__( - self, - array: _Optional[_Union[Array, _Mapping]] = ..., - matrix: _Optional[_Union[Matrix, _Mapping]] = ..., - tensor: _Optional[_Union[Tensor, _Mapping]] = ..., - block_tensor: _Optional[_Union[BlockTensor, _Mapping]] = ..., - ) -> None: ... - class ModelConfig(_message.Message): __slots__ = ("model_name", "pp_rank", "layer_idx_start", "layer_idx_end", "master_url", "next_pp_rank") MODEL_NAME_FIELD_NUMBER: _ClassVar[int] @@ -86,13 +44,18 @@ class ModelConfig(_message.Message): ) -> None: ... class ForwardRequest(_message.Message): - __slots__ = ("uuid", "hidden_states") + __slots__ = ("uuid", "seq_len", "hidden_states") UUID_FIELD_NUMBER: _ClassVar[int] + SEQ_LEN_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] - uuid: str + uuid: _containers.RepeatedScalarFieldContainer[str] + seq_len: _containers.RepeatedScalarFieldContainer[int] hidden_states: BFloat16Tensor def __init__( - self, uuid: _Optional[str] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ... + self, + uuid: _Optional[_Iterable[str]] = ..., + seq_len: _Optional[_Iterable[int]] = ..., + hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ..., ) -> None: ... class StatusResponse(_message.Message):