Skip to content

feat: support more PD node select func #970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ec38c31
feat: setup router funs
zhhangBian Jul 1, 2025
5981a2b
feat: setup new select manner
zhhangBian Jul 14, 2025
19bbc19
feat: add cli args support
zhhangBian Jul 14, 2025
072a709
feat: add memory select
zhhangBian Jul 14, 2025
d4a9dd7
feat: change memory token_load to 1s search
zhhangBian Jul 14, 2025
28ed0dd
feat: change select func to select class
zhhangBian Jul 14, 2025
21eb993
fix: typing
zhhangBian Jul 14, 2025
650bcd7
fix: remove unnecessary func
zhhangBian Jul 14, 2025
ed73734
feat: mroe clear
zhhangBian Jul 14, 2025
5e193e1
feat: add load info in pd_master
zhhangBian Jul 15, 2025
5748672
feat: change pd mem selector to pd_master load info
zhhangBian Jul 15, 2025
d9b53ff
fix: typo
zhhangBian Jul 15, 2025
f731532
feat: 更新为只有在有请求结束后才更新负载信息
zhhangBian Jul 16, 2025
6922a66
fix: create func para
zhhangBian Jul 16, 2025
917e30a
fix: typo
zhhangBian Jul 16, 2025
3b250be
fix: import error
hiworldwzj Jul 17, 2025
fb053ee
fix: arg name
zhhangBian Jul 17, 2025
da45d06
fix: dict key
zhhangBian Jul 18, 2025
e983ce8
remove unnecessary files
zhhangBian Jul 28, 2025
e68aca3
docker
zhhangBian Jul 28, 2025
ef6b024
docker main
zhhangBian Jul 28, 2025
b95ac63
remove env.sh
zhhangBian Jul 28, 2025
f6060f1
feat: 使用预测信息进行,避免拥塞
zhhangBian Jul 29, 2025
5155920
fix: better style
zhhangBian Jul 29, 2025
d05c0a0
fix: import
zhhangBian Jul 29, 2025
2c187a1
fix: lint
zhhangBian Jul 31, 2025
bcbaa71
拆分获取节点负载信息的函数
zhhangBian Jul 31, 2025
2674fd5
remove pd radix selector
zhhangBian Jul 31, 2025
25459a3
better typing
zhhangBian Jul 31, 2025
a5e14ce
remove unnecessary if
zhhangBian Jul 31, 2025
ed77be1
feat: better load info select
zhhangBian Jul 31, 2025
cedbb50
feat: update get load info
zhhangBian Jul 31, 2025
e55cbb5
fix: memory node select
zhhangBian Jul 31, 2025
c4e5a64
Merge branch 'ModelTC:main' into pd-router
zhhangBian Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=42000,
help="p d mode, decode node used for kv move manager rpyc server port",
)
parser.add_argument(
"--select_p_d_node_func",
type=str,
default="round_robin",
choices=["random", "round_robin", "memory"],
help="select p d node func, can be round_robin, random or memory",
)
parser.add_argument(
"--config_server_host",
type=str,
Expand Down
22 changes: 21 additions & 1 deletion lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,29 @@ async def _pd_process_generate(
logger.error(str(e))


# 获取节点负载信息
def _get_load_info(have_finished_req: bool) -> dict:
if not have_finished_req:
return None

from lightllm.server.api_http import g_objs
assert g_objs.shared_token_load is not None, "shared_token_load is not initialized"
current_load = [
float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(g_objs.args.dp)
]
load_info = {
"mem_len": min(current_load),
"client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}"
}
return load_info


# 转发token的task
async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
while True:
handle_list = await forwarding_queue.wait_to_get_all_data()

if handle_list:
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list)))
have_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list)
load_info: dict = _get_load_info(have_finished_req)
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)))
105 changes: 75 additions & 30 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,99 @@
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from .node_info_recorder import PredictNodeInfoRecorder
from .pd_selector import (
create_selector,
PDSelector,
)

logger = init_logger(__name__)


