diff --git a/README.md b/README.md index 7d1246e..2e1a99f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ 2.1 (localhost) ```bash - tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit --hostname localhost --is_local --client_size 1 + tllm.server --model_path mlx-community/Llama-3.2-1B-Instruct-4bit ``` 2.2 (multi clients) @@ -27,7 +27,6 @@ tllm.client --hostname http://$YOUR_IP:8022 ``` - 3. testing ```bash diff --git a/tllm/entrypoints/api_server.py b/tllm/entrypoints/api_server.py index 1929884..2ff0a71 100644 --- a/tllm/entrypoints/api_server.py +++ b/tllm/entrypoints/api_server.py @@ -15,7 +15,7 @@ from tllm.entrypoints.image_server.server_image import ImageServing from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse from tllm.entrypoints.server_chat import OpenAIServing -from tllm.entrypoints.utils import GRPCProcess, parse_master_args, serve_http, update_master_args +from tllm.entrypoints.utils import GRPCProcess, is_local, parse_master_args, serve_http, update_master_args from tllm.entrypoints.websocket_manager import WebsocketManager from tllm.generate import ImageGenerator, LLMGenerator from tllm.grpc.master_service.worker_manager import WorkerRPCManager @@ -151,7 +151,7 @@ async def update_model_url(): for clients in host_list ] worker_rpc_manager.update_url(host_list) - master_url = args.hostname if args.is_local else f"{args.hostname}:{args.grpc_port}" + master_url = args.hostname if is_local(args.hostname) else f"{args.hostname}:{args.grpc_port}" await worker_rpc_manager.send_config(master_url, host_list) # 后台持续进行健康检查,如果有节点挂掉,需要重新分配 await worker_rpc_manager.start_health_check() @@ -225,7 +225,7 @@ async def run_server(args) -> None: uvicorn_kwargs = {"host": ["::", "0.0.0.0"], "port": args.http_port, "timeout_graceful_shutdown": 5} - if args.is_local: + if is_local(args.hostname): if os.path.isfile(MASTER_SOCKET_PATH): os.remove(MASTER_SOCKET_PATH) if os.path.isfile(CLIENT_SOCKET_PATH): diff --git a/tllm/entrypoints/utils.py b/tllm/entrypoints/utils.py index 23f0e92..faea790 100644 --- a/tllm/entrypoints/utils.py +++ b/tllm/entrypoints/utils.py @@ -13,6 +13,10 @@ from tllm.singleton_logger import SingletonLogger +def is_local(hostname: str) -> bool: + return hostname == "localhost" + + def parse_master_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -21,7 +25,7 @@ def parse_master_args(): required=True, help="Specify the path of the model file or huggingface repo. Like mlx-community/Llama-3.2-1B-Instruct-bf16", ) - parser.add_argument("--hostname", type=str, help="The address of the client connection.") + parser.add_argument("--hostname", type=str, default="localhost", help="The address of the client connection.") parser.add_argument( "--grpc_port", type=int, @@ -43,14 +47,9 @@ def parse_master_args(): parser.add_argument( "--client_size", type=int, - default=None, + default=1, help="The number of clients. If this parameter is not provided, the program will try to parse and automatically calculate the number from the model path.", ) - parser.add_argument( - "--is_local", - action="store_true", - help="A boolean flag. If this parameter is specified in the command line, indicates that the model runs locally only", - ) parser.add_argument( "--is_debug", action="store_true",