Skip to content

Commit

Permalink
bytes bfloat16 replace float32
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Oct 2, 2024
1 parent 2b130d7 commit de44276
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 66 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ bfloat 16 CPU
Llama-3.2-1B-Instruct 单机时间:10.96 token/s
Llama-3.2-1B-Instruct 单机时间:5.73 token/s(包含首token生成的时间, transformers 框架 TTFT 时间不方便记录)

TODO: Meta-Llama-3-8B-Instruct in GPU
TODO: Meta-Llama-3-8B-Instruct in GPU

多维数组实现(float32): 单机通信在 0.002 s 左右 (seq-len=1)
bytes 实现(float32): 单机通信在 0.001 s 左右 (seq-len=1)
4 changes: 2 additions & 2 deletions examples/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"url": "localhost:25001",
"tp_size": 2,
"master_port": 29501,
"layer_idx": [0, 11]
"layer_idx": [0, 1]
},
{
"pp_rank": 1,
"url": "192.168.0.101:25002",
"url": "localhost:25002",
"tp_size": 2,
"master_port": 29502,
"layer_idx": [11, 22]
Expand Down
17 changes: 17 additions & 0 deletions tllm/commons/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List

import torch

from tllm.rpc import schemas_pb2, schemas_pb2_grpc
from tllm.rpc.schemas_pb2 import BFloat16Tensor


def protobuf_to_list(proto_message):
Expand Down Expand Up @@ -80,3 +83,17 @@ def list_to_protobuf(data: List):
raise ValueError("Input data must be a list.")

return multi_array_proto


def serialize_bfloat16_tensor(tensor) -> BFloat16Tensor:
# TODO: support bfloat16
tensor_proto = BFloat16Tensor()
tensor_proto.shape.extend(tensor.shape) # 添加形状
tensor_proto.data = tensor.to(torch.float32).numpy().tobytes()
return tensor_proto


def deserialize_bfloat16_tensor(tensor_proto) -> torch.Tensor:
data = torch.frombuffer(tensor_proto.data, dtype=torch.float32).to(torch.bfloat16)
tensor_data = data.view(*tensor_proto.shape)
return tensor_data
17 changes: 10 additions & 7 deletions tllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from transformers import AutoConfig
from transformers.models.llama.modeling_llama import LlamaRMSNorm

from tllm.commons.convert import deserialize_bfloat16_tensor, serialize_bfloat16_tensor
from tllm.generate.decode_utils import DecodeUtils
from tllm.rpc.manager import RPCManager
from tllm.utils import tensor_to_list

finish_reason_type = Literal["length", "stop", None]

logging.basicConfig(level=logging.INFO)


@dataclass
class GenerateResult:
Expand All @@ -32,9 +34,9 @@ class GenerateEnd:

@dataclass
class ForwardResult:
hidden_states: Optional[torch.Tensor] = None
logits: torch.Tensor
comm_cost_time_list: Optional[List[float]] = None
hidden_states: Optional[torch.Tensor] = None


def is_generate_end(output_ids: List[int], eos_token_id: int, max_new_tokens: int) -> GenerateEnd:
Expand Down Expand Up @@ -73,20 +75,21 @@ def from_pretrained(cls, model_path: str, weight_path: str, server: RPCManager,
model.eval()
return model

def _prepare_forward_data(self, uuid_str: str, hidden_states: torch.Tensor) -> Dict[str, Any]:
return {"uuid": uuid_str, "hidden_states": tensor_to_list(hidden_states)}
def _prepare_forward_data(self, uuid_str: str, hidden_states: torch.Tensor, need_serialize: bool) -> Dict[str, Any]:
if need_serialize:
hidden_states = serialize_bfloat16_tensor(hidden_states)
return {"uuid": uuid_str, "hidden_states": hidden_states}

def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> ForwardResult:
hidden_states = inputs_embeds
comm_cost_time_list = []
for pp_idx in range(self.pp_size):
s1 = time.time()
outputs = self.server.post_sync(
pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states)
pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states, need_serialize=pp_idx == 0)
)
hidden_states = deserialize_bfloat16_tensor(outputs.output) if pp_idx == 0 else outputs.output
s2 = time.time()
assert self.server.is_success(outputs), "Forward failed"
hidden_states = outputs.hidden_states
comm_cost_time_list.append(s2 - s1 - outputs.cost_time)

hidden_states = torch.tensor(hidden_states).to(inputs_embeds.dtype).to(self.norm.weight.device)
Expand Down
14 changes: 9 additions & 5 deletions tllm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import grpc

