Skip to content

Commit

Permalink
update more metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Oct 2, 2024
1 parent 0385213 commit 2b130d7
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 113 deletions.
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
- [x] pipeline-parallel by grpc
- [x] tensor-parallel by torch.dist
- [x] Merge Linear
- [ ] Performance Testing
- [ ] CPU llama-1B
- [ ] GPU llama-1B
- [ ] CPU llama-8B
- [x] Performance Testing
- [ ] Async Generation
- [ ] Queuing mechanism
- [ ] Decoding Strategy
Expand All @@ -37,21 +34,23 @@

在保证通信带宽的前提下,速度应当更快

bfloat 16 (need update)
| PP,TP | Llama-3.2-3B-Instruct | Meta-Llama-3-8B-Instruct | Meta-Llama-3-70B-Instruct |
| ---- | --------- | --- | ---------- |
| 2,1(实际) | 3.66 token/s | | | | |
| 2,2(实际) | | | | | |
由于 tokenizer 可能不同,所以输入 tokens 有一点出入,但基本差不多。

生成 token 速度(减去首token生成的时间)
bfloat 16 CPU
| PP,TP | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct |
| ---- | --------- | --- |
| 2,1(实际) | 8.04 token/s | 3.01 token/s |
| 2,2(实际) | 7.38 token/s | 2.51 token/s |

float 32
| PP,TP | TinyLlama-1.1B-Chat-v1.0 | Llama-3.2-1B-Instruct |
包含首 token 生成时间
| PP,TP | Llama-3.2-1B-Instruct | Llama-3.2-3B-Instruct |
| ---- | --------- | --- |
| 1,1(baseline) | 6.37 token/s; 17.98 token/s | |
| 2,1(单机模拟) | 5.91 token/s| |
| 2,2(单机模拟) | 5.46 token/s | |
| 2,1(实际) | 6.59 token/s | 5.43 token/s |
| 2,2(实际) | 6.44 token/s | 5.63 token/s |
| 2,1(实际) | 5.49 token/s | 2.42 token/s |
| 2,2(实际) | 5.53 token/s | 2.46 token/s |


Llama-3.2-1B-Instruct 单机时间:10.96 token/s
Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录)

TODO: Meta-Llama-3-8B-Instruct in GPU
62 changes: 28 additions & 34 deletions benchmarks/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,48 @@
from typing import *

import torch
from transformers import AutoTokenizer, LlamaForCausalLM


def load_model_and_tokenizer(model_path: str) -> Tuple[LlamaForCausalLM, AutoTokenizer]:
model = LlamaForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="cpu")
# https://gist.github.com/wnma3mz/64db3e69616b819de346635b7bfa1d36
# model.load_custom_weights(model.state_dict())
tok = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
return model, tok


formatted_prompt = "### Human: {}### Assistant:"


def tokenize_message(tok: AutoTokenizer, messages: List[Dict[str, str]]) -> List[int]:
inputs = formatted_prompt.format(messages[0]["content"])
# inputs = "Hello, how are you?"
# inputs = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tok.encode(inputs, add_special_tokens=True)
while input_ids[0] == input_ids[1] == tok.bos_token_id:
# input_ids = input_ids[1:]
input_ids.pop(0)
return input_ids
from transformers import LlamaForCausalLM

from tllm.generate.token_utils import TokenizerUtils

if __name__ == "__main__":
# model_path = "/Users/jianghulu/Documents/TinyLlama-1.1B-Chat-v0.1"
model_path = "/Users/lujianghu/Documents/TinyLlama-1.1B-Chat-v1.0"
model, tok = load_model_and_tokenizer(model_path)
model.eval()
model_path = "/Users/lujianghu/Documents/Llama-3.2-1B-Instruct"
tok = TokenizerUtils(model_path)
model = LlamaForCausalLM.from_pretrained(
model_path, trust_remote_code=True, device_map="cpu", torch_dtype=torch.bfloat16
)

messages = [{"role": "user", "content": "Hello, how are you?"}]
input_id_list = tokenize_message(tok, messages)
input_id_list = tok.preprocess(messages=messages).input_ids
input_ids = torch.tensor(input_id_list).unsqueeze(0)

model.eval()
print("input_ids: ", input_ids)
# output = model.generate(input_ids, max_new_tokens=50, tokenizer=tok, eos_token_id=[0, tok.eos_token_id])
# print(tok.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True))

# warmup
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=20, do_sample=False)
print("generate token: ", output[0])

time_list = []
# 模拟 TTFT 时间
cost_time_list = []
for _ in range(10):
s1 = time.time()
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=1, do_sample=False)
cost_time_list.append(time.time() - s1)
cost_time_list = sorted(cost_time_list)[1:-1]
ttft = sum(cost_time_list) / len(cost_time_list)
print("ttft: ", ttft)

