Skip to content

feat: support more PD node select func #970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build_and_upload_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ fi
IMAGE_TAG=$2
ACCOUNT=$1
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com
DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG .
DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG .
docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG

#deepep
DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.deepep -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep .
docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep
3 changes: 3 additions & 0 deletions env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
source /mtc/bianzhuohang/miniconda3/bin/activate
conda activate lightllm_router
clear
7 changes: 7 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=42000,
help="p d mode, decode node used for kv move manager rpyc server port",
)
parser.add_argument(
"--select_p_d_node_func",
type=str,
default="random",
choices=["random", "round_robin", "memory", "radix"],
help="select p d node func, can be random, round_robin, memory or radix",
)
parser.add_argument(
"--config_server_host",
type=str,
Expand Down
24 changes: 23 additions & 1 deletion lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,27 @@ async def _pd_process_generate(
async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
while True:
handle_list = await forwarding_queue.wait_to_get_all_data()

if handle_list:
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list)))
has_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list)
if has_finished_req:
# 获取节点负载信息
from lightllm.server.api_http import g_objs
if g_objs.shared_token_load is not None:
current_load = [
float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp)
]
if g_objs.args.dp == 1:
current_load = current_load[0]
load_info = {
"current_load": current_load,
"client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}"
}
else:
load_info = {
"current_load": 0.0,
"client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}"
}
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))
else:
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, None)))
141 changes: 114 additions & 27 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,133 @@
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from .pd_selector import (
create_selector,
PDSelector,
MemorySelector,
)

logger = init_logger(__name__)


class HttpServerManagerForPDMaster:
def __init__(
self,
args,
metric_port,
):
class PDManager:
def __init__(self, args):
self.args = args
self.metric_client = MetricClient(metric_port)
self.id_gen = ReqIDGenerator()
self.node_info: Dict[str, dict] = {}
self.prefill_nodes: List[PD_Client_Obj] = []
self.decode_nodes: List[PD_Client_Obj] = []
self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {}

self.req_id_to_out_inf: Dict[int, ReqStatus] = {}
self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对

self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)

self.first_time_costs = MovingAverage()
self.per_token_costs = MovingAverage()
self.selector: PDSelector = create_selector(args.select_p_d_node_func, self.prefill_nodes, self.decode_nodes, self)
return

async def register_pd(self, pd_info_json, websocket):
pd_client = PD_Client_Obj(**pd_info_json)
pd_client.websocket = websocket
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client
self.node_info[pd_client.client_ip_port] = {
"node_id": pd_info_json["node_id"],
"client_ip_port": pd_info_json["client_ip_port"],
"mode": pd_info_json["mode"],
"node": pd_client,
"load": 0.0,
}

if pd_client.mode == "prefill":
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
self.prefill_nodes.append(pd_client)
elif pd_client.mode == "decode":
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
self.decode_nodes.append(pd_client)
else:
assert False
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"

await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)

logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed")
return

async def remove_pd(self, pd_info_json):
pd_client = PD_Client_Obj(**pd_info_json)
try:
del self.url_to_pd_nodes[pd_client.client_ip_port]
del self.node_info[pd_client.client_ip_port]
except:
pass

if pd_client.client_ip_port in self.node_info:
del self.node_info[pd_client.client_ip_port]

self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]

await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)

logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed")
return

def update_node_load_info(self, load_info: dict):
"""更新节点负载信息"""
if load_info is None:
return

if "client_ip_port" in load_info:
ip_port = load_info["client_ip_port"]
if ip_port in self.node_info:
current_load = load_info["current_load"]
if isinstance(current_load, list):
if len(current_load) > 0:
load_value = float(current_load[0])
else:
load_value = 0.0
else:
load_value = float(current_load)

self.node_info[ip_port]["load"] = load_value
logger.debug(f"Updated node load info for {ip_port}: {load_value}")
else:
logger.warning(f"Received load info for unknown node: {ip_port}")

def get_node_load_info(self):
"""获取所有节点的负载信息"""
return {k: v.get("load", float("inf")) for k, v in self.node_info.items()}

def get_node_load_info_by_node(self, client_ip_port: str):
"""获取指定节点的负载信息"""
node_info = self.node_info.get(client_ip_port, None)
if node_info is not None:
return node_info.get("load", float("inf"))
else:
return float("inf")

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
return await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params)

class HttpServerManagerForPDMaster:
def __init__(
self,
args,
metric_port,
):
self.args = args
self.metric_client = MetricClient(metric_port)
self.id_gen = ReqIDGenerator()

self.pd_manager = PDManager(args)

self.req_id_to_out_inf: Dict[int, ReqStatus] = {}
self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对

self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)

self.first_time_costs = MovingAverage()
self.per_token_costs = MovingAverage()
return

async def register_pd(self, pd_info_json, websocket):
await self.pd_manager.register_pd(pd_info_json, websocket)
return

async def remove_pd(self, pd_info_json):
await self.pd_manager.remove_pd(pd_info_json)
return

async def update_req_status(self, upkv_status: UpKVStatus):
try:
group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id)
Expand Down Expand Up @@ -108,11 +182,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
async def select_p_d_node(
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
import random

p_node = random.choice(self.prefill_nodes)
d_node = random.choice(self.decode_nodes)
return p_node, d_node
return await self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params)

