diff --git a/RoadMap.md b/RoadMap.md index 3af28db..b03306e 100644 --- a/RoadMap.md +++ b/RoadMap.md @@ -45,7 +45,7 @@ - [ ] Auto Layer Split - [x] get free layer idx - [x] fix split layer pipeline - - [ ] calculate layer memory and recommend split + - [x] calculate layer memory and recommend split - [ ] split model before load - [x] Async Generation - [x] Multi-Sequence Batch=1 diff --git a/examples/run_engine.py b/examples/run_engine.py index d9a13f1..8e22feb 100644 --- a/examples/run_engine.py +++ b/examples/run_engine.py @@ -37,40 +37,28 @@ class Args: is_debug: bool = False -def init_engine(model_path): - model, tok = load_master_model(model_path) +def init_engine(model_path: str) -> AsyncEngine: + model = load_master_model(model_path) rpc_manager = LocalRPCManager(model_path) - generator = LLMGenerator(rpc_manager, model, tok) + generator = LLMGenerator(rpc_manager, model) engine = AsyncEngine(generator) return engine -def init_image_engine(model_path): - model, tok = load_master_model(model_path) +def init_image_engine(model_path: str) -> AsyncEngine: + model = load_master_model(model_path) rpc_manager = LocalRPCManager(model_path) - generator = ImageGenerator(rpc_manager, model, tok) + generator = ImageGenerator(rpc_manager, model) engine = AsyncEngine(generator) return engine -async def llm_generate(): - args = Args() - - engine = init_engine(args.model_path) - await engine.start() +def llm_message(): messages = [{"role": "user", "content": "Hello, how are you?"}] - openai_serving_chat = OpenAIServing(engine, args) + return messages - request = ChatCompletionRequest(model="test", messages=messages) - response = await openai_serving_chat.create_chat_completion(request, None) - print(response) - - -async def mllm_generate(): - args = Args() - engine = init_engine(args.model_path) - await engine.start() +def mllm_message(): messages = [ { "role": "user", @@ -80,6 +68,13 @@ async def mllm_generate(): ], } ] + return messages + + +async def llm_generate(args, messages): + engine = init_engine(args.model_path) + await engine.start() + messages = [{"role": "user", "content": "Hello, how are you?"}] openai_serving_chat = OpenAIServing(engine, args) request = ChatCompletionRequest(model="test", messages=messages) @@ -87,10 +82,7 @@ async def mllm_generate(): print(response) -async def image_generate(): - args = Args() - - prompt = "a little dog" +async def image_generate(args): prompt = "germanic romanticism painting of an obscure winter forest in a geocore landscape. Ambient landscape lighting, heavy shading, crystal night sky, stunning stars, topography" config = { "num_inference_steps": 3, @@ -99,7 +91,7 @@ async def image_generate(): } engine = init_image_engine(args.model_path) - _ = await engine.start() + await engine.start() image_serving = ImageServing(engine, args) @@ -110,6 +102,7 @@ async def image_generate(): if __name__ == "__main__": - asyncio.run(llm_generate()) - # asyncio.run(mllm_generate()) - # asyncio.run(image_generate()) + args = Args() + asyncio.run(llm_generate(args, llm_message())) + # asyncio.run(llm_generate(args, mllm_message())) + # asyncio.run(image_generate(args)) diff --git a/tllm/commons/manager.py b/tllm/commons/manager.py index 3e047f7..f9661a5 100644 --- a/tllm/commons/manager.py +++ b/tllm/commons/manager.py @@ -1,12 +1,11 @@ import os import time -from typing import List, Optional, Tuple +from typing import Any, List from transformers import AutoConfig from tllm import BACKEND, BackendEnum from tllm.commons.communicator import BaseCommunicator -from tllm.generate import LLMGenerator, TokenizerUtils from tllm.models.file_helper import find_weight_file, get_model_path from tllm.models.register import MODEL_REGISTER from tllm.models.weight_helper import load_gguf_weight, read_from_safetensors, tie_embedding_weights @@ -65,6 +64,8 @@ def __init__(self, weights): return TransformerWeightHandler(weights) def _post_init(self): + from tllm.generate import TokenizerUtils + if str(self.model_path).endswith(".gguf"): raise NotImplementedError("GGUF model not supported") # state_dict, config, _ = load_gguf_weight(str(self.model_path)) @@ -149,7 +150,7 @@ def _hf_read_client_weight(self, start_idx: int, end_idx: int): return state_dict -def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str): +def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, model_path: str) -> Any: weight_manager = WeightManager(model_path) config = weight_manager.config @@ -175,7 +176,7 @@ def load_client_model(start_idx: int, end_idx: int, comm: BaseCommunicator, mode return model -def load_master_model(model_path: str) -> Tuple[LLMGenerator, TokenizerUtils]: +def load_master_model(model_path: str) -> Any: weight_manager = WeightManager(model_path) state_dict = weight_manager.read_master_weight() if weight_manager.arch not in MODEL_REGISTER: @@ -190,4 +191,5 @@ def load_master_model(model_path: str) -> Tuple[LLMGenerator, TokenizerUtils]: kwargs.update({"quantization_level": weight_manager.config.quantization_level}) model = MY_CausalLM_CLASS.from_pretrained(weight_manager.config, state_dict, **kwargs) - return model, weight_manager.tok + model.tok = weight_manager.tok + return model diff --git a/tllm/entrypoints/api_server.py b/tllm/entrypoints/api_server.py index 860f422..8235969 100644 --- a/tllm/entrypoints/api_server.py +++ b/tllm/entrypoints/api_server.py @@ -13,8 +13,8 @@ from tllm.entrypoints.image_server.server_image import ImageServing from tllm.entrypoints.protocol import ChatCompletionRequest, ChatCompletionResponse from tllm.entrypoints.server_chat import OpenAIServing -from tllm.entrypoints.utils import load_master_config, parse_master_args, serve_http -from tllm.network.helper import get_free_port +from tllm.entrypoints.utils import parse_master_args, serve_http, update_master_args +from tllm.generate import ImageGenerator, LLMGenerator from tllm.network.manager import LocalRPCManager, RPCManager, WebsocketManager from tllm.schemas import InitModelRequest, InitModelResponse, RegisterClientRequest, RegisterClientResponse from tllm.singleton_logger import SingletonLogger @@ -188,7 +188,7 @@ async def init_engine(args): logger = SingletonLogger.setup_master_logger() s1 = time.time() - model, tok = load_master_model(args.model_path) + model = load_master_model(args.model_path) total_layers = model.num_layers # 必须要有层数 global ws_manager, rpc_manager @@ -197,17 +197,12 @@ async def init_engine(args): rpc_manager, master_handler = await init_rpc_manager( args.model_path, ws_manager.client_size, args.grpc_port, args.is_local ) - + logger.info(f"Engine init Cost Time: {time.time() - s1:.4f}s. Total Layers: {total_layers}") if args.is_image: - from tllm.generate import ImageGenerator - - generator = ImageGenerator(rpc_manager, model, tok) + generator = ImageGenerator(rpc_manager, model) else: - from tllm.generate import LLMGenerator - - generator = LLMGenerator(rpc_manager, model, tok) + generator = LLMGenerator(rpc_manager, model) engine = AsyncEngine(generator) - logger.info(f"Engine init Cost Time: {time.time() - s1:.4f}s. Total Layers: {total_layers}") await engine.start() return engine @@ -215,12 +210,7 @@ async def init_engine(args): async def run_server(args) -> None: SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO") - - if args.grpc_port is None: - args.grpc_port = get_free_port() - - if args.config: - args = load_master_config(args.config, args) + args = update_master_args(args) engine = await init_engine(args) app = await init_app(engine, args) diff --git a/tllm/entrypoints/handler/handler.py b/tllm/entrypoints/handler/handler.py index eb1f252..58a70ee 100644 --- a/tllm/entrypoints/handler/handler.py +++ b/tllm/entrypoints/handler/handler.py @@ -9,8 +9,7 @@ from tllm import GRPC_OPTIONS from tllm.commons.communicator import BaseCommunicator, Communicator from tllm.commons.convert import Convertor -from tllm.entrypoints.utils import load_handler_config, parse_handler_args -from tllm.network.helper import get_free_port, get_ips +from tllm.entrypoints.utils import parse_handler_args, update_handler_args from tllm.network.http_client import HTTPClient from tllm.network.manager import MasterRPCManager from tllm.rpc import schemas_pb2, schemas_pb2_grpc @@ -112,10 +111,10 @@ async def Forward( """ @param request: ForwardRequest hidden_states: bytes - uuid: str - seq_len: int + uuid: List[str] + seq_len: List[int] """ - if not hasattr(self.http_client, "model") and self.http_client is None: + if hasattr(self.http_client, "model") is None: return schemas_pb2.ForwardResponse(msg="Model not initialized", status=400) if hasattr(self.manager, "master_stub") is None: return schemas_pb2.ForwardResponse(msg="Manager not initialized", status=400) @@ -167,27 +166,13 @@ async def Health(self, request, context): async def run(args): + SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO") + args, ip_addr_list = update_handler_args(args) comm = Communicator() - if args.grpc_port is None: - args.grpc_port = get_free_port() - if args.config is not None: - if args.client_idx is None: - raise ValueError("client_idx is required when config is provided") - args = load_handler_config(args.config, args, args.client_idx) - - ip_addr_list = get_ips() - # 如果指定了 hostname, 则只使用指定的 hostname - if args.hostname is not None and isinstance(args.hostname, str): - ip_addr_list = [args.hostname] - if len(ip_addr_list) == 0: - raise ValueError("No available ip address") - - SingletonLogger.set_level("DEBUG" if args.is_debug else "INFO") logger = SingletonLogger.setup_handler_logger(f"handler-{args.grpc_port}") rpc_servicer = RPCHandler(comm, logger, args.master_addr) - try: if comm.rank == 0: await rpc_servicer.start(ip_addr_list, args.grpc_port) diff --git a/tllm/entrypoints/handler/master_handler.py b/tllm/entrypoints/handler/master_handler.py index f112ce8..927a0cc 100644 --- a/tllm/entrypoints/handler/master_handler.py +++ b/tllm/entrypoints/handler/master_handler.py @@ -27,8 +27,6 @@ def update(self, count: int, result: Tuple[int, float]): class PendingRequests: - """管理待处理的请求""" - def __init__(self): self._forward_requests: Dict[str, asyncio.Future] = {} self._status_requests: Dict[str, StatusTracker] = {} @@ -97,7 +95,6 @@ async def Forward( ) -> schemas_pb2.ForwardResponse: """处理从最后一个节点返回的结果""" request_id = "-".join(x for x in list(request.uuid)) - # self.logger.debug(f"Received result request id: {request_id}") try: self.pending_requests.complete_forward_request(request_id, request.hidden_states) @@ -113,7 +110,6 @@ async def ImageForward( ) -> schemas_pb2.ForwardResponse: """处理从最后一个节点返回的结果""" request_id = "-".join(x for x in list(request.uuid)) - # self.logger.debug(f"Received result request id: {request_id}") try: self.pending_requests.complete_forward_request(request_id, request.hidden_states) diff --git a/tllm/entrypoints/utils.py b/tllm/entrypoints/utils.py index 4d97ab8..95910b3 100644 --- a/tllm/entrypoints/utils.py +++ b/tllm/entrypoints/utils.py @@ -7,13 +7,14 @@ from fastapi import FastAPI import uvicorn +from tllm.network.helper import get_free_port, get_ips from tllm.singleton_logger import SingletonLogger def parse_master_args(): parser = argparse.ArgumentParser() - parser.add_argument("--hostname", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--hostname", type=str, required=False) parser.add_argument("--grpc_port", type=int, default=None) parser.add_argument("--http_port", type=int, default=8022) parser.add_argument("--config", type=str, default=None, help="config file path") @@ -42,23 +43,43 @@ def parse_handler_args(): return parser.parse_args() -def load_master_config(config_path: str, args): - with open(config_path, "r") as f: - config = json.load(f) - args.hostname = config["server"]["hostname"] - args.http_port = config["server"]["http_port"] - args.grpc_port = config["server"]["grpc_port"] - args.client_size = len(config["client"]) +def update_master_args(args): + if args.grpc_port is None: + args.grpc_port = get_free_port() + + if args.config is not None: + with open(args.config, "r") as f: + config = json.load(f) + args.hostname = config["server"]["hostname"] + args.http_port = config["server"]["http_port"] + args.grpc_port = config["server"]["grpc_port"] + args.client_size = len(config["client"]) return args -def load_handler_config(config_path: str, args, idx: int): - with open(config_path, "r") as f: - config = json.load(f) - args.grpc_port = config["client"][idx]["grpc_port"] - args.hostname = config["client"][idx]["hostname"] - args.master_addr = f'http://{config["server"]["hostname"]}:{config["server"]["http_port"]}' - return args +def update_handler_args(args): + if args.grpc_port is None: + args.grpc_port = get_free_port() + + if args.config is not None: + if args.client_idx is None: + raise ValueError("client_idx is required when config is provided") + with open(args.config_path, "r") as f: + config = json.load(f) + args.grpc_port = config["client"][args.client_idx]["grpc_port"] + args.hostname = config["client"][args.client_idx]["hostname"] + args.master_addr = f'http://{config["server"]["hostname"]}:{config["server"]["http_port"]}' + + # 如果指定了 hostname, 则只使用指定的 hostname + if args.hostname is not None and isinstance(args.hostname, str): + ip_addr_list = [args.hostname] + else: + ip_addr_list = get_ips() + + if len(ip_addr_list) == 0: + raise ValueError("No available ip address") + + return args, ip_addr_list async def serve_http(app: FastAPI, **uvicorn_kwargs: Dict): diff --git a/tllm/generate/image_generator.py b/tllm/generate/image_generator.py index e24bd1d..6149e8b 100644 --- a/tllm/generate/image_generator.py +++ b/tllm/generate/image_generator.py @@ -2,16 +2,17 @@ from typing import List from tllm.img_helper import pil_image_to_base64 +from tllm.network.manager.rpc_manager import RPCManager from tllm.schemas import ForwardResult, ImageRequestData from tllm.singleton_logger import SingletonLogger class ImageGenerator: - def __init__(self, manager: "RPCManager", model, tok=None) -> None: + def __init__(self, manager: RPCManager, model) -> None: self.manager = manager self.model = model self.logger = SingletonLogger.setup_master_logger() - self.tok = tok + self.tok = None async def forward(self, image_request: ImageRequestData) -> ForwardResult: height, width = image_request.runtime_config.height, image_request.runtime_config.width diff --git a/tllm/generate/llm_generator.py b/tllm/generate/llm_generator.py index b7837a5..1f9448d 100644 --- a/tllm/generate/llm_generator.py +++ b/tllm/generate/llm_generator.py @@ -4,8 +4,10 @@ import numpy as np from transformers import AutoImageProcessor, AutoProcessor +from tllm.generate.token_utils import TokenizerUtils from tllm.models.register import sampling_func from tllm.models.utils import is_generate_end +from tllm.network.manager import RPCManager from tllm.schemas import MIX_TENSOR, ForwardResult, SeqInput, SequenceRequestData from tllm.singleton_logger import SingletonLogger @@ -77,11 +79,11 @@ def process_mm_input( class LLMGenerator: - def __init__(self, manager: "RPCManager", model, tok: "TokenizerUtils") -> None: + def __init__(self, manager: RPCManager, model) -> None: self.manager = manager self.logger = SingletonLogger.setup_master_logger() self.model = model - self.tok = tok + self.tok: TokenizerUtils = model.tok self.processor = getattr(model, "processor", None) self.mm_config = getattr(model, "mm_config", None) if self.processor is not None: diff --git a/tllm/models/register.py b/tllm/models/register.py index 7387ccf..12a7b65 100644 --- a/tllm/models/register.py +++ b/tllm/models/register.py @@ -38,7 +38,8 @@ from tllm.models.torch.qwen import HFQwen2ForCausalLM, HFQwen2Model from tllm.models.torch.qwen_vl import HFQwen2VLForConditionalGeneration - sampling_func = greedy_decode MODEL_REGISTER.update({"LlamaForCausalLM": (HFLlamaForCausalLM, HFLlamaModel)}) MODEL_REGISTER.update({"Qwen2ForCausalLM": (HFQwen2ForCausalLM, HFQwen2Model)}) MODEL_REGISTER.update({"Qwen2VLForConditionalGeneration": (HFQwen2VLForConditionalGeneration, HFQwen2Model)}) + + sampling_func = greedy_decode diff --git a/tllm/network/manager/master_manager.py b/tllm/network/manager/master_manager.py index 0fbf2d4..30bb77e 100644 --- a/tllm/network/manager/master_manager.py +++ b/tllm/network/manager/master_manager.py @@ -8,6 +8,9 @@ class MasterRPCManager: # 向 Master 发送 gRPC 请求 + # 每个节点需要发送处理完之后的「结果」和「状态」 + # 「结果」发往下一个 PP,如果已经是最后一个 PP,则发往 Master 节点 + # 「状态」发往 Master 节点 def __init__(self, grpc_options: List[Tuple[str, int]]): self.grpc_options = grpc_options self.master_stub = None @@ -36,7 +39,6 @@ async def rpc_image_func( "uuid": request.uuid, "hidden_states": hidden_states, # "text_embeddings": request.text_embeddings, - # "image_rotary_emb": request.image_rotary_emb, } status_request = {"uuid": request.uuid, "pp_idx": self.pp_idx, "cost_time": cost_time} await self.master_stub.Status(schemas_pb2.StatusRequest(**status_request)) diff --git a/tllm/network/manager/rpc_manager.py b/tllm/network/manager/rpc_manager.py index f648402..59387fb 100644 --- a/tllm/network/manager/rpc_manager.py +++ b/tllm/network/manager/rpc_manager.py @@ -81,7 +81,6 @@ async def check_single_client(index: int) -> Tuple[int, bool]: tasks = [check_single_client(i) for i in range(self.client_size)] - # 等待所有任务完成,返回结果列表 results = await asyncio.gather(*tasks, return_exceptions=False) # 检查结果,如果有健康检查失败,返回对应的索引