# 模拟生成时间
cost_time_list = []
max_new_tokens = 20
for _ in range(10):
s1 = time.time()
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=False)
time_list.append(time.time() - s1)
# print(tok.decode(output[0][input_ids.shape[1] :], skip_special_tokens=True))
print("token/s: ", max_new_tokens / (sum(time_list) / len(time_list)))
cost_time_list.append(time.time() - s1 - ttft)
print("cost_time_list: ", cost_time_list)
cost_time_list = sorted(cost_time_list)[1:-1]
print("token/s: ", max_new_tokens / (sum(cost_time_list) / len(cost_time_list)))
19 changes: 19 additions & 0 deletions benchmarks/test_cost_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import requests
from tqdm import tqdm

if __name__ == "__main__":
url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json", "Authorization": "Bearer common"}
data = {"messages": [{"role": "user", "content": "Hello, how are you?"}]}
# warm-up
response = requests.post(url, headers=headers, json=data)

cost_time_list = []
for _ in tqdm(range(10)):
response = requests.post(url, headers=headers, json=data)
cost_time_list.append(response.json()["cost_time"])
# 去掉最大值和最小值
cost_time_list = sorted(cost_time_list)[1:-1]
test_tokens = 20
print("cost time: ", sum(cost_time_list) / len(cost_time_list))
print("token/s: ", test_tokens / (sum(cost_time_list) / len(cost_time_list)))
14 changes: 7 additions & 7 deletions examples/run.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#!/bin/bash
BASE_PATH=/Users/lujianghu/Documents/
MODE_SIZE=$1

