Skip to content

Commit

Permalink
update import class
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Dec 21, 2024
1 parent bf10a44 commit de6cf52
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 99 deletions.
2 changes: 1 addition & 1 deletion RoadMap.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
- [ ] Auto Layer Split
- [x] get free layer idx
- [x] fix split layer pipeline
- [ ] calculate layer memory and recommend split
- [x] calculate layer memory and recommend split
- [ ] split model before load
- [x] Async Generation
- [x] Multi-Sequence Batch=1
Expand Down
51 changes: 22 additions & 29 deletions examples/run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,40 +37,28 @@ class Args:
is_debug: bool = False


def init_engine(model_path):
model, tok = load_master_model(model_path)
def init_engine(model_path: str) -> AsyncEngine:
model = load_master_model(model_path)
rpc_manager = LocalRPCManager(model_path)
generator = LLMGenerator(rpc_manager, model, tok)
generator = LLMGenerator(rpc_manager, model)
engine = AsyncEngine(generator)
return engine


def init_image_engine(model_path):
model, tok = load_master_model(model_path)
def init_image_engine(model_path: str) -> AsyncEngine:
model = load_master_model(model_path)
rpc_manager = LocalRPCManager(model_path)
generator = ImageGenerator(rpc_manager, model, tok)
generator = ImageGenerator(rpc_manager, model)
engine = AsyncEngine(generator)
return engine


async def llm_generate():
args = Args()

engine = init_engine(args.model_path)
await engine.start()
def llm_message():
messages = [{"role": "user", "content": "Hello, how are you?"}]
openai_serving_chat = OpenAIServing(engine, args)
return messages

request = ChatCompletionRequest(model="test", messages=messages)
response = await openai_serving_chat.create_chat_completion(request, None)
print(response)


async def mllm_generate():
args = Args()

engine = init_engine(args.model_path)
await engine.start()
def mllm_message():
messages = [
{
"role": "user",
Expand All @@ -80,17 +68,21 @@ async def mllm_generate():
],
}
]
return messages


async def llm_generate(args, messages):
engine = init_engine(args.model_path)
await engine.start()
messages = [{"role": "user", "content": "Hello, how are you?"}]
openai_serving_chat = OpenAIServing(engine, args)

request = ChatCompletionRequest(model="test", messages=messages)
response = await openai_serving_chat.create_chat_completion(request, None)
print(response)


async def image_generate():
args = Args()

prompt = "a little dog"
async def image_generate(args):
prompt = "germanic romanticism painting of an obscure winter forest in a geocore landscape. Ambient landscape lighting, heavy shading, crystal night sky, stunning stars, topography"
config = {
"num_inference_steps": 3,
Expand All @@ -99,7 +91,7 @@ async def image_generate():
}

engine = init_image_engine(args.model_path)
_ = await engine.start()
await engine.start()

image_serving = ImageServing(engine, args)

Expand All @@ -110,6 +102,7 @@ async def image_generate():


if __name__ == "__main__":
asyncio.run(llm_generate())
# asyncio.run(mllm_generate())
# asyncio.run(image_generate())
args = Args()
asyncio.run(llm_generate(args, llm_message()))
# asyncio.run(llm_generate(args, mllm_message()))
# asyncio.run(image_generate(args))
12 changes: 7 additions & 5 deletions tllm/commons/manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import time
from typing import List, Optional, Tuple
from typing import Any, List

from transformers import AutoConfig

from tllm import BACKEND, BackendEnum
from tllm.commons.communicator import BaseCommunicator
from tllm.generate import LLMGenerator, TokenizerUtils
from tllm.models.file_helper import find_weight_file, get_model_path
from tllm.models.register import MODEL_REGISTER
from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors, tie_embedding_weights
Expand Down Expand Up @@ -65,6 +64,8 @@ def __init__(self, weights):
return TransformerWeightHandler(weights)

def _post_init(self):
from tllm.generate import TokenizerUtils

if str(self.model_path).endswith(".gguf"):
raise NotImplementedError("GGUF model not supported")
# state_dict, config, _ = load_gguf_weight(str(self.model_path))
Expand Down Expand Up @@ -149,7 +150,7 @@ def _hf_read_client_weight(self, start_idx: int, end_idx: int):
return state_dict


def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str):
def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str) -> Any:
weight_manager = WeightManager(model_path)
config = weight_manager.config

Expand All @@ -175,7 +176,7 @@ def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, mode
return model


