Skip to content

Commit

Permalink
update forward seq_len=1 batch
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Oct 3, 2024
1 parent 61f7898 commit d9a382b
Show file tree
Hide file tree
Showing 21 changed files with 417 additions and 163 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
- [x] Performance Testing
- [ ] Async Generation
- [ ] Queuing mechanism
- [ ] Decoding Strategy
- [ ] Continuous Batch
- [ ] split model before load
- [ ] Streaming Output
- [ ] OpenAI API format
- [x] Decoding Strategy
- [x] Top-K Sampling
- [x] Top-P Sampling
- [x] Temperature Sampling
- [ ] Model
- [x] LLaMA
- [ ] Multi-Model
Expand All @@ -23,13 +30,14 @@
- [ ] MLX Framework
- [ ] LORA Training
- [ ] Shard Storage
- [ ] split model before load
- [ ] Streaming Output

### Performance


- 2 GHz 四核Intel Core i5, 16 GB 3733 MHz LPDDR4X
- Llama-3.2-1B-Instruct 单机时间:10.96 token/s
- Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录)

- Apple M3 Pro, 18 GB

在保证通信带宽的前提下,速度应当更快
Expand All @@ -47,11 +55,9 @@ bfloat 16 CPU
| PP,TP | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct |
| ---- | --------- | --- |
| 2,1(实际) | 5.49 token/s | 2.42 token/s |
| 2,2(实际) | 5.53 token/s | 2.46 token/s |
| 2,2(实际) | 5.66 token/s | 2.46 token/s |


Llama-3.2-1B-Instruct 单机时间:10.96 token/s
Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录)

TODO: Meta-Llama-3-8B-Instruct in GPU

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
output = model.generate(input_ids, max_new_tokens=20, do_sample=False)
print("generate token: ", output[0])

assert False

# 模拟 TTFT 时间
cost_time_list = []
for _ in range(10):
Expand Down
34 changes: 34 additions & 0 deletions benchmarks/run_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import time

import torch

from tllm.engine import MyLlamaForCausalLM
from tllm.generate.decode_utils import DecodeUtils
from tllm.generate.token_utils import TokenizerUtils
from tllm.protocol import ChatCompletionRequest
from tllm.rpc.manager import RPCManager

if __name__ == "__main__":
BASE_PATH = "/Users/lujianghu/Documents/"
model_path = os.path.join(BASE_PATH, "Llama-3.2-1B-Instruct")
weight_path = os.path.join(model_path, "master_weight.pt")

url_list = ["localhost:25001"]
server = RPCManager(url_list)
model = MyLlamaForCausalLM.from_pretrained(model_path, weight_path, server)
tok = TokenizerUtils(model_path)

request = ChatCompletionRequest(
messages=[{"role": "user", "content": "Hello, how are you?"}], max_tokens=20, do_sample=False
)
input_id_list = tok.preprocess(messages=request.messages).input_ids

input_ids = torch.tensor(input_id_list).unsqueeze(0)
print("input_ids: ", input_ids)