if [ $MODE_SIZE == "1.1" ]; then
MODEL_PATH=/Users/lujianghu/Documents/TinyLlama-1.1B-Chat-v1.0
elif [ $MODE_SIZE == "1" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-1B-Instruct

if [ $MODE_SIZE == "1" ]; then
MODEL_PATH=$BASE_PATH/Llama-3.2-1B-Instruct
elif [ $MODE_SIZE == "3" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-3B-Instruct
MODEL_PATH=$BASE_PATH/Llama-3.2-3B-Instruct
elif [ $MODE_SIZE == "8" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Meta-Llama-3-8B-Instruct
MODEL_PATH=$BASE_PATH/Meta-Llama-3-8B-Instruct
elif [ $MODE_SIZE == "70" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Meta-Llama-3-70B-Instruct
MODEL_PATH=$BASE_PATH/Meta-Llama-3-70B-Instruct
else
echo "Invalid mode size"
exit 1
Expand Down
50 changes: 21 additions & 29 deletions examples/run_client.sh
Original file line number Diff line number Diff line change
@@ -1,75 +1,67 @@
#!/bin/bash
MASTER_PORT=29501
GRPC_PORT=25001
BASE_PATH=/Users/lujianghu/Documents/
MODE_SIZE=$1
TP=$2
PP=$3
PP_IDX=$3

# PP 当前仅限于 0 和 1
if [ $PP != "0" ] && [ $PP != "1" ]; then
# PP_IDX 当前仅限于 0 和 1
# MODEL_PATH 当前仅限于 1 3 8 70
if [ $PP_IDX != "0" ] && [ $PP_IDX != "1" ]; then
echo "Invalid pp size"
exit 1
fi

if [ $MODE_SIZE == "1.1" ]; then
MODEL_PATH=/Users/lujianghu/Documents/TinyLlama-1.1B-Chat-v1.0
if [ $PP == "0" ]; then
START_LAYER_IDX=0
END_LAYER_IDX=11
elif [ $PP == "1" ]; then
START_LAYER_IDX=11
END_LAYER_IDX=22
else
echo "Invalid pp size"
exit 1
fi
elif [ $MODE_SIZE == "1" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-1B-Instruct
if [ $PP == "0" ]; then

if [ $MODE_SIZE == "1" ]; then
MODEL_PATH=$BASE_PATH/Llama-3.2-1B-Instruct
if [ $PP_IDX == "0" ]; then
START_LAYER_IDX=0
END_LAYER_IDX=8
elif [ $PP == "1" ]; then
elif [ $PP_IDX == "1" ]; then
START_LAYER_IDX=8
END_LAYER_IDX=16
else
echo "Invalid pp size"
exit 1
fi
elif [ $MODE_SIZE == "3" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-3B-Instruct
if [ $PP == "0" ]; then
MODEL_PATH=$BASE_PATH/Llama-3.2-3B-Instruct
if [ $PP_IDX == "0" ]; then
START_LAYER_IDX=0
END_LAYER_IDX=14
elif [ $PP == "1" ]; then
elif [ $PP_IDX == "1" ]; then
START_LAYER_IDX=14
END_LAYER_IDX=28
else
echo "Invalid pp size"
exit 1
fi
elif [ $MODE_SIZE == "8" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Meta-Llama-3-8B-Instruct
if [ $PP == "0" ]; then
MODEL_PATH=$BASE_PATH/Meta-Llama-3-8B-Instruct
if [ $PP_IDX == "0" ]; then
START_LAYER_IDX=0
END_LAYER_IDX=16
elif [ $PP == "1" ]; then
elif [ $PP_IDX == "1" ]; then
START_LAYER_IDX=16
END_LAYER_IDX=32
else
echo "Invalid pp size"
exit 1
fi
elif [ $MODE_SIZE == "70" ]; then
MODEL_PATH=/Users/lujianghu/Documents/Meta-Llama-3-70B-Instruct
MODEL_PATH=$BASE_PATH/Meta-Llama-3-70B-Instruct
else
echo "Invalid mode size"
exit 1
fi

export OMP_NUM_THREADS=8;
export PYTHONPATH="./tllm":$PYTHONPATH;
export PYTHONPATH="./":$PYTHONPATH;

if [ $PP == "0" ]; then
if [ $PP_IDX == "0" ]; then
torchrun --nproc_per_node=$TP --master_port=$MASTER_PORT tllm/rpc/client.py --port=$GRPC_PORT --start_layer_idx=$START_LAYER_IDX --end_layer_idx=$END_LAYER_IDX --model_path $MODEL_PATH
elif [ $PP == "1" ]; then
elif [ $PP_IDX == "1" ]; then
torchrun --nproc_per_node=$TP --master_port=$(($MASTER_PORT+1)) tllm/rpc/client.py --port=$(($GRPC_PORT+1)) --start_layer_idx=$START_LAYER_IDX --end_layer_idx=$END_LAYER_IDX --model_path $MODEL_PATH
fi
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
transformers>=4.41.2
transformers==4.36.2
torch
grpcio
grpcio-tools
protobuf
grpcio==1.66.2
grpcio-tools==1.66.2
protobuf==5.28.0
requests
fastapi
tiktoken
tqdm
uvicorn
uvicorn
accelerate
sentencepiece
32 changes: 27 additions & 5 deletions tllm/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dataclasses import dataclass
import logging
import time
from typing import *
import uuid

Expand All @@ -19,6 +21,7 @@ class GenerateResult:
output_ids: List[int]
finish_reason: Optional[finish_reason_type] = None
output_text: Optional[str] = None
ttft: Optional[float] = None


@dataclass
Expand All @@ -27,6 +30,13 @@ class GenerateEnd:
is_end: bool


@dataclass
class ForwardResult:
hidden_states: Optional[torch.Tensor] = None
logits: torch.Tensor
comm_cost_time_list: Optional[List[float]] = None


def is_generate_end(output_ids: List[int], eos_token_id: int, max_new_tokens: int) -> GenerateEnd:
if len(output_ids) >= max_new_tokens:
return GenerateEnd(finish_reason="length", is_end=True)
Expand Down Expand Up @@ -66,19 +76,23 @@ def from_pretrained(cls, model_path: str, weight_path: str, server: RPCManager,
def _prepare_forward_data(self, uuid_str: str, hidden_states: torch.Tensor) -> Dict[str, Any]:
return {"uuid": uuid_str, "hidden_states": tensor_to_list(hidden_states)}

def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> Tuple[torch.Tensor, None]:
def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> ForwardResult:
hidden_states = inputs_embeds
comm_cost_time_list = []
for pp_idx in range(self.pp_size):
s1 = time.time()
outputs = self.server.post_sync(
pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states)
)
s2 = time.time()
assert self.server.is_success(outputs), "Forward failed"
hidden_states = self.server.fetch_list_output(outputs)
hidden_states = outputs.hidden_states
comm_cost_time_list.append(s2 - s1 - outputs.cost_time)

hidden_states = torch.tensor(hidden_states).to(inputs_embeds.dtype).to(self.norm.weight.device)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits, None
return ForwardResult(logits=logits, comm_cost_time_list=comm_cost_time_list)

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> GenerateResult:
Expand All @@ -88,8 +102,11 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G
output_ids: List[int] = []
finish_reason = None
uuid_str = str(uuid.uuid4())
ttft_start_time, ttft_end_time = time.time(), time.time()
while True:
logits, _ = self(input_embeds, uuid_str)
forward_result = self(input_embeds, uuid_str)
logits = forward_result.logits
comm_cost_time_list = forward_result.comm_cost_time_list
generate_ids = sampler.decode(logits)
output_ids.append(generate_ids[0])

Expand All @@ -99,5 +116,10 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G
break

input_embeds = self.embed_tokens(torch.tensor(generate_ids)).unsqueeze(0)
if len(output_ids) == 1:
ttft_end_time = time.time()
logging.info(f"ttft communication cost time: {comm_cost_time_list}")
else:
logging.info(f"tpot communication cost time: {comm_cost_time_list}")

return GenerateResult(output_ids=output_ids, finish_reason=finish_reason)
return GenerateResult(output_ids=output_ids, finish_reason=finish_reason, ttft=ttft_end_time - ttft_start_time)
Loading

0 comments on commit 2b130d7

Please sign in to comment.