def load_master_model(model_path: str) -> Tuple[LLMGenerator, TokenizerUtils]:
def load_master_model(model_path: str) -> Any:
weight_manager = WeightManager(model_path)
state_dict = weight_manager.read_master_weight()
if weight_manager.arch not in MODEL_REGISTER:
Expand All @@ -190,4 +191,5 @@ def load_master_model(model_path: str) -> Tuple[LLMGenerator, TokenizerUtils]:
kwargs.update({"quantization_level": weight_manager.config.quantization_level})

model = MY_CausalLM_CLASS.from_pretrained(weight_manager.config, state_dict, **kwargs)
return model, weight_manager.tok
model.tok = weight_manager.tok
return model
24 changes: 7 additions & 17 deletions tllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from tllm.entrypoints.image_server.server_image import ImageServing
from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse
from tllm.entrypoints.server_chat import OpenAIServing
from tllm.entrypoints.utils import load_master_config, parse_master_args, serve_http
from tllm.network.helper import get_free_port
from tllm.entrypoints.utils import parse_master_args, serve_http, update_master_args
from tllm.generate import ImageGenerator, LLMGenerator
from tllm.network.manager import LocalRPCManager, RPCManager, WebsocketManager
from tllm.schemas import InitModelRequest, InitModelResponse, RegisterClientRequest, RegisterClientResponse
from tllm.singleton_logger import SingletonLogger
Expand Down Expand Up @@ -188,7 +188,7 @@ async def init_engine(args):
logger = SingletonLogger.setup_master_logger()

s1 = time.time()
model, tok = load_master_model(args.model_path)
model = load_master_model(args.model_path)
total_layers = model.num_layers # 必须要有层数

global ws_manager, rpc_manager
Expand All @@ -197,30 +197,20 @@ async def init_engine(args):
rpc_manager, master_handler = await init_rpc_manager(
args.model_path, ws_manager.client_size, args.grpc_port, args.is_local
)

logger.info(f"Engine init Cost Time: {time.time() - s1:.4f}s. Total Layers: {total_layers}")
if args.is_image:
from tllm.generate import ImageGenerator

generator = ImageGenerator(rpc_manager, model, tok)
generator = ImageGenerator(rpc_manager, model)
else:
from tllm.generate import LLMGenerator

generator = LLMGenerator(rpc_manager, model, tok)
generator = LLMGenerator(rpc_manager, model)
engine = AsyncEngine(generator)
logger.info(f"Engine init Cost Time: {time.time() - s1:.4f}s. Total Layers: {total_layers}")

await engine.start()
return engine


async def run_server(args) -> None:
SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO")

if args.grpc_port is None:
args.grpc_port = get_free_port()

if args.config:
args = load_master_config(args.config, args)
args = update_master_args(args)

