diff --git a/README.md b/README.md index 6b5d59d..0e67aa5 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,31 @@ python benchmarks/run_async_requests.py ``` +### Config + +In `examples/config.json` + +```json +// 客户端的数量会决定模型拆分的数量 +{ + "server": { + "grpc_port": 25001, // server 的 grpc 端口,用于每个 client 发送状态数据以及最后一个 client 发送计算后的结果 + "http_port": 8022, // server 的 http 端口,API 接口 以及 WebSocket 服务 + "hostname": "mac-mini" // server 的 hostname,可以用 ip 代替,如 192.168.1.10,需要确保 client 能够访问 + }, + "client": [ + { + "grpc_port": 25002, // 第一个 client 的 grpc 端口 + "hostname": "m3pro" // 第一个 client 的 hostname,需要确保 server 和 其他 client 能够访问 + }, + { + "grpc_port": 25003, // 第二个 client 的 grpc 端口 + "hostname": "m3" // 第二个 client 的 hostname,需要确保 server 和 其他 client 能够访问 + } + ] +} +``` + ### Features - [x] Support Multi-Requests diff --git a/examples/config.json b/examples/config.json index bda1ade..ea499eb 100644 --- a/examples/config.json +++ b/examples/config.json @@ -1,19 +1,17 @@ { "server": { - "grpc": 25001, - "http": 8022, - "ip_addr": "mac-mini" + "grpc_port": 25001, + "http_port": 8022, + "hostname": "mac-mini" }, "client": [ { - "grpc": 25002, - "ip_addr": "m3pro", - "master_addr": "http://mac-mini:8022" + "grpc_port": 25002, + "hostname": "m3pro" }, { - "grpc": 25003, - "ip_addr": "m3", - "master_addr": "http://mac-mini:8022" + "grpc_port": 25003, + "hostname": "m3" } ] } \ No newline at end of file diff --git a/examples/config_one.json b/examples/config_one.json new file mode 100644 index 0000000..e9d12d5 --- /dev/null +++ b/examples/config_one.json @@ -0,0 +1,13 @@ +{ + "server": { + "grpc_port": 25001, + "http_port": 8022, + "hostname": "mac-mini" + }, + "client": [ + { + "grpc_port": 25002, + "hostname": "mac-mini" + } + ] +} \ No newline at end of file diff --git a/examples/run_client.sh b/examples/run_client.sh index d93aba8..7d7e916 100644 --- a/examples/run_client.sh +++ b/examples/run_client.sh @@ -5,4 +5,5 @@ MASTER_URL=http://mac-mini:8022 export OMP_NUM_THREADS=8; export PYTHONPATH="./":$PYTHONPATH; -python3 -m tllm.entrypoints.handler.handler --master_addr $MASTER_URL --is_debug \ No newline at end of file +python3 -m tllm.entrypoints.handler.handler --master_addr $MASTER_URL --is_debug +# python3 -m tllm.entrypoints.handler.handler --master_addr $MASTER_URL --is_debug --config examples/config_one.json --client_idx 0 \ No newline at end of file diff --git a/examples/run_server.sh b/examples/run_server.sh index 4701138..38545a5 100644 --- a/examples/run_server.sh +++ b/examples/run_server.sh @@ -5,10 +5,6 @@ MODEL_PATH=Qwen/Qwen2-VL-2B-Instruct MASTER_HOSTNAME=mac-mini export PYTHONPATH="./":$PYTHONPATH; -# num_hidden_layers -# 1B 16 -# 3B 28 -# 8B 32 -# 70B 70 -python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug \ No newline at end of file +python3 -m tllm.entrypoints.api_server --hostname $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug +# python3 -m tllm.entrypoints.api_server --hostname $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug --config examples/config_one.json \ No newline at end of file diff --git a/examples/run_single_server.sh b/examples/run_single_server.sh index f96d0c5..606bd92 100644 --- a/examples/run_single_server.sh +++ b/examples/run_single_server.sh @@ -2,10 +2,9 @@ MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-1B-Instruct # MODEL_PATH=Qwen/Qwen2-VL-2B-Instruct # MODEL_PATH=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit -MASTER_HOSTNAME=m3pro export PYTHONPATH="./":$PYTHONPATH; -python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_local --is_debug +python3 -m tllm.entrypoints.api_server --model_path $MODEL_PATH --is_local --is_debug diff --git a/flux_examples/run_server.sh b/flux_examples/run_server.sh index 95fd90e..2b8ef87 100644 --- a/flux_examples/run_server.sh +++ b/flux_examples/run_server.sh @@ -5,4 +5,4 @@ MASTER_HOSTNAME=mac-mini export PYTHONPATH="./":$PYTHONPATH; -python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug \ No newline at end of file +python3 -m tllm.entrypoints.api_server --hostname $MASTER_HOSTNAME --model_path $MODEL_PATH --client_size 1 --is_debug \ No newline at end of file diff --git a/flux_examples/run_single_server.sh b/flux_examples/run_single_server.sh index c5f78c3..ae6ddae 100644 --- a/flux_examples/run_single_server.sh +++ b/flux_examples/run_single_server.sh @@ -1,8 +1,7 @@ #!/bin/bash MODEL_PATH=/Users/lujianghu/Documents/flux/schnell_4bit -MASTER_HOSTNAME=mac-mini export PYTHONPATH="./":$PYTHONPATH; -python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_local --is_debug --is_image +python3 -m tllm.entrypoints.api_server --model_path $MODEL_PATH --client_size 1 --is_local --is_debug --is_image diff --git a/tllm/entrypoints/api_server.py b/tllm/entrypoints/api_server.py index e49b982..860f422 100644 --- a/tllm/entrypoints/api_server.py +++ b/tllm/entrypoints/api_server.py @@ -1,5 +1,4 @@ import asyncio -import json import os import time from typing import Union @@ -14,7 +13,7 @@ from tllm.entrypoints.image_server.server_image import ImageServing from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse from tllm.entrypoints.server_chat import OpenAIServing -from tllm.entrypoints.utils import parse_args, serve_http +from tllm.entrypoints.utils import load_master_config, parse_master_args, serve_http from tllm.network.helper import get_free_port from tllm.network.manager import LocalRPCManager, RPCManager, WebsocketManager from tllm.schemas import InitModelRequest, InitModelResponse, RegisterClientRequest, RegisterClientResponse @@ -48,7 +47,7 @@ async def get_index(): @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request) -> ChatCompletionResponse: - if not ws_manager.has_full_model and not is_local: + if not ws_manager.has_full_model and not args.is_local: raise ValueError("No available Full Node to process the request") if openai_serving_chat is None: raise ValueError("OpenAIServing instance is not initialized") @@ -66,7 +65,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re @app.post("/v1/completions") async def create_completion(request: ChatCompletionRequest, raw_request: Request) -> ChatCompletionResponse: - if not ws_manager.has_full_model and not is_local: + if not ws_manager.has_full_model and not args.is_local: raise ValueError("No available Full Node to process the request") if openai_serving_chat is None: raise ValueError("OpenAIServing instance is not initialized") @@ -80,7 +79,7 @@ async def create_completion(request: ChatCompletionRequest, raw_request: Request @app.post("/v1/create_image") async def create_image(request: Text2ImageRequest, raw_request: Request) -> Text2ImageResponse: - if not ws_manager.has_full_model and not is_local: + if not ws_manager.has_full_model and not args.is_local: raise ValueError("No available Full Node to process the request") if image_serving is None: raise ValueError("ImageServing instance is not initialized") @@ -142,7 +141,7 @@ async def update_model_url(): host_list = ws_manager.set_connect_clients() if len(host_list) > 0: rpc_manager.update_url(host_list) - await rpc_manager.send_config(f"{args.ip_addr}:{args.grpc_port}", host_list) + await rpc_manager.send_config(f"{args.hostname}:{args.grpc_port}", host_list) # 后台持续进行健康检查,如果有节点挂掉,需要重新分配 await rpc_manager.start_health_check() @@ -174,9 +173,7 @@ async def init_model_func( async def init_app(engine: AsyncEngine, args): global app global logger, openai_serving_chat, image_serving - global is_local logger = SingletonLogger.setup_master_logger() - is_local = args.is_local logger.info("args: %s", args) if args.is_image: @@ -196,7 +193,7 @@ async def init_engine(args): global ws_manager, rpc_manager - ws_manager = WebsocketManager(total_layers, args.model_path) + ws_manager = WebsocketManager(total_layers, args.model_path, client_size=args.client_size) rpc_manager, master_handler = await init_rpc_manager( args.model_path, ws_manager.client_size, args.grpc_port, args.is_local ) @@ -223,11 +220,7 @@ async def run_server(args) -> None: args.grpc_port = get_free_port() if args.config: - with open(args.config, "r") as f: - config = json.load(f) - args.ip_addr = config["server"]["ip_addr"] - args.http_port = config["server"]["http_port"] - args.grpc_port = config["server"]["grpc_port"] + args = load_master_config(args.config, args) engine = await init_engine(args) app = await init_app(engine, args) @@ -238,5 +231,5 @@ async def run_server(args) -> None: if __name__ == "__main__": - args = parse_args() + args = parse_master_args() asyncio.run(run_server(args)) diff --git a/tllm/entrypoints/handler/handler.py b/tllm/entrypoints/handler/handler.py index 6da9a38..eb1f252 100644 --- a/tllm/entrypoints/handler/handler.py +++ b/tllm/entrypoints/handler/handler.py @@ -1,8 +1,5 @@ -import argparse import asyncio from concurrent import futures -import json -import logging import time from typing import List import uuid @@ -12,6 +9,7 @@ from tllm import GRPC_OPTIONS from tllm.commons.communicator import BaseCommunicator, Communicator from tllm.commons.convert import Convertor +from tllm.entrypoints.utils import load_handler_config, parse_handler_args from tllm.network.helper import get_free_port, get_ips from tllm.network.http_client import HTTPClient from tllm.network.manager import MasterRPCManager @@ -168,44 +166,31 @@ async def Health(self, request, context): return schemas_pb2.HealthResponse(msg="Healthy", status=200) -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=None, help="gRPC 服务的端口") - parser.add_argument( - "--master_addr", type=str, required=True, help="master 的 http 地址, 如 http://192.168.x.y:8022" - ) - parser.add_argument("--ip_addr", type=str, default=None, help="提供给 master 连接的 ip, 如 192.168.x.y") - parser.add_argument("--is_debug", action="store_true") - parser.add_argument("--config", type=str, default=None, help="config file path") - return parser.parse_args() - - async def run(args): comm = Communicator() - if args.port is None: - args.port = get_free_port() + if args.grpc_port is None: + args.grpc_port = get_free_port() if args.config is not None: - with open(args.config, "r") as f: - config = json.load(f) - # TODO - args.port = config["client"][0]["grpc"] - args.master_addr = config["client"][0]["master_addr"] - args.ip_addr = config["client"][0]["ip_addr"] + if args.client_idx is None: + raise ValueError("client_idx is required when config is provided") + args = load_handler_config(args.config, args, args.client_idx) ip_addr_list = get_ips() - if args.ip_addr and isinstance(args.ip_addr, str): - ip_addr_list = [args.ip_addr] + # 如果指定了 hostname, 则只使用指定的 hostname + if args.hostname is not None and isinstance(args.hostname, str): + ip_addr_list = [args.hostname] + if len(ip_addr_list) == 0: raise ValueError("No available ip address") SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO") - logger = SingletonLogger.setup_handler_logger(f"handler-{args.port}") + logger = SingletonLogger.setup_handler_logger(f"handler-{args.grpc_port}") rpc_servicer = RPCHandler(comm, logger, args.master_addr) try: if comm.rank == 0: - await rpc_servicer.start(ip_addr_list, args.port) + await rpc_servicer.start(ip_addr_list, args.grpc_port) except Exception as e: await rpc_servicer.stop() logger.error(f"Error occurred: {str(e)}") @@ -215,5 +200,5 @@ async def run(args): if __name__ == "__main__": - args = parse_args() + args = parse_handler_args() asyncio.run(run(args)) diff --git a/tllm/entrypoints/utils.py b/tllm/entrypoints/utils.py index 86a5a4c..4d97ab8 100644 --- a/tllm/entrypoints/utils.py +++ b/tllm/entrypoints/utils.py @@ -1,5 +1,6 @@ import argparse import asyncio +import json import signal from typing import Dict @@ -9,19 +10,57 @@ from tllm.singleton_logger import SingletonLogger -def parse_args(): +def parse_master_args(): parser = argparse.ArgumentParser() - parser.add_argument("--ip_addr", type=str, required=True) + parser.add_argument("--hostname", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--grpc_port", type=int, default=None) parser.add_argument("--http_port", type=int, default=8022) parser.add_argument("--config", type=str, default=None, help="config file path") + parser.add_argument( + "--client_size", + type=int, + default=None, + help="the number of the client, if not provided, will be parsed from the model path and auto calculated", + ) parser.add_argument("--is_local", action="store_true") parser.add_argument("--is_debug", action="store_true") parser.add_argument("--is_image", action="store_true") return parser.parse_args() +def parse_handler_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--grpc_port", type=int, default=None, help="gRPC 服务的端口") + parser.add_argument( + "--master_addr", type=str, required=True, help="master 的 http 地址, 如 http://192.168.x.y:8022" + ) + parser.add_argument("--hostname", type=str, default=None, help="提供给 master 连接的 ip, 如 192.168.x.y") + parser.add_argument("--is_debug", action="store_true") + parser.add_argument("--config", type=str, default=None, help="config file path") + parser.add_argument("--client_idx", type=int, default=None, help="the client index in the config file") + return parser.parse_args() + + +def load_master_config(config_path: str, args): + with open(config_path, "r") as f: + config = json.load(f) + args.hostname = config["server"]["hostname"] + args.http_port = config["server"]["http_port"] + args.grpc_port = config["server"]["grpc_port"] + args.client_size = len(config["client"]) + return args + + +def load_handler_config(config_path: str, args, idx: int): + with open(config_path, "r") as f: + config = json.load(f) + args.grpc_port = config["client"][idx]["grpc_port"] + args.hostname = config["client"][idx]["hostname"] + args.master_addr = f'http://{config["server"]["hostname"]}:{config["server"]["http_port"]}' + return args + + async def serve_http(app: FastAPI, **uvicorn_kwargs: Dict): logger = SingletonLogger.setup_master_logger() diff --git a/tllm/models/file_helper.py b/tllm/models/file_helper.py index e6752c6..598fe4a 100644 --- a/tllm/models/file_helper.py +++ b/tllm/models/file_helper.py @@ -91,27 +91,29 @@ def parse_model_size(model_name: str) -> float: break except: pass - if model_size == -1: - print("Model size not found in model name. Defaulting to 1B") - return 1.0 - # assert model_size > 0, f"Invalid model name: {model_name}" + assert model_size > 0, f"Invalid model name: {model_name}" return model_size -def split_model_layers(model_size: float, total_layers: int) -> Tuple[int, List[Tuple[int, int]]]: - # 根据 model size 和 层数来划分客户端数量以及每个客户端的层数 +def auto_set_client_size(model_size: float) -> int: + # num_hidden_layers + # 1B 16 + # 3B 28 + # 8B 32 + # 70B 70 if model_size < 4: - client_size = 1 + return 1 elif model_size <= 8: - client_size = 2 + return 2 elif model_size <= 32: - client_size = 4 + return 4 elif model_size <= 72: - client_size = 8 + return 8 else: raise ValueError(f"Model size {model_size} is too large") + +def split_model_layers(total_layers: int, client_size: int) -> List[Tuple[int, int]]: + # 根据 model size 和 层数来划分客户端数量以及每个客户端的层数 each_client_layers = total_layers // client_size - return client_size, [ - (start_idx, start_idx + each_client_layers) for start_idx in range(0, total_layers, each_client_layers) - ] + return [(start_idx, start_idx + each_client_layers) for start_idx in range(0, total_layers, each_client_layers)] diff --git a/tllm/network/manager/websocket_manager.py b/tllm/network/manager/websocket_manager.py index beb744f..76991f2 100644 --- a/tllm/network/manager/websocket_manager.py +++ b/tllm/network/manager/websocket_manager.py @@ -1,25 +1,27 @@ import random -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from fastapi import WebSocket -from tllm.models.file_helper import parse_model_size, split_model_layers +from tllm.models.file_helper import auto_set_client_size, parse_model_size, split_model_layers from tllm.network.helper import find_continuous_path, tcp_ping_test from tllm.schemas import ClientData, InitModelRequest, InitModelResponse, RegisterClientRequest, RegisterClientResponse class WebsocketManager: - def __init__(self, total_layers: int, model_name: str, skip_parse: bool = False): + def __init__(self, total_layers: int, model_name: str, client_size: Optional[int] = None): self.total_layers = total_layers self.model_name = model_name self.clients: Dict[str, ClientData] = {} # 连接的客户端, client_id -> ClientData self.monitor_websockets: Set[WebSocket] = set() # 前端页面的websocket连接 self.connect_clients = [] - if skip_parse: - self.client_size, self.layer_info = split_model_layers(1, total_layers) + if client_size is None: + model_size = parse_model_size(model_name) + self.client_size = auto_set_client_size(model_size) else: - self.client_size, self.layer_info = split_model_layers(parse_model_size(model_name), total_layers) + self.client_size = client_size + self.layer_info = split_model_layers(total_layers, self.client_size) self.client_info = [[start_idx, end_idx, 0] for start_idx, end_idx in self.layer_info] # 统计连接情况 def get_free_layer(self) -> Tuple[int, int, int]: