Skip to content

Commit

Permalink
split init engine and init app
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Dec 21, 2024
1 parent da216c1 commit e5f169a
Show file tree
Hide file tree
Showing 23 changed files with 747 additions and 272 deletions.
40 changes: 18 additions & 22 deletions examples/run_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import asyncio
from dataclasses import dataclass
import logging
import os

from PIL import Image

Expand All @@ -13,8 +13,6 @@ def parse_args():


args = parse_args()
import os

os.environ["TLLM_BACKEND"] = args.backend.upper()

from tllm.commons.manager import load_master_model
Expand All @@ -26,42 +24,42 @@ def parse_args():
from tllm.generate import ImageGenerator, LLMGenerator
from tllm.img_helper import base64_to_pil_image
from tllm.network.manager import LocalRPCManager
from tllm.utils import setup_logger


@dataclass
class Args:
model_path: str = "/Users/lujianghu/Documents/Llama-3.2-3B-Instruct"
# model_path: str = "/Users/lujianghu/Documents/Llama-3.2-3B-Instruct"
# model_path: str = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
model_path: str = "mlx-community/Llama-3.2-1B-Instruct-4bit"
# model_path: str = "/Users/lujianghu/Documents/flux/schnell_4bit"
# model_path: str = "Qwen/Qwen2.5-0.5B-Instruct"
# model_path: str = "Qwen/Qwen2-VL-2B-Instruct"
is_debug: bool = False


def init_engine(model_path, logger):
def init_engine(model_path):
model, tok = load_master_model(model_path)
rpc_manager = LocalRPCManager(model_path)
generator = LLMGenerator(rpc_manager, logger, model, tok)
engine = AsyncEngine(logger, generator)
return engine, tok
generator = LLMGenerator(rpc_manager, model, tok)
engine = AsyncEngine(generator)
return engine


def init_image_engine(model_path, logger):
def init_image_engine(model_path):
model, tok = load_master_model(model_path)
rpc_manager = LocalRPCManager(model_path)
generator = ImageGenerator(rpc_manager, logger, model, tok)
engine = AsyncEngine(logger, generator)
generator = ImageGenerator(rpc_manager, model, tok)
engine = AsyncEngine(generator)
return engine


async def llm_generate():
args = Args()

logger = setup_logger("engine", logging.DEBUG if args.is_debug else logging.INFO)
engine, tok = init_engine(args.model_path, logger)
_ = await engine.start()
engine = init_engine(args.model_path)
await engine.start()
messages = [{"role": "user", "content": "Hello, how are you?"}]
openai_serving_chat = OpenAIServing(engine, tok, args)
openai_serving_chat = OpenAIServing(engine, args)

request = ChatCompletionRequest(model="test", messages=messages)
response = await openai_serving_chat.create_chat_completion(request, None)
Expand All @@ -71,9 +69,8 @@ async def llm_generate():
async def mllm_generate():
args = Args()

logger = setup_logger("engine", logging.DEBUG if args.is_debug else logging.INFO)
engine, tok = init_engine(args.model_path, logger)
_ = await engine.start()
engine = init_engine(args.model_path)
await engine.start()
messages = [
{
"role": "user",
Expand All @@ -83,7 +80,7 @@ async def mllm_generate():
],
}
]
openai_serving_chat = OpenAIServing(engine, tok, args)
openai_serving_chat = OpenAIServing(engine, args)

request = ChatCompletionRequest(model="test", messages=messages)
response = await openai_serving_chat.create_chat_completion(request, None)
Expand All @@ -101,8 +98,7 @@ async def image_generate():
"width": 768,
}

logger = setup_logger("engine", logging.DEBUG if args.is_debug else logging.INFO)
engine = init_image_engine(args.model_path, logger)
engine = init_image_engine(args.model_path)
_ = await engine.start()

image_serving = ImageServing(engine, args)
Expand Down
4 changes: 2 additions & 2 deletions examples/run_single_server.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
# MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-1B-Instruct
MODEL_PATH=Qwen/Qwen2-VL-2B-Instruct
MODEL_PATH=/Users/lujianghu/Documents/Llama-3.2-1B-Instruct
# MODEL_PATH=Qwen/Qwen2-VL-2B-Instruct
# MODEL_PATH=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit
MASTER_HOSTNAME=m3pro

Expand Down
2 changes: 1 addition & 1 deletion flux_examples/run_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ MASTER_HOSTNAME=mac-mini

export PYTHONPATH="./":$PYTHONPATH;

python3 -m tllm.entrypoints.image_server.image_api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug
python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_debug
2 changes: 1 addition & 1 deletion flux_examples/run_single_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ MODEL_PATH=/Users/lujianghu/Documents/flux/schnell_4bit
MASTER_HOSTNAME=mac-mini

export PYTHONPATH="./":$PYTHONPATH;
python3 -m tllm.entrypoints.image_server.image_api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_local --is_debug
python3 -m tllm.entrypoints.api_server --ip_addr $MASTER_HOSTNAME --model_path $MODEL_PATH --is_local --is_debug --is_image


32 changes: 32 additions & 0 deletions minimized_examples/mp_shared_memory/api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio

from fastapi import FastAPI, HTTPException
from shared_memory import RingBuffer

# api_server.py
app = FastAPI()
ring_buffer = RingBuffer("engine_buffer")
response_buffer = RingBuffer("response_buffer")


@app.post("/process")
async def process_request(data: dict):
# 写入请求
if not ring_buffer.write(str(data).encode()):
raise HTTPException(status_code=503, detail="Buffer full")