class HttpServerManagerForPDMaster:
def __init__(
self,
args,
metric_port,
):
class PDManager:
def __init__(self, args):
self.args = args
self.metric_client = MetricClient(metric_port)
self.id_gen = ReqIDGenerator()
self.prefill_nodes: List[PD_Client_Obj] = []
self.decode_nodes: List[PD_Client_Obj] = []
self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {}

self.req_id_to_out_inf: Dict[int, ReqStatus] = {}
self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对

self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)

self.first_time_costs = MovingAverage()
self.per_token_costs = MovingAverage()
self.node_info_recorder: PredictNodeInfoRecorder = PredictNodeInfoRecorder()
self.selector: PDSelector = create_selector(args.select_p_d_node_func, self.prefill_nodes, self.decode_nodes, self)
return

async def register_pd(self, pd_info_json, websocket):
pd_client = PD_Client_Obj(**pd_info_json)
pd_client.websocket = websocket
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client
self.node_info_recorder.register_node(pd_client)

if pd_client.mode == "prefill":
self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
self.prefill_nodes.append(pd_client)
elif pd_client.mode == "decode":
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]
self.decode_nodes.append(pd_client)
else:
assert False
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"

await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)

logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed")
return

async def remove_pd(self, pd_info_json):
pd_client = PD_Client_Obj(**pd_info_json)
try:
del self.url_to_pd_nodes[pd_client.client_ip_port]
except:
pass
self.node_info_recorder.remove_node(pd_client)

self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port]
self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port]

await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes)

logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed")
return

def update_node_load_info(self, load_info: dict):
"""更新节点负载信息"""
if load_info is None:
return
self.node_info_recorder.update_node_load_info(load_info)

def get_predict_node_infos(self):
"""获取所有节点的预测负载信息"""
return self.node_info_recorder.get_predict_node_infos()

async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
p_node, d_node = await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params)
self.node_info_recorder.update_predict_node_info(p_node, d_node, prompt, sampling_params, multimodal_params)
return p_node, d_node

class HttpServerManagerForPDMaster:
def __init__(
self,
args,
metric_port,
):
self.args = args
self.metric_client = MetricClient(metric_port)
self.id_gen = ReqIDGenerator()

self.pd_manager = PDManager(args)

self.req_id_to_out_inf: Dict[int, ReqStatus] = {}
self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对

self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)

self.first_time_costs = MovingAverage()
self.per_token_costs = MovingAverage()
return

async def register_pd(self, pd_info_json, websocket):
await self.pd_manager.register_pd(pd_info_json, websocket)
return

async def remove_pd(self, pd_info_json):
await self.pd_manager.remove_pd(pd_info_json)
return

async def update_req_status(self, upkv_status: UpKVStatus):
try:
group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id)
Expand Down Expand Up @@ -108,11 +148,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
async def select_p_d_node(
self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams
) -> Tuple[PD_Client_Obj, PD_Client_Obj]:
import random

p_node = random.choice(self.prefill_nodes)
d_node = random.choice(self.decode_nodes)
return p_node, d_node
return await self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params)

async def generate(
self,
Expand Down Expand Up @@ -264,7 +300,7 @@ async def _wait_to_token_package(
request: Request,
):
out_token_counter = 0
first_token_cost_ms = sys.float_info.max
first_token_cost_ms = float('inf')
group_request_id = sampling_params.group_request_id
unfinished_count = sampling_params.best_of
is_first_token = True
Expand Down Expand Up @@ -368,7 +404,16 @@ async def handle_loop(self):
try:
for obj in objs:
if obj[0] == ObjType.TOKEN_PACKS:
for sub_req_id, text, metadata, finish_status in obj[1]:
# 检查是否包含节点信息
if len(obj) >= 3:
handle_list, load_info = obj[1], obj[2]
# 更新节点负载信息
self.pd_manager.update_node_load_info(load_info)
else:
# 兼容旧格式
handle_list = obj[1]

for sub_req_id, text, metadata, finish_status in handle_list:
finish_status: FinishStatus = finish_status
group_req_id = convert_sub_id_to_group_id(sub_req_id)
try:
Expand Down
91 changes: 91 additions & 0 deletions lightllm/server/httpserver_for_pd_master/node_info_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import copy

from ..pd_io_struct import PD_Client_Obj
from lightllm.server.httpserver.manager import SamplingParams, MultimodalParams
from typing import Union, List, Dict
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class NodeInfoRecorder:
def __init__(self):
self.prefill_node_info: dict = {}
self.decode_node_info: dict = {}

def register_node(self, pd_client: PD_Client_Obj):
node_info = {
"node_id": pd_client.node_id,
"client_ip_port": pd_client.client_ip_port,
"mode": pd_client.mode,
"node": pd_client,
"mem_len": 0,
# "batch_size": 0,
}
if pd_client.mode == "prefill":
self.prefill_node_info[pd_client.client_ip_port] = node_info
elif pd_client.mode == "decode":
self.decode_node_info[pd_client.client_ip_port] = node_info
else:
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"

def remove_node(self, pd_client: PD_Client_Obj):
if pd_client.mode == "prefill":
del self.prefill_node_info[pd_client.client_ip_port]
elif pd_client.mode == "decode":
del self.decode_node_info[pd_client.client_ip_port]
else:
assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}"

