Skip to content

Commit

Permalink
tp comm update to broadcast_object
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Oct 2, 2024
1 parent de44276 commit 61f7898
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 29 deletions.
3 changes: 3 additions & 0 deletions tllm/commons/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ def gather(self, x: torch.Tensor):

def broadcast(self, x: torch.Tensor):
dist.broadcast(x, src=0)

def broadcast_object(self, obj_list: List[Any]):
dist.broadcast_object_list(obj_list, src=0)
10 changes: 6 additions & 4 deletions tllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def forward(self, inputs_embeds: torch.Tensor, uuid_str: str) -> ForwardResult:
outputs = self.server.post_sync(
pp_idx, "/forward", data=self._prepare_forward_data(uuid_str, hidden_states, need_serialize=pp_idx == 0)
)
hidden_states = deserialize_bfloat16_tensor(outputs.output) if pp_idx == 0 else outputs.output
hidden_states = (
deserialize_bfloat16_tensor(outputs.output) if pp_idx == self.pp_size - 1 else outputs.output
)
s2 = time.time()
comm_cost_time_list.append(s2 - s1 - outputs.cost_time)

hidden_states = torch.tensor(hidden_states).to(inputs_embeds.dtype).to(self.norm.weight.device)
hidden_states = hidden_states.to(self.norm.weight.device)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return ForwardResult(logits=logits, comm_cost_time_list=comm_cost_time_list)
Expand Down Expand Up @@ -121,8 +123,8 @@ def generate(self, input_ids: torch.Tensor, sampler: DecodeUtils, **kwargs) -> G
input_embeds = self.embed_tokens(torch.tensor(generate_ids)).unsqueeze(0)
if len(output_ids) == 1:
ttft_end_time = time.time()
logging.info(f"ttft communication cost time: {comm_cost_time_list}")
logging.info(f"ttft communication cost time: {",".join([f'{x:.4f}' for x in comm_cost_time_list])}")
else:
logging.info(f"tpot communication cost time: {comm_cost_time_list}")
logging.info(f"tpot communication cost time: {",".join([f'{x:.4f}' for x in comm_cost_time_list])}")

return GenerateResult(output_ids=output_ids, finish_reason=finish_reason, ttft=ttft_end_time - ttft_start_time)
35 changes: 10 additions & 25 deletions tllm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,19 @@
from concurrent import futures
import logging
import os
import pickle
import time
from typing import *

import grpc

from tllm.commons.communicator import Communicator, SingleNodeCommunicator
from tllm.commons.convert import (
deserialize_bfloat16_tensor,
list_to_protobuf,
protobuf_to_list,
serialize_bfloat16_tensor,
)
from tllm.commons.convert import deserialize_bfloat16_tensor, serialize_bfloat16_tensor
from tllm.models.llama import MyLlamaModel
from tllm.rpc import schemas_pb2, schemas_pb2_grpc
from tllm.utils import get_ip_address, tensor_to_list

logging.basicConfig(level=logging.INFO)

import torch
import torch.distributed as dist
from transformers import AutoConfig, LlamaForCausalLM


Expand All @@ -33,22 +25,17 @@ def __init__(self, config, model, rank: int, pp_rank: int):
self.rank = rank
self.pp_rank = pp_rank
self.init_model_flag = False
# self.ip_addr = get_ip_address()
self.ip_addr = "localhost"
self.prefix_log_str = f"IP: [{self.ip_addr}]"
# TODO 优化,同步 uuid str 和 shape,会导致 libc++abi: terminating due to uncaught exception of type gloo::EnforceNotMet:
# tensor_data = torch.empty(86, dtype=torch.uint8)
tensor_data = torch.empty(83, dtype=torch.uint8)
uuid_shape_list = [None, None]
if self.rank == 0:
pass
else:
while True:
dist.recv(tensor_data, src=0)
serialized_data = bytes(tensor_data.numpy())
hidden_states_shape, uuid = pickle.loads(serialized_data)

self.config.comm.broadcast_object(uuid_shape_list)
uuid, hidden_states_shape = uuid_shape_list
hidden_states = torch.empty(hidden_states_shape, dtype=self.model.dtype)
dist.recv(hidden_states, src=0)
self.config.comm.broadcast(hidden_states)

_ = self.model.forward(hidden_states, uuid)

Expand Down Expand Up @@ -77,11 +64,8 @@ def Forward(self, request: schemas_pb2.ForwardRequest, context: grpc.ServicerCon
s1 = time.time()
hidden_states = deserialize_bfloat16_tensor(request.hidden_states)

serialized_data = list(pickle.dumps((hidden_states.shape, request.uuid)))
tensor_data = torch.ByteTensor(serialized_data)
for rank in range(1, self.config.comm.world_size):
dist.send(tensor_data, dst=rank)
dist.send(hidden_states, dst=rank)
self.config.comm.broadcast_object([request.uuid, tuple(hidden_states.shape)])
self.config.comm.broadcast(hidden_states)

output = self.model(hidden_states, request.uuid)

Expand Down Expand Up @@ -134,11 +118,12 @@ def start_grpc_server(config, model, port, rank, pp_rank):
config.decoder_end_layer_idx = args.end_layer_idx
config.comm = comm

dtype = torch.bfloat16
s1 = time.time()
state_dict = LlamaForCausalLM.from_pretrained(
args.model_path, trust_remote_code=True, device_map="cpu", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
args.model_path, trust_remote_code=True, device_map="cpu", torch_dtype=dtype, low_cpu_mem_usage=True
).state_dict()
model = MyLlamaModel(config)
model = MyLlamaModel(config).to(dtype)
model.load_state_dict(state_dict)
logging.info(f"[Rank: {config.comm.rank}] Cost time {time.time() - s1}")
model.eval()
Expand Down

0 comments on commit 61f7898

Please sign in to comment.