# 等待响应
for _ in range(1000): # 设置超时
response = response_buffer.read()
if response:
return {"result": response.decode()}
await asyncio.sleep(0.0001)

raise HTTPException(status_code=408, detail="Request timeout")


# cleanup on shutdown
@app.on_event("shutdown")
async def shutdown_event():
ring_buffer.close()
response_buffer.close()
25 changes: 25 additions & 0 deletions minimized_examples/mp_shared_memory/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
from pathlib import Path


class Config:
# 项目基础配置
BASE_DIR = Path(__file__).parent
LOG_DIR = BASE_DIR / "logs"

# 共享内存配置
BUFFER_SIZE = 1024 * 1024 # 1MB
REQUEST_BUFFER_NAME = "engine_buffer"
RESPONSE_BUFFER_NAME = "response_buffer"

# API配置
API_HOST = "127.0.0.1"
API_PORT = 8000

# 进程配置
ENGINE_PROCESS_COUNT = 1 # 引擎进程数量

@classmethod
def init(cls):
# 创建日志目录
os.makedirs(cls.LOG_DIR, exist_ok=True)
28 changes: 28 additions & 0 deletions minimized_examples/mp_shared_memory/engine_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio

from shared_memory import RingBuffer


async def run_engine():
ring_buffer = RingBuffer("engine_buffer")
response_buffer = RingBuffer("response_buffer")

while True:
# 非阻塞读取
data = ring_buffer.read()
if data:
# 处理数据
result = await process_data(data)
# 写入响应
response_buffer.write(result)
await asyncio.sleep(0.0001) # 极小的睡眠以避免CPU占用过高


async def process_data(data: bytes):
# 模拟处理
await asyncio.sleep(0.001)
return b"Processed: " + data


if __name__ == "__main__":
asyncio.run(run_engine())
160 changes: 160 additions & 0 deletions minimized_examples/mp_shared_memory/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# 项目结构
# project/
# ├── requirements.txt
# ├── config.py
# ├── shared_memory.py # 上一个回答中的共享内存实现
# ├── engine_process.py # 上一个回答中的引擎实现
# ├── api_server.py # 上一个回答中的API实现
# └── run.py # 主启动脚本

import logging
import signal
import subprocess
import sys
import time

import click
from config import Config
import psutil

# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler(Config.LOG_DIR / "supervisor.log"), logging.StreamHandler()],
)
logger = logging.getLogger("supervisor")


class ProcessManager:
def __init__(self):
self.processes = {}
self.running = True
signal.signal(signal.SIGTERM, self.handle_signal)
signal.signal(signal.SIGINT, self.handle_signal)

def handle_signal(self, signum, frame):
logger.info(f"Received signal {signum}")
self.shutdown()

def start_process(self, name, cmd):
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
self.processes[name] = process
logger.info(f"Started {name} with PID {process.pid}")
return process
except Exception as e:
logger.error(f"Failed to start {name}: {e}")
return None

def monitor_process(self, name, process):
try:
stdout, stderr = process.communicate()
if process.returncode != 0:
logger.error(f"{name} exited with code {process.returncode}")
logger.error(f"stderr: {stderr}")
return process.returncode
except Exception as e:
logger.error(f"Error monitoring {name}: {e}")
return -1

def shutdown(self):
self.running = False
logger.info("Shutting down all processes...")

# 首先发送 SIGTERM
for name, process in self.processes.items():
if process.poll() is None:
logger.info(f"Sending SIGTERM to {name}")
process.terminate()

# 等待进程结束
for i in range(5): # 最多等待5秒
if all(process.poll() is not None for process in self.processes.values()):
break
time.sleep(1)

# 如果还有进程没有结束,发送 SIGKILL
for name, process in self.processes.items():
if process.poll() is None:
logger.info(f"Sending SIGKILL to {name}")
process.kill()


@click.group()
def cli():
"""进程管理CLI"""
pass


@cli.command()
def start():
"""启动所有服务"""
Config.init()
manager = ProcessManager()

# 启动引擎进程
for i in range(Config.ENGINE_PROCESS_COUNT):
manager.start_process(f"engine_{i}", [sys.executable, "engine_process.py"])

# 启动API服务
manager.start_process(
"api",
[
sys.executable,
"-m",
"uvicorn",
"api_server:app",
"--host",
Config.API_HOST,
"--port",
str(Config.API_PORT),
"--reload",
],
)

try:
while manager.running:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Received keyboard interrupt")
finally:
manager.shutdown()


@cli.command()
def status():
"""查看服务状态"""
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
try:
if "python" in proc.info["name"].lower():
cmdline = " ".join(proc.info["cmdline"])
if any(x in cmdline for x in ["engine_process.py", "api_server:app"]):
print(f"PID: {proc.info['pid']}, Command: {cmdline}")
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass


@cli.command()
def stop():
"""停止所有服务"""
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
try:
if "python" in proc.info["name"].lower():
cmdline = " ".join(proc.info["cmdline"])
if any(x in cmdline for x in ["engine_process.py", "api_server:app"]):
psutil.Process(proc.info["pid"]).terminate()
print(f"Terminated process {proc.info['pid']}")
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass


if __name__ == "__main__":
cli()

# requirements.txt
# fastapi==0.68.0
# uvicorn==0.15.0
# click==8.0.3
# psutil==5.8.0
# numpy==1.21.2
Loading

0 comments on commit e5f169a

Please sign in to comment.