diff --git a/tllm/commons/communicator.py b/tllm/commons/communicator.py index b8cd3bf..d9626c2 100644 --- a/tllm/commons/communicator.py +++ b/tllm/commons/communicator.py @@ -68,3 +68,6 @@ def gather(self, x: torch.Tensor): def broadcast(self, x: torch.Tensor): dist.broadcast(x, src=0) + + def broadcast_object(self, obj_list: List[Any]): + dist.broadcast_object_list(obj_list, src=0) diff --git a/tllm/engine.py b/tllm/engine.py index e202747..64ff092 100644 --- a/tllm/engine.py +++ b/tllm/engine.py @@ -88,11 +88,13 @@ def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> ForwardResult: outputs = self.server.post_sync( pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states, need_serialize=pp_idx == 0) ) - hidden_states = deserialize_bfloat16_tensor(outputs.output) if pp_idx == 0 else outputs.output + hidden_states = ( + deserialize_bfloat16_tensor(outputs.output) if pp_idx == self.pp_size - 1 else outputs.output + ) s2 = time.time() comm_cost_time_list.append(s2 - s1 - outputs.cost_time) - hidden_states = torch.tensor(hidden_states).to(inputs_embeds.dtype).to(self.norm.weight.device) + hidden_states = hidden_states.to(self.norm.weight.device) hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return ForwardResult(logits=logits, comm_cost_time_list=comm_cost_time_list) @@ -121,8 +123,8 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G input_embeds = self.embed_tokens(torch.tensor(generate_ids)).unsqueeze(0) if len(output_ids) == 1: ttft_end_time = time.time() - logging.info(f"ttft communication cost time: {comm_cost_time_list}") + logging.info(f"ttft communication cost time: {",".join([f'{x:.4f}' for x in comm_cost_time_list])}") else: - logging.info(f"tpot communication cost time: {comm_cost_time_list}") + logging.info(f"tpot communication cost time: {",".join([f'{x:.4f}' for x in comm_cost_time_list])}") return GenerateResult(output_ids=output_ids, finish_reason=finish_reason, ttft=ttft_end_time - ttft_start_time) diff --git a/tllm/rpc/client.py b/tllm/rpc/client.py index dfc182b..62229ff 100644 --- a/tllm/rpc/client.py +++ b/tllm/rpc/client.py @@ -2,27 +2,19 @@ from concurrent import futures import logging import os -import pickle import time from typing import * import grpc from tllm.commons.communicator import Communicator, SingleNodeCommunicator -from tllm.commons.convert import ( - deserialize_bfloat16_tensor, - list_to_protobuf, - protobuf_to_list, - serialize_bfloat16_tensor, -) +from tllm.commons.convert import deserialize_bfloat16_tensor, serialize_bfloat16_tensor from tllm.models.llama import MyLlamaModel from tllm.rpc import schemas_pb2, schemas_pb2_grpc -from tllm.utils import get_ip_address, tensor_to_list logging.basicConfig(level=logging.INFO) import torch -import torch.distributed as dist from transformers import AutoConfig, LlamaForCausalLM @@ -33,22 +25,17 @@ def __init__(self, config, model, rank: int, pp_rank: int): self.rank = rank self.pp_rank = pp_rank self.init_model_flag = False - # self.ip_addr = get_ip_address() self.ip_addr = "localhost" self.prefix_log_str = f"IP: [{self.ip_addr}]" - # TODO 优化,同步 uuid str 和 shape,会导致 libc++abi: terminating due to uncaught exception of type gloo::EnforceNotMet: - # tensor_data = torch.empty(86, dtype=torch.uint8) - tensor_data = torch.empty(83, dtype=torch.uint8) + uuid_shape_list = [None, None] if self.rank == 0: pass else: while True: - dist.recv(tensor_data, src=0) - serialized_data = bytes(tensor_data.numpy()) - hidden_states_shape, uuid = pickle.loads(serialized_data) - + self.config.comm.broadcast_object(uuid_shape_list) + uuid, hidden_states_shape = uuid_shape_list hidden_states = torch.empty(hidden_states_shape, dtype=self.model.dtype) - dist.recv(hidden_states, src=0) + self.config.comm.broadcast(hidden_states) _ = self.model.forward(hidden_states, uuid) @@ -77,11 +64,8 @@ def Forward(self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerCon s1 = time.time() hidden_states = deserialize_bfloat16_tensor(request.hidden_states) - serialized_data = list(pickle.dumps((hidden_states.shape, request.uuid))) - tensor_data = torch.ByteTensor(serialized_data) - for rank in range(1, self.config.comm.world_size): - dist.send(tensor_data, dst=rank) - dist.send(hidden_states, dst=rank) + self.config.comm.broadcast_object([request.uuid, tuple(hidden_states.shape)]) + self.config.comm.broadcast(hidden_states) output = self.model(hidden_states, request.uuid) @@ -134,11 +118,12 @@ def start_grpc_server(config, model, port, rank, pp_rank): config.decoder_end_layer_idx = args.end_layer_idx config.comm = comm + dtype = torch.bfloat16 s1 = time.time() state_dict = LlamaForCausalLM.from_pretrained( - args.model_path, trust_remote_code=True, device_map="cpu", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + args.model_path, trust_remote_code=True, device_map="cpu", torch_dtype=dtype, low_cpu_mem_usage=True ).state_dict() - model = MyLlamaModel(config) + model = MyLlamaModel(config).to(dtype) model.load_state_dict(state_dict) logging.info(f"[Rank: {config.comm.rank}] Cost time {time.time() - s1}") model.eval()