s1 = time.time()
output = model.generate(
input_ids, max_new_tokens=request.max_tokens, do_sample=request.do_sample, sampler=DecodeUtils("greedy")
)
print(output)
2 changes: 1 addition & 1 deletion examples/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
{
"pp_rank": 1,
"url": "localhost:25002",
"url": "192.168.0.101:25002",
"tp_size": 2,
"master_port": 29502,
"layer_idx": [11, 22]
Expand Down
31 changes: 31 additions & 0 deletions examples/run_single_client.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash
MASTER_PORT=29501
GRPC_PORT=25001
BASE_PATH=/Users/lujianghu/Documents/
MODE_SIZE=$1
TP=$2


if [ $MODE_SIZE == "1" ]; then
MODEL_PATH=$BASE_PATH/Llama-3.2-1B-Instruct
START_LAYER_IDX=0
END_LAYER_IDX=16
elif [ $MODE_SIZE == "3" ]; then
MODEL_PATH=$BASE_PATH/Llama-3.2-3B-Instruct
START_LAYER_IDX=0
END_LAYER_IDX=28
elif [ $MODE_SIZE == "8" ]; then
MODEL_PATH=$BASE_PATH/Meta-Llama-3-8B-Instruct
START_LAYER_IDX=0
END_LAYER_IDX=32
elif [ $MODE_SIZE == "70" ]; then
MODEL_PATH=$BASE_PATH/Meta-Llama-3-70B-Instruct
else
echo "Invalid mode size"
exit 1
fi

export OMP_NUM_THREADS=8;
export PYTHONPATH="./":$PYTHONPATH;

torchrun --nproc_per_node=$TP --master_port=$MASTER_PORT tllm/rpc/client.py --port=$GRPC_PORT --start_layer_idx=$START_LAYER_IDX --end_layer_idx=$END_LAYER_IDX --model_path $MODEL_PATH
53 changes: 53 additions & 0 deletions exps/test_attn/causal_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List

import torch


def merge_mask(mask_list: List[torch.Tensor], total_length: int) -> torch.Tensor:
combined_mask = torch.zeros((total_length, total_length), dtype=torch.bool)

start_index = 0
for mask in mask_list:
combined_mask[start_index : start_index + mask.size(0), start_index : start_index + mask.size(1)] = mask
start_index += mask.size(0)

combined_attn_bias = torch.zeros(total_length, total_length, dtype=torch.float)
combined_attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
return combined_attn_bias


def build_mask(mask: torch.Tensor) -> torch.Tensor:
attn_bias = torch.zeros(mask.size(0), mask.size(1), dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
return attn_bias


def build_qkv(bs, num_heads, seq_len, head_dim):
query = torch.randn(bs, num_heads, seq_len, head_dim)
key = torch.randn(bs, num_heads, seq_len, head_dim)
value = torch.randn(bs, num_heads, seq_len, head_dim)
return query, key, value


if __name__ == "__main__":
seq_len1, seq_len2 = 3, 4
temp_mask = torch.ones(seq_len1, seq_len1, dtype=torch.bool).tril(diagonal=0)
temp_mask2 = torch.ones(seq_len2, seq_len2, dtype=torch.bool).tril(diagonal=0)

combined_attn_bias = merge_mask([temp_mask, temp_mask2], seq_len1 + seq_len2)

# bs, num_heads, seq_len, head_dim
query1, key1, value1 = build_qkv(1, 2, seq_len1, 4)
base_out1 = torch.nn.functional.scaled_dot_product_attention(query1, key1, value1, is_causal=True)

query2, key2, value2 = build_qkv(1, 2, seq_len2, 4)
base_out2 = torch.nn.functional.scaled_dot_product_attention(query2, key2, value2, is_causal=True)

query = torch.cat([query1, query2], dim=-2)
key = torch.cat([key1, key2], dim=-2)
value = torch.cat([value1, value2], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=combined_attn_bias)
out1, out2 = torch.split(out, [seq_len1, seq_len2], dim=-2)

print("torch.allclose(base_out1, out1)", torch.allclose(base_out1, out1))
print("torch.allclose(base_out2, out2)", torch.allclose(base_out2, out2))
45 changes: 45 additions & 0 deletions exps/test_attn/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Tuple

import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb


def build_qkv(bs, num_heads, seq_len, head_dim):
query = torch.randn(bs, num_heads, seq_len, head_dim)
key = torch.randn(bs, num_heads, seq_len, head_dim)
value = torch.randn(bs, num_heads, seq_len, head_dim)
return query, key, value


if __name__ == "__main__":
seq_len1, seq_len2 = 3, 4
head_dim = 4
query1, key1, value1 = build_qkv(1, 2, seq_len1, head_dim)
query2, key2, value2 = build_qkv(1, 2, seq_len2, head_dim)
position_ids1 = torch.arange(seq_len1).unsqueeze(0)
position_ids2 = torch.arange(seq_len2).unsqueeze(0)

rotary_emb = LlamaRotaryEmbedding(
head_dim,
max_position_embeddings=4096,
base=10000,
)
cos1, sin1 = rotary_emb(value1, seq_len=seq_len1)
cos2, sin2 = rotary_emb(value2, seq_len=seq_len2)

query1_out, key1_out = apply_rotary_pos_emb(query1, key1, cos1, sin1, position_ids1)
query2_out, key2_out = apply_rotary_pos_emb(query2, key2, cos2, sin2, position_ids2)

large_cos1, large_sin1 = rotary_emb(value1, seq_len=seq_len1 * 10)
large_query1, large_key1 = apply_rotary_pos_emb(query1, key1, large_cos1, large_sin1, position_ids1)

print("test large seq len", torch.allclose(large_query1, query1_out))

# merge
query = torch.cat([query1, query2], dim=-2)
key = torch.cat([key1, key2], dim=-2)
position_ids = torch.cat([position_ids1, position_ids2], dim=-1)
cos, sin = rotary_emb(value1, seq_len=max(seq_len1, seq_len2))
query_out, key_out = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
print("test merge rope for query", torch.allclose(query_out, torch.cat([query1_out, query2_out], dim=-2)))
print("test merge rope for key", torch.allclose(key_out, torch.cat([key1_out, key2_out], dim=-2)))
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ exclude = '''
| .venv
| build
| dist
| .pyi
| schemas_pb2.py
| schemas_pb2_grpc.py
)/
'''

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.36.2
transformers==4.45.1
torch
grpcio==1.66.2
grpcio-tools==1.66.2
Expand Down
2 changes: 1 addition & 1 deletion tllm/commons/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, max_alive_time=60):
self.cache_dict = {}

def get(self, key) -> Any:
return self.cache_dict.get(key)
return self.cache_dict.get(key)["past_key_values"]

def set(self, key, value: Any) -> None:
self.cache_dict[key] = {"past_key_values": value, "ts": time.time()}
Expand Down
61 changes: 25 additions & 36 deletions tllm/commons/layers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from dataclasses import dataclass
from typing import *

import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaConfig,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv


@dataclass
class AttentionCache:
uuid_str_list: List[str]
position_ids: torch.Tensor
past_key_value: "DynamicCache"
attn_mask: torch.Tensor


class BaseParallelLayer(nn.Module):
Expand Down Expand Up @@ -150,14 +152,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.rank,
)
self.o_proj = RowParallelLayer(self.num_heads * self.head_dim, self.hidden_size, self.world_size, self.rank)
self._init_rope()