engine = await init_engine(args)
app = await init_app(engine, args)
Expand Down
27 changes: 6 additions & 21 deletions tllm/entrypoints/handler/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from tllm import GRPC_OPTIONS
from tllm.commons.communicator import BaseCommunicator, Communicator
from tllm.commons.convert import Convertor
from tllm.entrypoints.utils import load_handler_config, parse_handler_args
from tllm.network.helper import get_free_port, get_ips
from tllm.entrypoints.utils import parse_handler_args, update_handler_args
from tllm.network.http_client import HTTPClient
from tllm.network.manager import MasterRPCManager
from tllm.rpc import schemas_pb2, schemas_pb2_grpc
Expand Down Expand Up @@ -112,10 +111,10 @@ async def Forward(
"""
@param request: ForwardRequest
hidden_states: bytes
uuid: str
seq_len: int
uuid: List[str]
seq_len: List[int]
"""
if not hasattr(self.http_client, "model") and self.http_client is None:
if hasattr(self.http_client, "model") is None:
return schemas_pb2.ForwardResponse(msg="Model not initialized", status=400)
if hasattr(self.manager, "master_stub") is None:
return schemas_pb2.ForwardResponse(msg="Manager not initialized", status=400)
Expand Down Expand Up @@ -167,27 +166,13 @@ async def Health(self, request, context):


async def run(args):
SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO")
args, ip_addr_list = update_handler_args(args)
comm = Communicator()
if args.grpc_port is None:
args.grpc_port = get_free_port()
if args.config is not None:
if args.client_idx is None:
raise ValueError("client_idx is required when config is provided")
args = load_handler_config(args.config, args, args.client_idx)

ip_addr_list = get_ips()
# 如果指定了 hostname, 则只使用指定的 hostname
if args.hostname is not None and isinstance(args.hostname, str):
ip_addr_list = [args.hostname]

if len(ip_addr_list) == 0:
raise ValueError("No available ip address")

SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO")
logger = SingletonLogger.setup_handler_logger(f"handler-{args.grpc_port}")

rpc_servicer = RPCHandler(comm, logger, args.master_addr)

try:
if comm.rank == 0:
await rpc_servicer.start(ip_addr_list, args.grpc_port)
Expand Down
4 changes: 0 additions & 4 deletions tllm/entrypoints/handler/master_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def update(self, count: int, result: Tuple[int, float]):


class PendingRequests:
"""管理待处理的请求"""

def __init__(self):
self._forward_requests: Dict[str, asyncio.Future] = {}
self._status_requests: Dict[str, StatusTracker] = {}
Expand Down Expand Up @@ -97,7 +95,6 @@ async def Forward(
) -> schemas_pb2.ForwardResponse:
"""处理从最后一个节点返回的结果"""
request_id = "-".join(x for x in list(request.uuid))
# self.logger.debug(f"Received result request id: {request_id}")

try:
self.pending_requests.complete_forward_request(request_id, request.hidden_states)
Expand All @@ -113,7 +110,6 @@ async def ImageForward(
) -> schemas_pb2.ForwardResponse:
"""处理从最后一个节点返回的结果"""
request_id = "-".join(x for x in list(request.uuid))
# self.logger.debug(f"Received result request id: {request_id}")

try:
self.pending_requests.complete_forward_request(request_id, request.hidden_states)
Expand Down
51 changes: 36 additions & 15 deletions tllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from fastapi import FastAPI
import uvicorn

from tllm.network.helper import get_free_port, get_ips
from tllm.singleton_logger import SingletonLogger


def parse_master_args():
parser = argparse.ArgumentParser()
parser.add_argument("--hostname", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--hostname", type=str, required=False)
parser.add_argument("--grpc_port", type=int, default=None)
parser.add_argument("--http_port", type=int, default=8022)
parser.add_argument("--config", type=str, default=None, help="config file path")
Expand Down Expand Up @@ -42,23 +43,43 @@ def parse_handler_args():
return parser.parse_args()


def load_master_config(config_path: str, args):
with open(config_path, "r") as f:
config = json.load(f)
args.hostname = config["server"]["hostname"]
args.http_port = config["server"]["http_port"]
args.grpc_port = config["server"]["grpc_port"]
args.client_size = len(config["client"])
def update_master_args(args):
if args.grpc_port is None:
args.grpc_port = get_free_port()

if args.config is not None:
with open(args.config, "r") as f:
config = json.load(f)
args.hostname = config["server"]["hostname"]
args.http_port = config["server"]["http_port"]
args.grpc_port = config["server"]["grpc_port"]
args.client_size = len(config["client"])
return args


def load_handler_config(config_path: str, args, idx: int):
with open(config_path, "r") as f:
config = json.load(f)
args.grpc_port = config["client"][idx]["grpc_port"]
args.hostname = config["client"][idx]["hostname"]
args.master_addr = f'http://{config["server"]["hostname"]}:{config["server"]["http_port"]}'
return args
def update_handler_args(args):
if args.grpc_port is None:
args.grpc_port = get_free_port()

if args.config is not None:
if args.client_idx is None:
raise ValueError("client_idx is required when config is provided")
with open(args.config_path, "r") as f:
config = json.load(f)
args.grpc_port = config["client"][args.client_idx]["grpc_port"]
args.hostname = config["client"][args.client_idx]["hostname"]
args.master_addr = f'http://{config["server"]["hostname"]}:{config["server"]["http_port"]}'

# 如果指定了 hostname, 则只使用指定的 hostname
if args.hostname is not None and isinstance(args.hostname, str):
ip_addr_list = [args.hostname]
else:
ip_addr_list = get_ips()

if len(ip_addr_list) == 0:
raise ValueError("No available ip address")

return args, ip_addr_list


async def serve_http(app: FastAPI, **uvicorn_kwargs: Dict):
Expand Down
5 changes: 3 additions & 2 deletions tllm/generate/image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from typing import List

from tllm.img_helper import pil_image_to_base64
from tllm.network.manager.rpc_manager import RPCManager
from tllm.schemas import ForwardResult, ImageRequestData
from tllm.singleton_logger import SingletonLogger


class ImageGenerator:
def __init__(self, manager: "RPCManager", model, tok=None) -> None:
def __init__(self, manager: RPCManager, model) -> None:
self.manager = manager
self.model = model
self.logger = SingletonLogger.setup_master_logger()
self.tok = tok
self.tok = None

async def forward(self, image_request: ImageRequestData) -> ForwardResult:
height, width = image_request.runtime_config.height, image_request.runtime_config.width
Expand Down
Loading

0 comments on commit de6cf52

Please sign in to comment.