from tllm.commons.communicator import Communicator, SingleNodeCommunicator
from tllm.commons.convert import list_to_protobuf, protobuf_to_list
from tllm.commons.convert import (
deserialize_bfloat16_tensor,
list_to_protobuf,
protobuf_to_list,
serialize_bfloat16_tensor,
)
from tllm.models.llama import MyLlamaModel
from tllm.rpc import schemas_pb2, schemas_pb2_grpc
from tllm.utils import get_ip_address, tensor_to_list
Expand Down Expand Up @@ -66,12 +71,11 @@ def InitModel(self, request: schemas_pb2.ModelConfig, context: grpc.ServicerCont
def Forward(self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerContext):
"""
@param request: ForwardRequest
hidden_states: torch.Tensor
hidden_states: bytes
uuid: str
"""
s1 = time.time()
hidden_states = protobuf_to_list(request.hidden_states)
hidden_states = torch.tensor(hidden_states, dtype=self.model.dtype)
hidden_states = deserialize_bfloat16_tensor(request.hidden_states)

serialized_data = list(pickle.dumps((hidden_states.shape, request.uuid)))
tensor_data = torch.ByteTensor(serialized_data)
Expand All @@ -81,7 +85,7 @@ def Forward(self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerCon

output = self.model(hidden_states, request.uuid)

return_output = list_to_protobuf(tensor_to_list(output))
return_output = serialize_bfloat16_tensor(output)
cost_time = time.time() - s1
logging.info(f"{self.prefix_log_str} Forward pass cost time: {cost_time:.2f} s")

Expand Down
2 changes: 1 addition & 1 deletion tllm/rpc/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def init_model(self, stub, data):
return stub.InitModel(request)

def forward(self, stub, data):
request = schemas_pb2.ForwardRequest(uuid=data["uuid"], hidden_states=list_to_protobuf(data["hidden_states"]))
request = schemas_pb2.ForwardRequest(uuid=data["uuid"], hidden_states=data["hidden_states"])
return stub.Forward(request)

def health(self, stub):
Expand Down
29 changes: 5 additions & 24 deletions tllm/rpc/schemas.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,12 @@ syntax = "proto3";

package schemas;

message Array {
repeated float elements = 1; // One-dimensional array of integers
}

message Matrix {
repeated Array rows = 1; // Two-dimensional array of integers
}

message Tensor {
repeated Matrix layers = 1; // Three-dimensional array of integers
message BFloat16Tensor {
bytes data = 1; // 使用 bytes 存储 bfloat16 数据
repeated int32 shape = 2; // 形状信息
}

message BlockTensor {
repeated Tensor blocks = 1; // Four-dimensional array of integers
}

message MultiDimensionalArray {
// Union type to represent any dimensional array
oneof multi_array {
Array array = 1;
Matrix matrix = 2;
Tensor tensor = 3;
BlockTensor block_tensor = 4;
}
}

message ModelConfig {
string model_name = 1;
Expand All @@ -39,7 +20,7 @@ message ModelConfig {

message ForwardRequest {
string uuid = 1;
MultiDimensionalArray hidden_states = 2;
BFloat16Tensor hidden_states = 2;
}

message StatusResponse {
Expand All @@ -50,7 +31,7 @@ message StatusResponse {
message ForwardResponse {
string msg = 1;
int32 status = 2;
MultiDimensionalArray output = 3;
BFloat16Tensor output = 3;
float cost_time = 4;
}

Expand Down
40 changes: 21 additions & 19 deletions tllm/rpc/schemas_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 15 additions & 7 deletions tllm/rpc/schemas_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import (
ClassVar as _ClassVar,
Iterable as _Iterable,
Expand All @@ -6,9 +9,6 @@ from typing import (
Union as _Union,
)

from google.protobuf import descriptor as _descriptor, message as _message
from google.protobuf.internal import containers as _containers

DESCRIPTOR: _descriptor.FileDescriptor

class Array(_message.Message):
Expand All @@ -35,6 +35,14 @@ class BlockTensor(_message.Message):
blocks: _containers.RepeatedCompositeFieldContainer[Tensor]
def __init__(self, blocks: _Optional[_Iterable[_Union[Tensor, _Mapping]]] = ...) -> None: ...

class BFloat16Tensor(_message.Message):
__slots__ = ("data", "shape")
DATA_FIELD_NUMBER: _ClassVar[int]
SHAPE_FIELD_NUMBER: _ClassVar[int]
data: bytes
shape: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, data: _Optional[bytes] = ..., shape: _Optional[_Iterable[int]] = ...) -> None: ...

class MultiDimensionalArray(_message.Message):
__slots__ = ("array", "matrix", "tensor", "block_tensor")
ARRAY_FIELD_NUMBER: _ClassVar[int]
Expand Down Expand Up @@ -82,9 +90,9 @@ class ForwardRequest(_message.Message):
UUID_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
uuid: str
hidden_states: MultiDimensionalArray
hidden_states: BFloat16Tensor
def __init__(
self, uuid: _Optional[str] = ..., hidden_states: _Optional[_Union[MultiDimensionalArray, _Mapping]] = ...
self, uuid: _Optional[str] = ..., hidden_states: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...
) -> None: ...

class StatusResponse(_message.Message):
Expand All @@ -103,13 +111,13 @@ class ForwardResponse(_message.Message):
COST_TIME_FIELD_NUMBER: _ClassVar[int]
msg: str
status: int
output: MultiDimensionalArray
output: BFloat16Tensor
cost_time: float
def __init__(
self,
msg: _Optional[str] = ...,
status: _Optional[int] = ...,
output: _Optional[_Union[MultiDimensionalArray, _Mapping]] = ...,
output: _Optional[_Union[BFloat16Tensor, _Mapping]] = ...,
cost_time: _Optional[float] = ...,
) -> None: ...

Expand Down

0 comments on commit de44276

Please sign in to comment.