async def generate(
self,
Expand Down Expand Up @@ -264,7 +334,7 @@ async def _wait_to_token_package(
request: Request,
):
out_token_counter = 0
first_token_cost_ms = sys.float_info.max
first_token_cost_ms = float('inf')
group_request_id = sampling_params.group_request_id
unfinished_count = sampling_params.best_of
is_first_token = True
Expand Down Expand Up @@ -351,6 +421,14 @@ async def timer_log(self):
async def put_to_handle_queue(self, obj):
await self.infos_queues.put(obj)

def get_node_load_info(self):
"""获取所有节点的负载信息"""
return self.pd_manager.get_node_load_info()

def get_node_load_info_by_node(self, client_ip_port: str):
"""获取指定节点的负载信息"""
return self.pd_manager.get_node_load_info_by_node(client_ip_port)

async def handle_loop(self):
self.infos_queues = AsyncQueue()
asyncio.create_task(self.timer_log())
Expand All @@ -368,7 +446,16 @@ async def handle_loop(self):
try:
for obj in objs:
if obj[0] == ObjType.TOKEN_PACKS:
for sub_req_id, text, metadata, finish_status in obj[1]:
# 检查是否包含节点信息
if len(obj) >= 3:
handle_list, load_info = obj[1], obj[2]
# 更新节点负载信息
self.pd_manager.update_node_load_info(load_info)
else:
# 兼容旧格式
handle_list = obj[1]

for sub_req_id, text, metadata, finish_status in handle_list:
finish_status: FinishStatus = finish_status
group_req_id = convert_sub_id_to_group_id(sub_req_id)
try:
Expand Down
29 changes: 29 additions & 0 deletions lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List
from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj
from .pd_selector import (
PDSelector,
RandomSelector,
RoundRobinSelector,
MemorySelector,
RadixSelector
)

__all__ = [
"PDSelector",
"RandomSelector",
"RoundRobinSelector",
"MemorySelector",
"RadixSelector"
]

def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager) -> PDSelector:
if selector_type == "random":
return RandomSelector(prefill_nodes, decode_nodes, pd_manager)
elif selector_type == "round_robin":
return RoundRobinSelector(prefill_nodes, decode_nodes, pd_manager)
elif selector_type == "memory":
return MemorySelector(prefill_nodes, decode_nodes, pd_manager)
elif selector_type == "radix":
return RadixSelector(prefill_nodes, decode_nodes, pd_manager)
else:
raise ValueError(f"Invalid selector type: {selector_type}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Union, List, Tuple
from lightllm.server.pd_io_struct import PD_Client_Obj
from lightllm.server.core.objs import SamplingParams
from lightllm.server.multimodal_params import MultimodalParams


class PDSelector:
def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager):
self.prefill_nodes: List[PD_Client_Obj] = prefill_nodes
self.decode_nodes: List[PD_Client_Obj] = decode_nodes
self.pd_manager = pd_manager

async def update_nodes(self, prefill_nodes, decode_nodes):
self.prefill_nodes = prefill_nodes
self.decode_nodes = decode_nodes

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
raise NotImplementedError("Subclass must implement this method")


class RandomSelector(PDSelector):
"""随机选择器"""

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
import random

p_node = random.choice(self.prefill_nodes)
d_node = random.choice(self.decode_nodes)
return p_node, d_node


class RoundRobinSelector(PDSelector):
"""轮询选择器"""

def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager):
super().__init__(prefill_nodes, decode_nodes, pd_manager)
self.prefill_node_index: int = 0
self.decode_node_index: int = 0

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
p_node = self.prefill_nodes[self.prefill_node_index]
d_node = self.decode_nodes[self.decode_node_index]
self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes)
self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes)
return p_node, d_node


class MemorySelector(PDSelector):
"""基于内存使用情况的选择器"""

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
if self.pd_manager is None:
# 如果没有 PDManager 引用,回退到随机选择
import random
p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None
d_node = random.choice(self.decode_nodes) if self.decode_nodes else None
return p_node, d_node

# 获取 prefill 节点的内存使用情况
prefill_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.prefill_nodes]
decode_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.decode_nodes]

import random
min_prefill_usage = min(prefill_usages) if prefill_usages else float('inf')
min_decode_usage = min(decode_usages) if decode_usages else float('inf')

p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') and prefill_usages else random.choice(self.prefill_nodes)
d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') and decode_usages else random.choice(self.decode_nodes)

return p_node, d_node

class RadixSelector(PDSelector):
async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
pass
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ frozendict==2.4.6
atomics==1.0.3
easydict==1.13
gunicorn==23.0.0
vllm==0.8.5
# vllm==0.8.5
flashinfer-python==0.2.4
sgl-kernel==0.1.4
httpx==0.28.1
Expand Down
11 changes: 11 additions & 0 deletions server_d.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=1 python3 -m lightllm.server.api_server \
--model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \
--run_mode "decode" \
--host 127.0.1.1 \
--port 8118 \
--nccl_port 12322 \
--tp 1 \
--tokenizer_mode fast \
--pd_master_ip 127.0.0.1 \
--pd_master_port 60011 \
--pd_decode_rpyc_port 42020
Loading