def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)

def load_state_dict(self, state_dict: Dict) -> None:
for key in ["o_proj"]:
Expand All @@ -175,8 +169,8 @@ def load_state_dict(self, state_dict: Dict) -> None:
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_cache: AttentionCache,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -186,16 +180,15 @@ def forward(
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx - self.config.offset)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, attention_cache.position_ids
)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
if attention_cache.past_key_value is not None:
cache_kwargs = {"uuid_str_list": attention_cache.uuid_str_list}
key_states, value_states = attention_cache.past_key_value.update(
key_states, value_states, self.layer_idx - self.config.offset, cache_kwargs
)

Expand All @@ -204,18 +197,14 @@ def forward(

# TODO: speed up the following line
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and q_len > 1,
query_states, key_states, value_states, attn_mask=attention_cache.attn_mask
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)

attn_output = self.comm.all_reduce(self.o_proj(attn_output))
return attn_output, None, past_key_value
return attn_output, None, attention_cache.past_key_value


class MyLlamaDecoderLayer(nn.Module):
Expand Down Expand Up @@ -245,8 +234,8 @@ def load_state_dict(self, state_dict: Dict):
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_cache: AttentionCache,
) -> Tuple[torch.Tensor, Optional["Cache"]]:
residual = hidden_states

Expand All @@ -255,8 +244,8 @@ def forward(
# Self Attention
hidden_states, _, past_key_value = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
attention_cache=attention_cache,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states

Expand Down
Loading

0 comments on commit d9a382b

Please sign in to comment.