def update_node_load_info(self, load_info: dict):
if "client_ip_port" in load_info:
ip_port = load_info["client_ip_port"]
if ip_port in self.prefill_node_info:
self.prefill_node_info[ip_port]["mem_len"] = load_info["mem_len"]
elif ip_port in self.decode_node_info:
self.decode_node_info[ip_port]["mem_len"] = load_info["mem_len"]
else:
logger.warning(f"Received load info for unknown node: {ip_port}")
else:
logger.warning("Received load info without client_ip_port")


class PredictNodeInfoRecorder(NodeInfoRecorder):
def __init__(self):
super().__init__()
self.prefill_predict_node_info: dict = {}
self.decode_predict_node_info: dict = {}

def register_node(self, pd_client: PD_Client_Obj):
super().register_node(pd_client)
if pd_client.mode == "prefill":
self.prefill_predict_node_info[pd_client.client_ip_port] = copy.copy(self.prefill_node_info[pd_client.client_ip_port])
elif pd_client.mode == "decode":
self.decode_predict_node_info[pd_client.client_ip_port] = copy.copy(self.decode_node_info[pd_client.client_ip_port])

def remove_node(self, pd_client: PD_Client_Obj):
super().remove_node(pd_client)
if pd_client.mode == "prefill":
del self.prefill_predict_node_info[pd_client.client_ip_port]
elif pd_client.mode == "decode":
del self.decode_predict_node_info[pd_client.client_ip_port]

def update_node_load_info(self, load_info: dict):
super().update_node_load_info(load_info)
ip_port = load_info["client_ip_port"]
if ip_port in self.prefill_node_info:
self.prefill_predict_node_info[ip_port] = copy.copy(self.prefill_node_info[ip_port])
elif ip_port in self.decode_node_info:
self.decode_predict_node_info[ip_port] = copy.copy(self.decode_node_info[ip_port])
else:
logger.warning(f"Received load info for unknown node: {ip_port}")

def update_predict_node_info(self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams):
self.prefill_predict_node_info[p_node.client_ip_port]["mem_len"] += len(prompt)
self.decode_predict_node_info[d_node.client_ip_port]["mem_len"] += sampling_params.max_new_tokens

def get_predict_node_infos(self) -> Dict[str, dict]:
return {
"prefill": self.prefill_predict_node_info,
"decode": self.decode_predict_node_info,
}
25 changes: 25 additions & 0 deletions lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List
from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj
from .pd_selector import (
PDSelector,
RandomSelector,
RoundRobinSelector,
MemorySelector
)

__all__ = [
"PDSelector",
"RandomSelector",
"RoundRobinSelector",
"MemorySelector"
]

def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager) -> PDSelector:
if selector_type == "random":
return RandomSelector(prefill_nodes, decode_nodes, pd_manager)
elif selector_type == "round_robin":
return RoundRobinSelector(prefill_nodes, decode_nodes, pd_manager)
elif selector_type == "memory":
return MemorySelector(prefill_nodes, decode_nodes, pd_manager)
else:
raise ValueError(f"Invalid selector type: {selector_type}")
Loading