From ec38c314da0004bf9c21daf9f52c3e82cd68ecc4 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 1 Jul 2025 17:29:34 +0800 Subject: [PATCH 01/33] feat: setup router funs --- lightllm/server/httpserver_for_pd_master/manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 05b2d987c..2a02187bb 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -108,12 +108,21 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def select_p_d_node( self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + return self._randome_select_p_d_node(prompt, sampling_params, multimodal_params) + + def _randome_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random p_node = random.choice(self.prefill_nodes) d_node = random.choice(self.decode_nodes) return p_node, d_node + def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + pass + + def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + pass + async def generate( self, prompt: Union[str, List[int]], From 5981a2bcae6c081f2c088b314aa6eba9d1be56ee Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 16:46:13 +0800 Subject: [PATCH 02/33] feat: setup new select manner --- .../httpserver_for_pd_master/manager.py | 105 ++++++++++++------ 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 2a02187bb..0eeac420d 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -12,7 +12,7 @@ import pickle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict +from typing import Union, List, Tuple, Dict, Callable from lightllm.server.core.objs import FinishStatus from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType from lightllm.server.core.objs import SamplingParams @@ -29,29 +29,16 @@ logger = init_logger(__name__) -class HttpServerManagerForPDMaster: - def __init__( - self, - args, - metric_port, - ): +class PDManager: + def __init__(self, args): self.args = args - self.metric_client = MetricClient(metric_port) - self.id_gen = ReqIDGenerator() self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} - - self.req_id_to_out_inf: Dict[int, ReqStatus] = {} - self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 - - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - - self.first_time_costs = MovingAverage() - self.per_token_costs = MovingAverage() + self.select_p_d_node_func = None return - async def register_pd(self, pd_info_json, websocket): + def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client @@ -67,7 +54,7 @@ async def register_pd(self, pd_info_json, websocket): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed") return - async def remove_pd(self, pd_info_json): + def remove_pd(self, pd_info_json): pd_client = PD_Client_Obj(**pd_info_json) try: del self.url_to_pd_nodes[pd_client.client_ip_port] @@ -78,6 +65,71 @@ async def remove_pd(self, pd_info_json): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return + def _get_select_p_d_node_func(self, select_p_d_node_func_name: str) -> Callable[[Union[str, List[int]], SamplingParams, MultimodalParams], Tuple[PD_Client_Obj, PD_Client_Obj]]: + if select_p_d_node_func_name == "random": + self.prefill_node_index = 0 + self.decode_node_index = 0 + return self._random_select_p_d_node + elif select_p_d_node_func_name == "memory": + return self._memory_select_p_d_node + elif select_p_d_node_func_name == "radix": + return self._radix_select_p_d_node + else: + raise ValueError(f"invalid select_p_d_node_func_name: {select_p_d_node_func_name}") + + def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + assert self.select_p_d_node_func is not None, "select_p_d_node_func is not set" + return self.select_p_d_node_func(prompt, sampling_params, multimodal_params) + + def _random_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + import random + + p_node = random.choice(self.prefill_nodes) + d_node = random.choice(self.decode_nodes) + return p_node, d_node + + def _round_robin_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + p_node = self.prefill_nodes[self.prefill_node_index] + d_node = self.decode_nodes[self.decode_node_index] + self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes) + self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) + return p_node, d_node + + def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + pass + + def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + pass + +class HttpServerManagerForPDMaster: + def __init__( + self, + args, + metric_port, + ): + self.args = args + self.metric_client = MetricClient(metric_port) + self.id_gen = ReqIDGenerator() + + self.pd_manager = PDManager(args) + + self.req_id_to_out_inf: Dict[int, ReqStatus] = {} + self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 + + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + + self.first_time_costs = MovingAverage() + self.per_token_costs = MovingAverage() + return + + async def register_pd(self, pd_info_json, websocket): + self.pd_manager.register_pd(pd_info_json, websocket) + return + + async def remove_pd(self, pd_info_json): + self.pd_manager.remove_pd(pd_info_json) + return + async def update_req_status(self, upkv_status: UpKVStatus): try: group_request_id = convert_sub_id_to_group_id(upkv_status.group_request_id) @@ -108,20 +160,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def select_p_d_node( self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - return self._randome_select_p_d_node(prompt, sampling_params, multimodal_params) - - def _randome_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - import random - - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) - return p_node, d_node - - def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass - - def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass + return self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params) async def generate( self, From 19bbc19c21a2f82a33e2d5555fb06cfecee74976 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 16:48:29 +0800 Subject: [PATCH 03/33] feat: add cli args support --- lightllm/server/api_cli.py | 7 +++++++ lightllm/server/httpserver_for_pd_master/manager.py | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index e9943b05f..55f2e14af 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -42,6 +42,13 @@ def make_argument_parser() -> argparse.ArgumentParser: default=42000, help="p d mode, decode node used for kv move manager rpyc server port", ) + parser.add_argument( + "--select_p_d_node_func", + type=str, + default="random", + choices=["random", "round_robin", "memory", "radix"], + help="select p d node func, can be random, round_robin, memory or radix", + ) parser.add_argument( "--config_server_host", type=str, diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 0eeac420d..857995353 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -35,7 +35,7 @@ def __init__(self, args): self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} - self.select_p_d_node_func = None + self.select_p_d_node_func = self._get_select_p_d_node_func(args.select_p_d_node_func) return def register_pd(self, pd_info_json, websocket): @@ -67,9 +67,11 @@ def remove_pd(self, pd_info_json): def _get_select_p_d_node_func(self, select_p_d_node_func_name: str) -> Callable[[Union[str, List[int]], SamplingParams, MultimodalParams], Tuple[PD_Client_Obj, PD_Client_Obj]]: if select_p_d_node_func_name == "random": + return self._random_select_p_d_node + elif select_p_d_node_func_name == "round_robin": self.prefill_node_index = 0 self.decode_node_index = 0 - return self._random_select_p_d_node + return self._round_robin_select_p_d_node elif select_p_d_node_func_name == "memory": return self._memory_select_p_d_node elif select_p_d_node_func_name == "radix": From 072a709474ea449567cbdad03661cc571e5f784b Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 17:16:23 +0800 Subject: [PATCH 04/33] feat: add memory select --- .../httpserver_for_pd_master/manager.py | 78 ++++++++++++++++--- 1 file changed, 69 insertions(+), 9 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 857995353..e951f2ed3 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -65,7 +65,7 @@ def remove_pd(self, pd_info_json): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return - def _get_select_p_d_node_func(self, select_p_d_node_func_name: str) -> Callable[[Union[str, List[int]], SamplingParams, MultimodalParams], Tuple[PD_Client_Obj, PD_Client_Obj]]: + def _get_select_p_d_node_func(self, select_p_d_node_func_name: str): if select_p_d_node_func_name == "random": return self._random_select_p_d_node elif select_p_d_node_func_name == "round_robin": @@ -79,28 +79,88 @@ def _get_select_p_d_node_func(self, select_p_d_node_func_name: str) -> Callable[ else: raise ValueError(f"invalid select_p_d_node_func_name: {select_p_d_node_func_name}") - def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: assert self.select_p_d_node_func is not None, "select_p_d_node_func is not set" - return self.select_p_d_node_func(prompt, sampling_params, multimodal_params) + return await self.select_p_d_node_func(prompt, sampling_params, multimodal_params) - def _random_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + async def _random_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: import random p_node = random.choice(self.prefill_nodes) d_node = random.choice(self.decode_nodes) return p_node, d_node - def _round_robin_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + async def _round_robin_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: p_node = self.prefill_nodes[self.prefill_node_index] d_node = self.decode_nodes[self.decode_node_index] self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes) self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) return p_node, d_node - def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass + # 哪个节点内存占用最小,就选哪个 + async def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + import aiohttp + import asyncio + import random + async def get_node_memory_usage(node: PD_Client_Obj) -> float: + """获取节点的内存使用率""" + try: + node_url = f"http://{node.client_ip_port}/token_load" + timeout = aiohttp.ClientTimeout(total=5.0) # 5秒超时 + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(node_url) as response: + if response.status == 200: + data = await response.json() + # 获取当前负载,这表示内存使用情况 + # current_load 值越大表示使用的内存越多 + current_load = data.get("current_load", 0.0) + if isinstance(current_load, list): + # 如果是列表,取第一个值(通常dp=1时会返回单个值的列表) + current_load = current_load[0] if current_load else 0.0 + return float(current_load) + else: + logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") + return 0.0 + except Exception as e: + logger.warning(f"Error getting memory usage from {node.client_ip_port}: {str(e)}") + return 0.0 + + # 并行获取所有prefill节点的内存使用情况 + prefill_tasks = [get_node_memory_usage(node) for node in self.prefill_nodes] + prefill_usages = await asyncio.gather(*prefill_tasks, return_exceptions=True) + + # 并行获取所有decode节点的内存使用情况 + decode_tasks = [get_node_memory_usage(node) for node in self.decode_nodes] + decode_usages = await asyncio.gather(*decode_tasks, return_exceptions=True) + + # 处理异常结果,将异常替换为0.0 + prefill_usages = [usage if not isinstance(usage, Exception) else sys.float_info.max for usage in prefill_usages] + decode_usages = [usage if not isinstance(usage, Exception) else sys.float_info.max for usage in decode_usages] + + # 找到内存使用最少的prefill节点 + min_prefill_usage = sys.float_info.max + min_prefill_index = 0 + for i, usage in enumerate(prefill_usages): + if usage < min_prefill_usage: + min_prefill_usage = usage + min_prefill_index = i + p_node = self.prefill_nodes[min_prefill_index] if min_prefill_usage != sys.float_info.max else random.choice(self.prefill_nodes) + + # 找到内存使用最少的decode节点 + min_decode_usage = sys.float_info.max + min_decode_index = 0 + for i, usage in enumerate(decode_usages): + if usage < min_decode_usage: + min_decode_usage = usage + min_decode_index = i + d_node = self.decode_nodes[min_decode_index] if min_decode_usage != sys.float_info.max else random.choice(self.decode_nodes) + + logger.debug(f"Selected prefill node {p_node.client_ip_port} with usage {min_prefill_usage}") + logger.debug(f"Selected decode node {d_node.client_ip_port} with usage {min_decode_usage}") + + return p_node, d_node - def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + async def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: pass class HttpServerManagerForPDMaster: @@ -162,7 +222,7 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar async def select_p_d_node( self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams ) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - return self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params) + return await self.pd_manager.select_p_d_node(prompt, sampling_params, multimodal_params) async def generate( self, From d4a9dd737bd2fef2abe793155aeee7aa215bfb15 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 17:34:07 +0800 Subject: [PATCH 05/33] feat: change memory token_load to 1s search --- .../httpserver_for_pd_master/manager.py | 139 ++++++++++-------- .../pd_selector/__init__.py | 0 .../pd_selector/pd_selector.py | 0 3 files changed, 80 insertions(+), 59 deletions(-) create mode 100644 lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py create mode 100644 lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index e951f2ed3..e98bc43cb 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -60,6 +60,11 @@ def remove_pd(self, pd_info_json): del self.url_to_pd_nodes[pd_client.client_ip_port] except: pass + + # 从内存缓存中删除该节点的数据 + if hasattr(self, 'node_memory_cache'): + self.node_memory_cache.pop(pd_client.client_ip_port, None) + self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") @@ -73,12 +78,73 @@ def _get_select_p_d_node_func(self, select_p_d_node_func_name: str): self.decode_node_index = 0 return self._round_robin_select_p_d_node elif select_p_d_node_func_name == "memory": + # 内存使用情况缓存 + self.node_memory_cache: Dict[str, float] = {} # 节点IP:PORT -> 内存使用率 + self.memory_monitor_task = None + self._start_memory_monitor() return self._memory_select_p_d_node elif select_p_d_node_func_name == "radix": return self._radix_select_p_d_node else: raise ValueError(f"invalid select_p_d_node_func_name: {select_p_d_node_func_name}") + def _start_memory_monitor(self): + """启动内存监控后台任务""" + import asyncio + + if self.memory_monitor_task is None or self.memory_monitor_task.done(): + try: + loop = asyncio.get_event_loop() + self.memory_monitor_task = loop.create_task(self._memory_monitor_loop()) + logger.info("Started memory monitoring task") + except RuntimeError: + logger.warning("No event loop running, memory monitoring will start later") + + async def _memory_monitor_loop(self): + """后台内存监控循环,每秒更新一次所有节点的内存使用情况""" + import aiohttp + import asyncio + + async def get_node_memory_usage(node: PD_Client_Obj) -> tuple: + try: + node_url = f"http://{node.client_ip_port}/token_load" + timeout = aiohttp.ClientTimeout(total=3.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(node_url) as response: + if response.status == 200: + data = await response.json() + current_load = data.get("current_load", 0.0) + if isinstance(current_load, list): + current_load = current_load[0] if current_load else 0.0 + return node.client_ip_port, float(current_load) + else: + logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") + return node.client_ip_port, float('inf') + except Exception as e: + logger.warning(f"Error getting memory usage from {node.client_ip_port}: {str(e)}") + return node.client_ip_port, float('inf') + + while True: + try: + all_nodes = self.prefill_nodes + self.decode_nodes + tasks = [get_node_memory_usage(node) for node in all_nodes] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 更新缓存 + for result in results: + if isinstance(result, tuple) and len(result) == 2: + node_key, usage = result + self.node_memory_cache[node_key] = usage + + logger.debug(f"Updated memory cache: {self.node_memory_cache}") + + # 等待1秒后再次检查 + await asyncio.sleep(1.0) + + except Exception as e: + logger.error(f"Error in memory monitor loop: {str(e)}") + await asyncio.sleep(1.0) + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: assert self.select_p_d_node_func is not None, "select_p_d_node_func is not set" return await self.select_p_d_node_func(prompt, sampling_params, multimodal_params) @@ -97,66 +163,17 @@ async def _round_robin_select_p_d_node(self, prompt: Union[str, List[int]], samp self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) return p_node, d_node - # 哪个节点内存占用最小,就选哪个 + # 哪个节点内存占用最少,就选哪个 async def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - import aiohttp - import asyncio + # 获取prefill节点的内存使用情况 + prefill_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.prefill_nodes] + decode_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.decode_nodes] + import random - async def get_node_memory_usage(node: PD_Client_Obj) -> float: - """获取节点的内存使用率""" - try: - node_url = f"http://{node.client_ip_port}/token_load" - timeout = aiohttp.ClientTimeout(total=5.0) # 5秒超时 - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(node_url) as response: - if response.status == 200: - data = await response.json() - # 获取当前负载,这表示内存使用情况 - # current_load 值越大表示使用的内存越多 - current_load = data.get("current_load", 0.0) - if isinstance(current_load, list): - # 如果是列表,取第一个值(通常dp=1时会返回单个值的列表) - current_load = current_load[0] if current_load else 0.0 - return float(current_load) - else: - logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") - return 0.0 - except Exception as e: - logger.warning(f"Error getting memory usage from {node.client_ip_port}: {str(e)}") - return 0.0 - - # 并行获取所有prefill节点的内存使用情况 - prefill_tasks = [get_node_memory_usage(node) for node in self.prefill_nodes] - prefill_usages = await asyncio.gather(*prefill_tasks, return_exceptions=True) - - # 并行获取所有decode节点的内存使用情况 - decode_tasks = [get_node_memory_usage(node) for node in self.decode_nodes] - decode_usages = await asyncio.gather(*decode_tasks, return_exceptions=True) - - # 处理异常结果,将异常替换为0.0 - prefill_usages = [usage if not isinstance(usage, Exception) else sys.float_info.max for usage in prefill_usages] - decode_usages = [usage if not isinstance(usage, Exception) else sys.float_info.max for usage in decode_usages] - - # 找到内存使用最少的prefill节点 - min_prefill_usage = sys.float_info.max - min_prefill_index = 0 - for i, usage in enumerate(prefill_usages): - if usage < min_prefill_usage: - min_prefill_usage = usage - min_prefill_index = i - p_node = self.prefill_nodes[min_prefill_index] if min_prefill_usage != sys.float_info.max else random.choice(self.prefill_nodes) - - # 找到内存使用最少的decode节点 - min_decode_usage = sys.float_info.max - min_decode_index = 0 - for i, usage in enumerate(decode_usages): - if usage < min_decode_usage: - min_decode_usage = usage - min_decode_index = i - d_node = self.decode_nodes[min_decode_index] if min_decode_usage != sys.float_info.max else random.choice(self.decode_nodes) - - logger.debug(f"Selected prefill node {p_node.client_ip_port} with usage {min_prefill_usage}") - logger.debug(f"Selected decode node {d_node.client_ip_port} with usage {min_decode_usage}") + min_prefill_usage = min(prefill_usages) + min_decode_usage = min(decode_usages) + p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') else random.choice(self.prefill_nodes) + d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') else random.choice(self.decode_nodes) return p_node, d_node @@ -374,7 +391,7 @@ async def _wait_to_token_package( request: Request, ): out_token_counter = 0 - first_token_cost_ms = sys.float_info.max + first_token_cost_ms = float('inf') group_request_id = sampling_params.group_request_id unfinished_count = sampling_params.best_of is_first_token = True @@ -465,6 +482,10 @@ async def handle_loop(self): self.infos_queues = AsyncQueue() asyncio.create_task(self.timer_log()) + # 如果使用memory策略,确保监控任务已启动 + if hasattr(self.pd_manager, 'memory_monitor_task') and self.args.select_p_d_node_func == "memory": + self.pd_manager._start_memory_monitor() + use_config_server = self.args.config_server_host and self.args.config_server_port if use_config_server: diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py new file mode 100644 index 000000000..e69de29bb From 28ed0dd16755e403c82e819b5a810696c3f23cc1 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 17:46:06 +0800 Subject: [PATCH 06/33] feat: change select func to select class --- .../httpserver_for_pd_master/manager.py | 138 +++--------------- .../pd_selector/__init__.py | 29 ++++ .../pd_selector/pd_selector.py | 127 ++++++++++++++++ 3 files changed, 180 insertions(+), 114 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index e98bc43cb..ed67be370 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -25,6 +25,11 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue from lightllm.utils.error_utils import ServerBusyError +from .pd_selector import ( + create_selector, + PDSelector, + MemorySelector, +) logger = init_logger(__name__) @@ -35,9 +40,12 @@ def __init__(self, args): self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} - self.select_p_d_node_func = self._get_select_p_d_node_func(args.select_p_d_node_func) + self.selector: PDSelector = self._create_selector(args.select_p_d_node_func) return + def _create_selector(self, select_p_d_node_func_name: str) -> PDSelector: + return create_selector(select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) + def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket @@ -51,6 +59,10 @@ def register_pd(self, pd_info_json, websocket): else: assert False + # 更新选择器的节点列表 + self.selector.prefill_nodes = self.prefill_nodes + self.selector.decode_nodes = self.decode_nodes + logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed") return @@ -61,124 +73,22 @@ def remove_pd(self, pd_info_json): except: pass - # 从内存缓存中删除该节点的数据 - if hasattr(self, 'node_memory_cache'): - self.node_memory_cache.pop(pd_client.client_ip_port, None) + # 从内存缓存中删除该节点的数据(如果是MemorySelector) + if isinstance(self.selector, MemorySelector): + self.selector.remove_node_cache(pd_client.client_ip_port) self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] - logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") - return - - def _get_select_p_d_node_func(self, select_p_d_node_func_name: str): - if select_p_d_node_func_name == "random": - return self._random_select_p_d_node - elif select_p_d_node_func_name == "round_robin": - self.prefill_node_index = 0 - self.decode_node_index = 0 - return self._round_robin_select_p_d_node - elif select_p_d_node_func_name == "memory": - # 内存使用情况缓存 - self.node_memory_cache: Dict[str, float] = {} # 节点IP:PORT -> 内存使用率 - self.memory_monitor_task = None - self._start_memory_monitor() - return self._memory_select_p_d_node - elif select_p_d_node_func_name == "radix": - return self._radix_select_p_d_node - else: - raise ValueError(f"invalid select_p_d_node_func_name: {select_p_d_node_func_name}") - - def _start_memory_monitor(self): - """启动内存监控后台任务""" - import asyncio - - if self.memory_monitor_task is None or self.memory_monitor_task.done(): - try: - loop = asyncio.get_event_loop() - self.memory_monitor_task = loop.create_task(self._memory_monitor_loop()) - logger.info("Started memory monitoring task") - except RuntimeError: - logger.warning("No event loop running, memory monitoring will start later") - - async def _memory_monitor_loop(self): - """后台内存监控循环,每秒更新一次所有节点的内存使用情况""" - import aiohttp - import asyncio - - async def get_node_memory_usage(node: PD_Client_Obj) -> tuple: - try: - node_url = f"http://{node.client_ip_port}/token_load" - timeout = aiohttp.ClientTimeout(total=3.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(node_url) as response: - if response.status == 200: - data = await response.json() - current_load = data.get("current_load", 0.0) - if isinstance(current_load, list): - current_load = current_load[0] if current_load else 0.0 - return node.client_ip_port, float(current_load) - else: - logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") - return node.client_ip_port, float('inf') - except Exception as e: - logger.warning(f"Error getting memory usage from {node.client_ip_port}: {str(e)}") - return node.client_ip_port, float('inf') - - while True: - try: - all_nodes = self.prefill_nodes + self.decode_nodes - tasks = [get_node_memory_usage(node) for node in all_nodes] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 更新缓存 - for result in results: - if isinstance(result, tuple) and len(result) == 2: - node_key, usage = result - self.node_memory_cache[node_key] = usage - - logger.debug(f"Updated memory cache: {self.node_memory_cache}") - # 等待1秒后再次检查 - await asyncio.sleep(1.0) + # 更新选择器的节点列表 + self.selector.prefill_nodes = self.prefill_nodes + self.selector.decode_nodes = self.decode_nodes - except Exception as e: - logger.error(f"Error in memory monitor loop: {str(e)}") - await asyncio.sleep(1.0) + logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") + return async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - assert self.select_p_d_node_func is not None, "select_p_d_node_func is not set" - return await self.select_p_d_node_func(prompt, sampling_params, multimodal_params) - - async def _random_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - import random - - p_node = random.choice(self.prefill_nodes) - d_node = random.choice(self.decode_nodes) - return p_node, d_node - - async def _round_robin_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - p_node = self.prefill_nodes[self.prefill_node_index] - d_node = self.decode_nodes[self.decode_node_index] - self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes) - self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) - return p_node, d_node - - # 哪个节点内存占用最少,就选哪个 - async def _memory_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - # 获取prefill节点的内存使用情况 - prefill_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.prefill_nodes] - decode_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.decode_nodes] - - import random - min_prefill_usage = min(prefill_usages) - min_decode_usage = min(decode_usages) - p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') else random.choice(self.prefill_nodes) - d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') else random.choice(self.decode_nodes) - - return p_node, d_node - - async def _radix_select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass + return await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) class HttpServerManagerForPDMaster: def __init__( @@ -483,8 +393,8 @@ async def handle_loop(self): asyncio.create_task(self.timer_log()) # 如果使用memory策略,确保监控任务已启动 - if hasattr(self.pd_manager, 'memory_monitor_task') and self.args.select_p_d_node_func == "memory": - self.pd_manager._start_memory_monitor() + if isinstance(self.pd_manager.selector, MemorySelector): + self.pd_manager.selector._start_memory_monitor() use_config_server = self.args.config_server_host and self.args.config_server_port diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index e69de29bb..6d22490ce 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -0,0 +1,29 @@ +from typing import List +from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj +from .pd_selector import ( + PDSelector, + RandomSelector, + RoundRobinSelector, + MemorySelector, + RadixSelector +) + +__all__ = [ + "PDSelector", + "RandomSelector", + "RoundRobinSelector", + "MemorySelector", + "RadixSelector" +] + +def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj]) -> PDSelector: + if selector_type == "random": + return RandomSelector(prefill_nodes, decode_nodes) + elif selector_type == "round_robin": + return RoundRobinSelector(prefill_nodes, decode_nodes) + elif selector_type == "memory": + return MemorySelector(prefill_nodes, decode_nodes) + elif selector_type == "radix": + return RadixSelector(prefill_nodes, decode_nodes) + else: + raise ValueError(f"Invalid selector type: {selector_type}") \ No newline at end of file diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index e69de29bb..48d08bfba 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -0,0 +1,127 @@ +from typing import Union, List, Tuple +from lightllm.server.pd_io_struct import PD_Client_Obj +from lightllm.server.core.objs import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams + + +class PDSelector: + def __init__(self, prefill_nodes, decode_nodes): + self.prefill_nodes = prefill_nodes + self.decode_nodes = decode_nodes + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + raise NotImplementedError("Subclass must implement this method") + + +class RandomSelector(PDSelector): + """随机选择器""" + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + import random + + p_node = random.choice(self.prefill_nodes) + d_node = random.choice(self.decode_nodes) + return p_node, d_node + + +class RoundRobinSelector(PDSelector): + """轮询选择器""" + + def __init__(self, prefill_nodes, decode_nodes): + super().__init__(prefill_nodes, decode_nodes) + self.prefill_node_index = 0 + self.decode_node_index = 0 + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + p_node = self.prefill_nodes[self.prefill_node_index] + d_node = self.decode_nodes[self.decode_node_index] + self.prefill_node_index = (self.prefill_node_index + 1) % len(self.prefill_nodes) + self.decode_node_index = (self.decode_node_index + 1) % len(self.decode_nodes) + return p_node, d_node + + +class MemorySelector(PDSelector): + """基于内存使用情况的选择器""" + + def __init__(self, prefill_nodes, decode_nodes): + super().__init__(prefill_nodes, decode_nodes) + # 内存使用情况缓存 + self.node_memory_cache = {} # 节点IP:PORT -> 内存使用率 + self.memory_monitor_task = None + self._start_memory_monitor() + + def _start_memory_monitor(self): + """启动内存监控后台任务""" + import asyncio + + if self.memory_monitor_task is None or self.memory_monitor_task.done(): + try: + loop = asyncio.get_event_loop() + self.memory_monitor_task = loop.create_task(self._memory_monitor_loop()) + from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) + logger.info("Started memory monitoring task") + except RuntimeError: + from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) + logger.warning("No event loop running, memory monitoring will start later") + + async def _memory_monitor_loop(self): + """后台内存监控循环,每秒更新一次所有节点的内存使用情况""" + import aiohttp + import asyncio + from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) + + async def get_node_memory_usage(node: PD_Client_Obj) -> tuple: + node_url = f"http://{node.client_ip_port}/token_load" + timeout = aiohttp.ClientTimeout(total=3.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(node_url) as response: + if response.status == 200: + data = await response.json() + current_load = data.get("current_load", 0.0) + if isinstance(current_load, list): + current_load = current_load[0] if current_load else 0.0 + return node.client_ip_port, float(current_load) + else: + logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") + return node.client_ip_port, float('inf') + + while True: + all_nodes = self.prefill_nodes + self.decode_nodes + tasks = [get_node_memory_usage(node) for node in all_nodes] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 更新缓存 + for result in results: + if isinstance(result, tuple) and len(result) == 2: + node_key, usage = result + self.node_memory_cache[node_key] = usage + + logger.debug(f"Updated memory cache: {self.node_memory_cache}") + + # 等待1秒后再次检查 + await asyncio.sleep(1.0) + + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + # 获取prefill节点的内存使用情况 + prefill_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.prefill_nodes] + decode_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.decode_nodes] + + import random + min_prefill_usage = min(prefill_usages) + min_decode_usage = min(decode_usages) + p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') else random.choice(self.prefill_nodes) + d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') else random.choice(self.decode_nodes) + + return p_node, d_node + + def remove_node_cache(self, node_ip_port: str): + """删除节点的内存缓存""" + self.node_memory_cache.pop(node_ip_port, None) + + +class RadixSelector(PDSelector): + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + pass \ No newline at end of file From 21eb993601c2644ae651aebfb2b1a3fdfc2d64b7 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 17:50:27 +0800 Subject: [PATCH 07/33] fix: typing --- lightllm/server/httpserver_for_pd_master/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index ed67be370..e9277a284 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -12,7 +12,7 @@ import pickle asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Callable +from typing import Union, List, Tuple, Dict from lightllm.server.core.objs import FinishStatus from ..pd_io_struct import PD_Client_Obj, UpKVStatus, ObjType from lightllm.server.core.objs import SamplingParams From 650bcd76982ec5cce0475afb4e59430c72b43edc Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 18:25:38 +0800 Subject: [PATCH 08/33] fix: remove unnecessary func --- lightllm/server/httpserver_for_pd_master/manager.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index e9277a284..8bcf8fba5 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -40,12 +40,9 @@ def __init__(self, args): self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} - self.selector: PDSelector = self._create_selector(args.select_p_d_node_func) + self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) return - def _create_selector(self, select_p_d_node_func_name: str) -> PDSelector: - return create_selector(select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) - def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket From ed73734127652be0907314eb6cba1c4592ea7a9c Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 14 Jul 2025 18:27:58 +0800 Subject: [PATCH 09/33] feat: mroe clear --- .../server/httpserver_for_pd_master/manager.py | 16 ++++++---------- .../pd_selector/pd_selector.py | 4 ++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 8bcf8fba5..f880f7d04 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -43,7 +43,7 @@ def __init__(self, args): self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) return - def register_pd(self, pd_info_json, websocket): + async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client @@ -56,14 +56,12 @@ def register_pd(self, pd_info_json, websocket): else: assert False - # 更新选择器的节点列表 - self.selector.prefill_nodes = self.prefill_nodes - self.selector.decode_nodes = self.decode_nodes + await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} registed") return - def remove_pd(self, pd_info_json): + async def remove_pd(self, pd_info_json): pd_client = PD_Client_Obj(**pd_info_json) try: del self.url_to_pd_nodes[pd_client.client_ip_port] @@ -77,9 +75,7 @@ def remove_pd(self, pd_info_json): self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] - # 更新选择器的节点列表 - self.selector.prefill_nodes = self.prefill_nodes - self.selector.decode_nodes = self.decode_nodes + await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return @@ -109,11 +105,11 @@ def __init__( return async def register_pd(self, pd_info_json, websocket): - self.pd_manager.register_pd(pd_info_json, websocket) + await self.pd_manager.register_pd(pd_info_json, websocket) return async def remove_pd(self, pd_info_json): - self.pd_manager.remove_pd(pd_info_json) + await self.pd_manager.remove_pd(pd_info_json) return async def update_req_status(self, upkv_status: UpKVStatus): diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 48d08bfba..1f101ff99 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -9,6 +9,10 @@ def __init__(self, prefill_nodes, decode_nodes): self.prefill_nodes = prefill_nodes self.decode_nodes = decode_nodes + async def update_nodes(self, prefill_nodes, decode_nodes): + self.prefill_nodes = prefill_nodes + self.decode_nodes = decode_nodes + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: raise NotImplementedError("Subclass must implement this method") From 5e193e13ca6cd6c847e44ae6a9140de5f371a8d5 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 15 Jul 2025 14:49:01 +0800 Subject: [PATCH 10/33] feat: add load info in pd_master --- lightllm/server/httpserver/pd_loop.py | 20 ++++++- .../httpserver_for_pd_master/manager.py | 52 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 10a4a8ec5..565e4e633 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -181,4 +181,22 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): while True: handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list))) + # 获取节点负载信息 + from lightllm.server.api_http import g_objs + if g_objs.shared_token_load is not None: + current_load = [ + float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp) + ] + if g_objs.args.dp == 1: + current_load = current_load[0] + load_info = { + "current_load": current_load, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + else: + load_info = { + "current_load": 0.0, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + + await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index f880f7d04..afd54e95c 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -37,16 +37,23 @@ class PDManager: def __init__(self, args): self.args = args + self.node_info: Dict[str, dict] = {} self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] - self.url_to_pd_nodes: Dict[str, PD_Client_Obj] = {} self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) return async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket - self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client + self.node_info[pd_client.client_ip_port] = { + "node_id": pd_info_json["pd_node_id"], + "client_ip_port": pd_info_json["client_ip_port"], + "mode": pd_info_json["mode"], + "node": pd_client, + "load": 0.0, + } + if pd_client.mode == "prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.prefill_nodes.append(pd_client) @@ -54,7 +61,7 @@ async def register_pd(self, pd_info_json, websocket): self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes.append(pd_client) else: - assert False + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" await self.selector.update_nodes(self.prefill_nodes, self.decode_nodes) @@ -64,10 +71,13 @@ async def register_pd(self, pd_info_json, websocket): async def remove_pd(self, pd_info_json): pd_client = PD_Client_Obj(**pd_info_json) try: - del self.url_to_pd_nodes[pd_client.client_ip_port] + del self.node_info[pd_client.client_ip_port] except: pass + if pd_client.client_ip_port in self.node_info: + del self.node_info[pd_client.client_ip_port] + # 从内存缓存中删除该节点的数据(如果是MemorySelector) if isinstance(self.selector, MemorySelector): self.selector.remove_node_cache(pd_client.client_ip_port) @@ -80,6 +90,21 @@ async def remove_pd(self, pd_info_json): logger.info(f"mode: {pd_client.mode} url: {pd_client.client_ip_port} removed") return + def update_node_load_info(self, load_info: dict): + """更新节点负载信息""" + if load_info is not None and "client_ip_port" in load_info: + ip_port = load_info["client_ip_port"] + self.node_info[ip_port]["load"] = load_info["current_load"] + logger.debug(f"Updated node load info for {ip_port}: {load_info['current_load']}") + + def get_node_load_info(self): + """获取所有节点的负载信息""" + return {k: v.get("load", float("inf")) for k, v in self.node_info.items()} + + def get_node_load_info_by_node(self, client_ip_port: str): + """获取指定节点的负载信息""" + return self.node_info.get(client_ip_port, None).get("load", float("inf")) + async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: return await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) @@ -381,6 +406,14 @@ async def timer_log(self): async def put_to_handle_queue(self, obj): await self.infos_queues.put(obj) + def get_node_load_info(self): + """获取所有节点的负载信息""" + return self.pd_manager.get_node_load_info() + + def get_node_load_info_by_node(self, client_ip_port: str): + """获取指定节点的负载信息""" + return self.pd_manager.get_node_load_info_by_node(client_ip_port) + async def handle_loop(self): self.infos_queues = AsyncQueue() asyncio.create_task(self.timer_log()) @@ -402,7 +435,16 @@ async def handle_loop(self): try: for obj in objs: if obj[0] == ObjType.TOKEN_PACKS: - for sub_req_id, text, metadata, finish_status in obj[1]: + # 检查是否包含节点信息 + if len(obj) >= 3: + handle_list, load_info = obj[1], obj[2] + # 更新节点负载信息 + self.pd_manager.update_node_load_info(load_info) + else: + # 兼容旧格式 + handle_list = obj[1] + + for sub_req_id, text, metadata, finish_status in handle_list: finish_status: FinishStatus = finish_status group_req_id = convert_sub_id_to_group_id(sub_req_id) try: From 5748672b1b488277b5c3cee821c303c4eb2f10a4 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 15 Jul 2025 14:57:07 +0800 Subject: [PATCH 11/33] feat: change pd mem selector to pd_master load info --- .../httpserver_for_pd_master/manager.py | 32 ++++-- .../pd_selector/pd_selector.py | 104 ++++-------------- 2 files changed, 44 insertions(+), 92 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index afd54e95c..a4922978f 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -40,7 +40,7 @@ def __init__(self, args): self.node_info: Dict[str, dict] = {} self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] - self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes) + self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes, self) return async def register_pd(self, pd_info_json, websocket): @@ -78,10 +78,6 @@ async def remove_pd(self, pd_info_json): if pd_client.client_ip_port in self.node_info: del self.node_info[pd_client.client_ip_port] - # 从内存缓存中删除该节点的数据(如果是MemorySelector) - if isinstance(self.selector, MemorySelector): - self.selector.remove_node_cache(pd_client.client_ip_port) - self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] @@ -94,8 +90,20 @@ def update_node_load_info(self, load_info: dict): """更新节点负载信息""" if load_info is not None and "client_ip_port" in load_info: ip_port = load_info["client_ip_port"] - self.node_info[ip_port]["load"] = load_info["current_load"] - logger.debug(f"Updated node load info for {ip_port}: {load_info['current_load']}") + if ip_port in self.node_info: + current_load = load_info["current_load"] + if isinstance(current_load, list): + if len(current_load) > 0: + load_value = float(current_load[0]) + else: + load_value = 0.0 + else: + load_value = float(current_load) + + self.node_info[ip_port]["load"] = load_value + logger.debug(f"Updated node load info for {ip_port}: {load_value}") + else: + logger.warning(f"Received load info for unknown node: {ip_port}") def get_node_load_info(self): """获取所有节点的负载信息""" @@ -103,7 +111,11 @@ def get_node_load_info(self): def get_node_load_info_by_node(self, client_ip_port: str): """获取指定节点的负载信息""" - return self.node_info.get(client_ip_port, None).get("load", float("inf")) + node_info = self.node_info.get(client_ip_port, None) + if node_info is not None: + return node_info.get("load", float("inf")) + else: + return float("inf") async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: return await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) @@ -418,10 +430,6 @@ async def handle_loop(self): self.infos_queues = AsyncQueue() asyncio.create_task(self.timer_log()) - # 如果使用memory策略,确保监控任务已启动 - if isinstance(self.pd_manager.selector, MemorySelector): - self.pd_manager.selector._start_memory_monitor() - use_config_server = self.args.config_server_host and self.args.config_server_port if use_config_server: diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 1f101ff99..23d9a2b0a 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -2,12 +2,14 @@ from lightllm.server.pd_io_struct import PD_Client_Obj from lightllm.server.core.objs import SamplingParams from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.httpserver_for_pd_master.manager import PDManager class PDSelector: - def __init__(self, prefill_nodes, decode_nodes): - self.prefill_nodes = prefill_nodes - self.decode_nodes = decode_nodes + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager): + self.prefill_nodes: List[PD_Client_Obj] = prefill_nodes + self.decode_nodes: List[PD_Client_Obj] = decode_nodes + self.pd_manager: PDManager = pd_manager async def update_nodes(self, prefill_nodes, decode_nodes): self.prefill_nodes = prefill_nodes @@ -31,10 +33,10 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: class RoundRobinSelector(PDSelector): """轮询选择器""" - def __init__(self, prefill_nodes, decode_nodes): - super().__init__(prefill_nodes, decode_nodes) - self.prefill_node_index = 0 - self.decode_node_index = 0 + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager): + super().__init__(prefill_nodes, decode_nodes, pd_manager) + self.prefill_node_index: int = 0 + self.decode_node_index: int = 0 async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: p_node = self.prefill_nodes[self.prefill_node_index] @@ -47,85 +49,27 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: class MemorySelector(PDSelector): """基于内存使用情况的选择器""" - def __init__(self, prefill_nodes, decode_nodes): - super().__init__(prefill_nodes, decode_nodes) - # 内存使用情况缓存 - self.node_memory_cache = {} # 节点IP:PORT -> 内存使用率 - self.memory_monitor_task = None - self._start_memory_monitor() - - def _start_memory_monitor(self): - """启动内存监控后台任务""" - import asyncio - - if self.memory_monitor_task is None or self.memory_monitor_task.done(): - try: - loop = asyncio.get_event_loop() - self.memory_monitor_task = loop.create_task(self._memory_monitor_loop()) - from lightllm.utils.log_utils import init_logger - logger = init_logger(__name__) - logger.info("Started memory monitoring task") - except RuntimeError: - from lightllm.utils.log_utils import init_logger - logger = init_logger(__name__) - logger.warning("No event loop running, memory monitoring will start later") - - async def _memory_monitor_loop(self): - """后台内存监控循环,每秒更新一次所有节点的内存使用情况""" - import aiohttp - import asyncio - from lightllm.utils.log_utils import init_logger - logger = init_logger(__name__) - - async def get_node_memory_usage(node: PD_Client_Obj) -> tuple: - node_url = f"http://{node.client_ip_port}/token_load" - timeout = aiohttp.ClientTimeout(total=3.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(node_url) as response: - if response.status == 200: - data = await response.json() - current_load = data.get("current_load", 0.0) - if isinstance(current_load, list): - current_load = current_load[0] if current_load else 0.0 - return node.client_ip_port, float(current_load) - else: - logger.warning(f"Failed to get token_load from {node.client_ip_port}, status: {response.status}") - return node.client_ip_port, float('inf') - - while True: - all_nodes = self.prefill_nodes + self.decode_nodes - tasks = [get_node_memory_usage(node) for node in all_nodes] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 更新缓存 - for result in results: - if isinstance(result, tuple) and len(result) == 2: - node_key, usage = result - self.node_memory_cache[node_key] = usage - - logger.debug(f"Updated memory cache: {self.node_memory_cache}") - - # 等待1秒后再次检查 - await asyncio.sleep(1.0) - async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - # 获取prefill节点的内存使用情况 - prefill_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.prefill_nodes] - decode_usages = [self.node_memory_cache.get(node.client_ip_port, float('inf')) for node in self.decode_nodes] + if self.pd_manager is None: + # 如果没有 PDManager 引用,回退到随机选择 + import random + p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None + d_node = random.choice(self.decode_nodes) if self.decode_nodes else None + return p_node, d_node + + # 获取 prefill 节点的内存使用情况 + prefill_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.prefill_nodes] + decode_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.decode_nodes] import random - min_prefill_usage = min(prefill_usages) - min_decode_usage = min(decode_usages) - p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') else random.choice(self.prefill_nodes) - d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') else random.choice(self.decode_nodes) + min_prefill_usage = min(prefill_usages) if prefill_usages else float('inf') + min_decode_usage = min(decode_usages) if decode_usages else float('inf') + + p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') and prefill_usages else random.choice(self.prefill_nodes) + d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') and decode_usages else random.choice(self.decode_nodes) return p_node, d_node - def remove_node_cache(self, node_ip_port: str): - """删除节点的内存缓存""" - self.node_memory_cache.pop(node_ip_port, None) - - class RadixSelector(PDSelector): async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: pass \ No newline at end of file From d9b53ff9ac4046d38331c8b93ca8ed8894b8f9d0 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 15 Jul 2025 14:58:02 +0800 Subject: [PATCH 12/33] fix: typo --- .../server/httpserver_for_pd_master/pd_selector/pd_selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 23d9a2b0a..ff2854a5c 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -72,4 +72,4 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: class RadixSelector(PDSelector): async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass \ No newline at end of file + pass From f7315322fa1b7d2d1f2fc232d4c10160b5cb0200 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Wed, 16 Jul 2025 11:08:37 +0800 Subject: [PATCH 13/33] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=E4=B8=BA?= =?UTF-8?q?=E5=8F=AA=E6=9C=89=E5=9C=A8=E6=9C=89=E8=AF=B7=E6=B1=82=E7=BB=93?= =?UTF-8?q?=E6=9D=9F=E5=90=8E=E6=89=8D=E6=9B=B4=E6=96=B0=E8=B4=9F=E8=BD=BD?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightllm/server/httpserver/pd_loop.py | 40 ++++++++++--------- .../httpserver_for_pd_master/manager.py | 5 ++- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index 565e4e633..dd4459d0e 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -180,23 +180,27 @@ async def _pd_process_generate( async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): while True: handle_list = await forwarding_queue.wait_to_get_all_data() + if handle_list: - # 获取节点负载信息 - from lightllm.server.api_http import g_objs - if g_objs.shared_token_load is not None: - current_load = [ - float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp) - ] - if g_objs.args.dp == 1: - current_load = current_load[0] - load_info = { - "current_load": current_load, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } + has_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) + if has_finished_req: + # 获取节点负载信息 + from lightllm.server.api_http import g_objs + if g_objs.shared_token_load is not None: + current_load = [ + float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp) + ] + if g_objs.args.dp == 1: + current_load = current_load[0] + load_info = { + "current_load": current_load, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + else: + load_info = { + "current_load": 0.0, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) else: - load_info = { - "current_load": 0.0, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } - - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) + await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, None))) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index a4922978f..199cb3c4f 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -88,7 +88,10 @@ async def remove_pd(self, pd_info_json): def update_node_load_info(self, load_info: dict): """更新节点负载信息""" - if load_info is not None and "client_ip_port" in load_info: + if load_info is None: + return + + if "client_ip_port" in load_info: ip_port = load_info["client_ip_port"] if ip_port in self.node_info: current_load = load_info["current_load"] From 6922a6616600b300006873373a87126c2e1b061b Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Wed, 16 Jul 2025 11:30:30 +0800 Subject: [PATCH 14/33] fix: create func para --- .../httpserver_for_pd_master/pd_selector/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index 6d22490ce..363846020 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -1,5 +1,5 @@ from typing import List -from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj +from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj, PDManager from .pd_selector import ( PDSelector, RandomSelector, @@ -16,14 +16,14 @@ "RadixSelector" ] -def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj]) -> PDSelector: +def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager) -> PDSelector: if selector_type == "random": - return RandomSelector(prefill_nodes, decode_nodes) + return RandomSelector(prefill_nodes, decode_nodes, pd_manager) elif selector_type == "round_robin": - return RoundRobinSelector(prefill_nodes, decode_nodes) + return RoundRobinSelector(prefill_nodes, decode_nodes, pd_manager) elif selector_type == "memory": - return MemorySelector(prefill_nodes, decode_nodes) + return MemorySelector(prefill_nodes, decode_nodes, pd_manager) elif selector_type == "radix": - return RadixSelector(prefill_nodes, decode_nodes) + return RadixSelector(prefill_nodes, decode_nodes, pd_manager) else: raise ValueError(f"Invalid selector type: {selector_type}") \ No newline at end of file From 917e30af2b28ccaa3bdd18fd85898275a87eeeac Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Wed, 16 Jul 2025 14:16:56 +0800 Subject: [PATCH 15/33] fix: typo --- .../server/httpserver_for_pd_master/pd_selector/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index 363846020..7463a21c1 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -26,4 +26,4 @@ def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], deco elif selector_type == "radix": return RadixSelector(prefill_nodes, decode_nodes, pd_manager) else: - raise ValueError(f"Invalid selector type: {selector_type}") \ No newline at end of file + raise ValueError(f"Invalid selector type: {selector_type}") From 3b250be0ba55c58ae44b4e13c649d2c6397f0217 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 17 Jul 2025 16:01:24 +0800 Subject: [PATCH 16/33] fix: import error --- build_and_upload_docker.sh | 6 +++++- env.sh | 3 +++ .../httpserver_for_pd_master/pd_selector/__init__.py | 4 ++-- .../pd_selector/pd_selector.py | 7 +++---- requirements.txt | 2 +- server_d.sh | 11 +++++++++++ server_master.sh | 6 ++++++ server_p.sh | 11 +++++++++++ test.sh | 7 +++++++ 9 files changed, 49 insertions(+), 8 deletions(-) create mode 100755 env.sh create mode 100644 server_d.sh create mode 100644 server_master.sh create mode 100644 server_p.sh create mode 100644 test.sh diff --git a/build_and_upload_docker.sh b/build_and_upload_docker.sh index 0b1897316..fc7fd871f 100755 --- a/build_and_upload_docker.sh +++ b/build_and_upload_docker.sh @@ -17,5 +17,9 @@ fi IMAGE_TAG=$2 ACCOUNT=$1 aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com -DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . +DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG + +#deepep +DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.deepep -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep . +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep \ No newline at end of file diff --git a/env.sh b/env.sh new file mode 100755 index 000000000..151b27c47 --- /dev/null +++ b/env.sh @@ -0,0 +1,3 @@ +source /mtc/bianzhuohang/miniconda3/bin/activate +conda activate lightllm_router +clear \ No newline at end of file diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index 7463a21c1..392b6ea00 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -1,5 +1,5 @@ from typing import List -from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj, PDManager +from lightllm.server.httpserver_for_pd_master.manager import PD_Client_Obj from .pd_selector import ( PDSelector, RandomSelector, @@ -16,7 +16,7 @@ "RadixSelector" ] -def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager) -> PDSelector: +def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager) -> PDSelector: if selector_type == "random": return RandomSelector(prefill_nodes, decode_nodes, pd_manager) elif selector_type == "round_robin": diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index ff2854a5c..669eae2e9 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -2,14 +2,13 @@ from lightllm.server.pd_io_struct import PD_Client_Obj from lightllm.server.core.objs import SamplingParams from lightllm.server.multimodal_params import MultimodalParams -from lightllm.server.httpserver_for_pd_master.manager import PDManager class PDSelector: - def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager): + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager): self.prefill_nodes: List[PD_Client_Obj] = prefill_nodes self.decode_nodes: List[PD_Client_Obj] = decode_nodes - self.pd_manager: PDManager = pd_manager + self.pd_manager = pd_manager async def update_nodes(self, prefill_nodes, decode_nodes): self.prefill_nodes = prefill_nodes @@ -33,7 +32,7 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: class RoundRobinSelector(PDSelector): """轮询选择器""" - def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager: PDManager): + def __init__(self, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager): super().__init__(prefill_nodes, decode_nodes, pd_manager) self.prefill_node_index: int = 0 self.decode_node_index: int = 0 diff --git a/requirements.txt b/requirements.txt index 1febb64f1..f5378b7f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -83,7 +83,7 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 gunicorn==23.0.0 -vllm==0.8.5 +# vllm==0.8.5 flashinfer-python==0.2.4 sgl-kernel==0.1.4 httpx==0.28.1 diff --git a/server_d.sh b/server_d.sh new file mode 100644 index 000000000..a90ff22c6 --- /dev/null +++ b/server_d.sh @@ -0,0 +1,11 @@ +CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=1 python3 -m lightllm.server.api_server \ + --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ + --run_mode "decode" \ + --host 127.0.1.1 \ + --port 8118 \ + --nccl_port 12322 \ + --tp 1 \ + --tokenizer_mode fast \ + --pd_master_ip 127.0.0.1 \ + --pd_master_port 60011 \ + --pd_decode_rpyc_port 42020 diff --git a/server_master.sh b/server_master.sh new file mode 100644 index 000000000..b6088d71a --- /dev/null +++ b/server_master.sh @@ -0,0 +1,6 @@ +python3 -m lightllm.server.api_server \ + --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ + --run_mode "pd_master" \ + --host 127.0.0.1 \ + --port 60011 \ + --select_p_d_node_func random \ No newline at end of file diff --git a/server_p.sh b/server_p.sh new file mode 100644 index 000000000..652d0e57e --- /dev/null +++ b/server_p.sh @@ -0,0 +1,11 @@ +CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python3 -m lightllm.server.api_server \ + --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ + --run_mode "prefill" \ + --host 127.0.1.1 \ + --port 8079 \ + --tp 1 \ + --nccl_port 2769 \ + --tokenizer_mode fast \ + --pd_master_ip 127.0.0.1 \ + --pd_master_port 60011 \ + --disable_cudagraph \ No newline at end of file diff --git a/test.sh b/test.sh new file mode 100644 index 000000000..cc2e49ae3 --- /dev/null +++ b/test.sh @@ -0,0 +1,7 @@ +python3 test/benchmark/service/benchmark_client.py \ + --url http://127.0.0.1:60011/generate \ + --tokenizer_path /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ + --server_api lightllm \ + --dump_file result.json \ + --seed 42 \ + --dump_file pd_random.json \ No newline at end of file From fb053ee4ce698f73b8c426e5bdee798bd994e980 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 17 Jul 2025 16:04:47 +0800 Subject: [PATCH 17/33] fix: arg name --- lightllm/server/httpserver_for_pd_master/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 199cb3c4f..340aa3453 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -40,7 +40,7 @@ def __init__(self, args): self.node_info: Dict[str, dict] = {} self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] - self.selector: PDSelector = create_selector(args.select_p_d_node_func_name, self.prefill_nodes, self.decode_nodes, self) + self.selector: PDSelector = create_selector(args.select_p_d_node_func, self.prefill_nodes, self.decode_nodes, self) return async def register_pd(self, pd_info_json, websocket): From da45d06871cbbd2895a3aaab3814258f9e3635f6 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Fri, 18 Jul 2025 11:00:22 +0800 Subject: [PATCH 18/33] fix: dict key --- lightllm/server/httpserver_for_pd_master/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 340aa3453..5436e7a5d 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -47,7 +47,7 @@ async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket self.node_info[pd_client.client_ip_port] = { - "node_id": pd_info_json["pd_node_id"], + "node_id": pd_info_json["node_id"], "client_ip_port": pd_info_json["client_ip_port"], "mode": pd_info_json["mode"], "node": pd_client, From e983ce8303004a0c1502d0b32d09c36f962b76ec Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 28 Jul 2025 17:20:43 +0800 Subject: [PATCH 19/33] remove unnecessary files --- requirements.txt | 2 +- server_d.sh | 11 ----------- server_master.sh | 6 ------ server_p.sh | 11 ----------- test.sh | 7 ------- 5 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 server_d.sh delete mode 100644 server_master.sh delete mode 100644 server_p.sh delete mode 100644 test.sh diff --git a/requirements.txt b/requirements.txt index f5378b7f9..1febb64f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -83,7 +83,7 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 gunicorn==23.0.0 -# vllm==0.8.5 +vllm==0.8.5 flashinfer-python==0.2.4 sgl-kernel==0.1.4 httpx==0.28.1 diff --git a/server_d.sh b/server_d.sh deleted file mode 100644 index a90ff22c6..000000000 --- a/server_d.sh +++ /dev/null @@ -1,11 +0,0 @@ -CUDA_VISIBLE_DEVICES=1 KV_TRANS_USE_P2P=1 LOADWORKER=1 python3 -m lightllm.server.api_server \ - --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ - --run_mode "decode" \ - --host 127.0.1.1 \ - --port 8118 \ - --nccl_port 12322 \ - --tp 1 \ - --tokenizer_mode fast \ - --pd_master_ip 127.0.0.1 \ - --pd_master_port 60011 \ - --pd_decode_rpyc_port 42020 diff --git a/server_master.sh b/server_master.sh deleted file mode 100644 index b6088d71a..000000000 --- a/server_master.sh +++ /dev/null @@ -1,6 +0,0 @@ -python3 -m lightllm.server.api_server \ - --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ - --run_mode "pd_master" \ - --host 127.0.0.1 \ - --port 60011 \ - --select_p_d_node_func random \ No newline at end of file diff --git a/server_p.sh b/server_p.sh deleted file mode 100644 index 652d0e57e..000000000 --- a/server_p.sh +++ /dev/null @@ -1,11 +0,0 @@ -CUDA_VISIBLE_DEVICES=0 KV_TRANS_USE_P2P=1 LOADWORKER=1 python3 -m lightllm.server.api_server \ - --model_dir /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ - --run_mode "prefill" \ - --host 127.0.1.1 \ - --port 8079 \ - --tp 1 \ - --nccl_port 2769 \ - --tokenizer_mode fast \ - --pd_master_ip 127.0.0.1 \ - --pd_master_port 60011 \ - --disable_cudagraph \ No newline at end of file diff --git a/test.sh b/test.sh deleted file mode 100644 index cc2e49ae3..000000000 --- a/test.sh +++ /dev/null @@ -1,7 +0,0 @@ -python3 test/benchmark/service/benchmark_client.py \ - --url http://127.0.0.1:60011/generate \ - --tokenizer_path /mtc/bianzhuohang/models/Qwen/Qwen2.5-14B \ - --server_api lightllm \ - --dump_file result.json \ - --seed 42 \ - --dump_file pd_random.json \ No newline at end of file From e68aca3988a80b49a1f1b95933437436eb241736 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 28 Jul 2025 17:21:35 +0800 Subject: [PATCH 20/33] docker --- build_and_upload_docker.sh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/build_and_upload_docker.sh b/build_and_upload_docker.sh index fc7fd871f..281e33c43 100755 --- a/build_and_upload_docker.sh +++ b/build_and_upload_docker.sh @@ -17,9 +17,5 @@ fi IMAGE_TAG=$2 ACCOUNT=$1 aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com -DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . -docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG - -#deepep -DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.deepep -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep . -docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep \ No newline at end of file +DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG \ No newline at end of file From ef6b024174a9f1f84d75ce5c9d70038951ea6623 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 28 Jul 2025 17:23:10 +0800 Subject: [PATCH 21/33] docker main --- build_and_upload_docker.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/build_and_upload_docker.sh b/build_and_upload_docker.sh index 281e33c43..fc7fd871f 100755 --- a/build_and_upload_docker.sh +++ b/build_and_upload_docker.sh @@ -17,5 +17,9 @@ fi IMAGE_TAG=$2 ACCOUNT=$1 aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com -DOCKER_BUILDKIT=1 docker build -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . -docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG \ No newline at end of file +DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG . +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG + +#deepep +DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.deepep -t $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep . +docker push $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/lightllm:$IMAGE_TAG-deepep \ No newline at end of file From b95ac63a68aaa9357688c21bbac1ee4c463ca0ed Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Mon, 28 Jul 2025 17:24:59 +0800 Subject: [PATCH 22/33] remove env.sh --- env.sh | 3 --- 1 file changed, 3 deletions(-) delete mode 100755 env.sh diff --git a/env.sh b/env.sh deleted file mode 100755 index 151b27c47..000000000 --- a/env.sh +++ /dev/null @@ -1,3 +0,0 @@ -source /mtc/bianzhuohang/miniconda3/bin/activate -conda activate lightllm_router -clear \ No newline at end of file From f6060f1e2855050fba8219df56ef0d89427ea961 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 29 Jul 2025 17:12:37 +0800 Subject: [PATCH 23/33] =?UTF-8?q?feat:=20=E4=BD=BF=E7=94=A8=E9=A2=84?= =?UTF-8?q?=E6=B5=8B=E4=BF=A1=E6=81=AF=E8=BF=9B=E8=A1=8C=EF=BC=8C=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E6=8B=A5=E5=A1=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightllm/server/httpserver/pd_loop.py | 6 +- .../httpserver_for_pd_master/manager.py | 54 ++++-------- .../node_info_recorder.py | 82 +++++++++++++++++++ .../pd_selector/pd_selector.py | 27 ++++-- 4 files changed, 117 insertions(+), 52 deletions(-) create mode 100644 lightllm/server/httpserver_for_pd_master/node_info_recorder.py diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index dd4459d0e..a57e92b3a 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -188,17 +188,17 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): from lightllm.server.api_http import g_objs if g_objs.shared_token_load is not None: current_load = [ - float(g_objs.shared_token_load.get_current_load(dp_index)) for dp_index in range(g_objs.args.dp) + float(g_objs.shared_token_load.get_frozened_token_count(dp_index)) for dp_index in range(g_objs.args.dp) ] if g_objs.args.dp == 1: current_load = current_load[0] load_info = { - "current_load": current_load, + "mem_len": current_load, "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" } else: load_info = { - "current_load": 0.0, + "mem_len": 0, "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" } await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 5436e7a5d..d30ce92a7 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -25,10 +25,10 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue from lightllm.utils.error_utils import ServerBusyError +from .node_info_recorder import PredictNodeInfoRecorder from .pd_selector import ( create_selector, PDSelector, - MemorySelector, ) logger = init_logger(__name__) @@ -37,22 +37,16 @@ class PDManager: def __init__(self, args): self.args = args - self.node_info: Dict[str, dict] = {} self.prefill_nodes: List[PD_Client_Obj] = [] self.decode_nodes: List[PD_Client_Obj] = [] + self.node_info_recorder: PredictNodeInfoRecorder = PredictNodeInfoRecorder() self.selector: PDSelector = create_selector(args.select_p_d_node_func, self.prefill_nodes, self.decode_nodes, self) return async def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) pd_client.websocket = websocket - self.node_info[pd_client.client_ip_port] = { - "node_id": pd_info_json["node_id"], - "client_ip_port": pd_info_json["client_ip_port"], - "mode": pd_info_json["mode"], - "node": pd_client, - "load": 0.0, - } + self.node_info_recorder.register_node(pd_client) if pd_client.mode == "prefill": self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] @@ -70,13 +64,7 @@ async def register_pd(self, pd_info_json, websocket): async def remove_pd(self, pd_info_json): pd_client = PD_Client_Obj(**pd_info_json) - try: - del self.node_info[pd_client.client_ip_port] - except: - pass - - if pd_client.client_ip_port in self.node_info: - del self.node_info[pd_client.client_ip_port] + self.node_info_recorder.remove_node(pd_client) self.prefill_nodes = [e for e in self.prefill_nodes if e.client_ip_port != pd_client.client_ip_port] self.decode_nodes = [e for e in self.decode_nodes if e.client_ip_port != pd_client.client_ip_port] @@ -90,38 +78,24 @@ def update_node_load_info(self, load_info: dict): """更新节点负载信息""" if load_info is None: return - - if "client_ip_port" in load_info: - ip_port = load_info["client_ip_port"] - if ip_port in self.node_info: - current_load = load_info["current_load"] - if isinstance(current_load, list): - if len(current_load) > 0: - load_value = float(current_load[0]) - else: - load_value = 0.0 - else: - load_value = float(current_load) - - self.node_info[ip_port]["load"] = load_value - logger.debug(f"Updated node load info for {ip_port}: {load_value}") - else: - logger.warning(f"Received load info for unknown node: {ip_port}") + self.node_info_recorder.update_node_load_info(load_info) def get_node_load_info(self): """获取所有节点的负载信息""" - return {k: v.get("load", float("inf")) for k, v in self.node_info.items()} + return self.node_info_recorder.get_node_infos() + + def get_predict_node_infos(self): + """获取所有节点的预测负载信息""" + return self.node_info_recorder.get_predict_node_infos() def get_node_load_info_by_node(self, client_ip_port: str): """获取指定节点的负载信息""" - node_info = self.node_info.get(client_ip_port, None) - if node_info is not None: - return node_info.get("load", float("inf")) - else: - return float("inf") + return self.node_info_recorder.get_node_info(client_ip_port) async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - return await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) + p_node, d_node = await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) + self.node_info_recorder.update_predict_node_info(p_node, d_node, prompt, sampling_params, multimodal_params) + return p_node, d_node class HttpServerManagerForPDMaster: def __init__( diff --git a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py new file mode 100644 index 000000000..68561699c --- /dev/null +++ b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py @@ -0,0 +1,82 @@ +import copy + +from ..pd_io_struct import PD_Client_Obj +from lightllm.server.httpserver.manager import SamplingParams, MultimodalParams +from typing import Union, List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class NodeInfoRecorder: + def __init__(self): + self.node_info = {} + + def register_node(self, pd_client: PD_Client_Obj): + self.node_info[pd_client.client_ip_port] = { + "node_id": pd_client.node_id, + "client_ip_port": pd_client.client_ip_port, + "mode": pd_client.mode, + "node": pd_client, + "mem_len": 0, + # "batch_size": 0, + } + + def remove_node(self, pd_client: PD_Client_Obj): + del self.node_info[pd_client.client_ip_port] + + def update_node_load_info(self, load_info: dict): + if "client_ip_port" in load_info: + ip_port = load_info["client_ip_port"] + if ip_port in self.node_info: + mem_len = load_info["mem_len"] + # batch_size = load_info["batch_size"] + self.node_info[ip_port]["mem_len"] = mem_len + # self.node_info[ip_port]["batch_size"] = batch_size + logger.debug(f"Updated node load info for {ip_port}: {mem_len}") + # logger.debug(f"Updated node load info for {ip_port}: {mem_len}, {batch_size}") + else: + logger.warning(f"Received load info for unknown node: {ip_port}") + else: + logger.warning(f"Received load info without client_ip_port") + + def get_node_infos(self): + return {k: { + "mem_len": v.get("mem_len", int("inf")), + # "batch_size": v.get("batch_size", float("inf")), + } for k, v in self.node_info.items()} + + def get_node_info(self, client_ip_port: str): + return self.node_info.get(client_ip_port, None) + +class PredictNodeInfoRecorder(NodeInfoRecorder): + def __init__(self): + super().__init__() + self.predict_node_info = {} + + def register_node(self, pd_client: PD_Client_Obj): + super().register_node(pd_client) + self.predict_node_info[pd_client.client_ip_port] = copy.copy(self.node_info[pd_client.client_ip_port]) + + def remove_node(self, pd_client: PD_Client_Obj): + super().remove_node(pd_client) + del self.predict_node_info[pd_client.client_ip_port] + + def update_node_load_info(self, load_info: dict): + super().update_node_load_info(load_info) + self.predict_node_info[load_info["client_ip_port"]] = copy.copy(self.node_info[load_info["client_ip_port"]]) + + def update_predict_node_info(self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams): + self.predict_node_info[p_node.client_ip_port]["mem_len"] += len(prompt) + # self.predict_node_info[p_node.client_ip_port]["batch_size"] += 1 + self.predict_node_info[d_node.client_ip_port]["mem_len"] += sampling_params.max_new_tokens + # self.predict_node_info[d_node.client_ip_port]["batch_size"] += 1 + + def get_predict_node_infos(self): + return {k: { + "mem_len": v.get("mem_len", float("inf")), + # "batch_size": v.get("batch_size", float("inf")), + } for k, v in self.predict_node_info.items()} + + def get_predict_node_info(self, client_ip_port: str): + return self.predict_node_info.get(client_ip_port, None) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 669eae2e9..1571ffe5d 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -49,6 +49,15 @@ class MemorySelector(PDSelector): """基于内存使用情况的选择器""" async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: + def _get_min_node(node_infos: dict): + min_node, min_node_len = None, float("inf") + for node_ip, node_info in node_infos.items(): + if node_info["mem_len"] < float("inf"): + if node_info["mem_len"] < min_node_len: + min_node_len = node_info["mem_len"] + min_node = node_ip + return min_node + if self.pd_manager is None: # 如果没有 PDManager 引用,回退到随机选择 import random @@ -56,16 +65,16 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node - # 获取 prefill 节点的内存使用情况 - prefill_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.prefill_nodes] - decode_usages = [self.pd_manager.get_node_load_info_by_node(node.client_ip_port) for node in self.decode_nodes] + node_infos = self.pd_manager.get_predict_node_infos() + node_infos = {k: v for k, v in node_infos.items() if v["mem_len"] < float("inf")} + if len(node_infos) == 0: + return random.choice(self.prefill_nodes), random.choice(self.decode_nodes) - import random - min_prefill_usage = min(prefill_usages) if prefill_usages else float('inf') - min_decode_usage = min(decode_usages) if decode_usages else float('inf') - - p_node = self.prefill_nodes[prefill_usages.index(min_prefill_usage)] if min_prefill_usage != float('inf') and prefill_usages else random.choice(self.prefill_nodes) - d_node = self.decode_nodes[decode_usages.index(min_decode_usage)] if min_decode_usage != float('inf') and decode_usages else random.choice(self.decode_nodes) + # 获取负载最小的节点 + p_node_infos = {k: v for k, v in node_infos.items() if k in self.prefill_nodes} + d_node_infos = {k: v for k, v in node_infos.items() if k in self.decode_nodes} + p_node = _get_min_node(p_node_infos) or random.choice(self.prefill_nodes) + d_node = _get_min_node(d_node_infos) or random.choice(self.decode_nodes) return p_node, d_node From 5155920496332b6c4c637d671b7edecfc72435d7 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 29 Jul 2025 17:14:27 +0800 Subject: [PATCH 24/33] fix: better style --- .../pd_selector/pd_selector.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 1571ffe5d..ef551dbb3 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -49,12 +49,12 @@ class MemorySelector(PDSelector): """基于内存使用情况的选择器""" async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - def _get_min_node(node_infos: dict): + def _get_min_node(node_infos: dict, key: str): min_node, min_node_len = None, float("inf") for node_ip, node_info in node_infos.items(): - if node_info["mem_len"] < float("inf"): - if node_info["mem_len"] < min_node_len: - min_node_len = node_info["mem_len"] + if node_info[key] < float("inf"): + if node_info[key] < min_node_len: + min_node_len = node_info[key] min_node = node_ip return min_node @@ -73,8 +73,8 @@ def _get_min_node(node_infos: dict): # 获取负载最小的节点 p_node_infos = {k: v for k, v in node_infos.items() if k in self.prefill_nodes} d_node_infos = {k: v for k, v in node_infos.items() if k in self.decode_nodes} - p_node = _get_min_node(p_node_infos) or random.choice(self.prefill_nodes) - d_node = _get_min_node(d_node_infos) or random.choice(self.decode_nodes) + p_node = _get_min_node(p_node_infos, "mem_len") or random.choice(self.prefill_nodes) + d_node = _get_min_node(d_node_infos, "mem_len") or random.choice(self.decode_nodes) return p_node, d_node From d05c0a0d7d0aa1bf52d3be11a58f18b2094fda5a Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Tue, 29 Jul 2025 17:29:02 +0800 Subject: [PATCH 25/33] fix: import --- .../httpserver_for_pd_master/pd_selector/pd_selector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index ef551dbb3..0337ad53a 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -1,3 +1,5 @@ +import random + from typing import Union, List, Tuple from lightllm.server.pd_io_struct import PD_Client_Obj from lightllm.server.core.objs import SamplingParams @@ -22,8 +24,6 @@ class RandomSelector(PDSelector): """随机选择器""" async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - import random - p_node = random.choice(self.prefill_nodes) d_node = random.choice(self.decode_nodes) return p_node, d_node @@ -60,7 +60,6 @@ def _get_min_node(node_infos: dict, key: str): if self.pd_manager is None: # 如果没有 PDManager 引用,回退到随机选择 - import random p_node = random.choice(self.prefill_nodes) if self.prefill_nodes else None d_node = random.choice(self.decode_nodes) if self.decode_nodes else None return p_node, d_node From 2c187a199fe1f814d39f22c1a1de4cfa0d7ea076 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 12:55:14 +0800 Subject: [PATCH 26/33] fix: lint --- lightllm/server/httpserver_for_pd_master/node_info_recorder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py index 68561699c..0781de2f2 100644 --- a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py +++ b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py @@ -38,7 +38,7 @@ def update_node_load_info(self, load_info: dict): else: logger.warning(f"Received load info for unknown node: {ip_port}") else: - logger.warning(f"Received load info without client_ip_port") + logger.warning("Received load info without client_ip_port") def get_node_infos(self): return {k: { From bcbaa715b170be13e7ef28db287d3512571ba7fd Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 13:03:20 +0800 Subject: [PATCH 27/33] =?UTF-8?q?=E6=8B=86=E5=88=86=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E8=B4=9F=E8=BD=BD=E4=BF=A1=E6=81=AF=E7=9A=84?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightllm/server/httpserver/pd_loop.py | 39 +++++++++++++++------------ 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index a57e92b3a..b2322de8e 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -176,6 +176,27 @@ async def _pd_process_generate( logger.error(str(e)) +# 获取节点负载信息 +def _get_load_info(): + from lightllm.server.api_http import g_objs + if g_objs.shared_token_load is not None: + current_load = [ + float(g_objs.shared_token_load.get_frozened_token_count(dp_index)) for dp_index in range(g_objs.args.dp) + ] + if g_objs.args.dp == 1: + current_load = current_load[0] + load_info = { + "mem_len": current_load, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + else: + load_info = { + "mem_len": 0, + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } + return load_info + + # 转发token的task async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): while True: @@ -184,23 +205,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): if handle_list: has_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) if has_finished_req: - # 获取节点负载信息 - from lightllm.server.api_http import g_objs - if g_objs.shared_token_load is not None: - current_load = [ - float(g_objs.shared_token_load.get_frozened_token_count(dp_index)) for dp_index in range(g_objs.args.dp) - ] - if g_objs.args.dp == 1: - current_load = current_load[0] - load_info = { - "mem_len": current_load, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } - else: - load_info = { - "mem_len": 0, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } + load_info = _get_load_info() await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) else: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, None))) From 2674fd56a9d0437f2f3999a847647453020a2b9a Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 13:05:37 +0800 Subject: [PATCH 28/33] remove pd radix selector --- lightllm/server/api_cli.py | 6 +++--- .../httpserver_for_pd_master/pd_selector/__init__.py | 8 ++------ .../httpserver_for_pd_master/pd_selector/pd_selector.py | 4 ---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 55f2e14af..c0b75ea30 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -45,9 +45,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--select_p_d_node_func", type=str, - default="random", - choices=["random", "round_robin", "memory", "radix"], - help="select p d node func, can be random, round_robin, memory or radix", + default="round_robin", + choices=["random", "round_robin", "memory"], + help="select p d node func, can be round_robin, random or memory", ) parser.add_argument( "--config_server_host", diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py index 392b6ea00..dae927341 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/__init__.py @@ -4,16 +4,14 @@ PDSelector, RandomSelector, RoundRobinSelector, - MemorySelector, - RadixSelector + MemorySelector ) __all__ = [ "PDSelector", "RandomSelector", "RoundRobinSelector", - "MemorySelector", - "RadixSelector" + "MemorySelector" ] def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], decode_nodes: List[PD_Client_Obj], pd_manager) -> PDSelector: @@ -23,7 +21,5 @@ def create_selector(selector_type: str, prefill_nodes: List[PD_Client_Obj], deco return RoundRobinSelector(prefill_nodes, decode_nodes, pd_manager) elif selector_type == "memory": return MemorySelector(prefill_nodes, decode_nodes, pd_manager) - elif selector_type == "radix": - return RadixSelector(prefill_nodes, decode_nodes, pd_manager) else: raise ValueError(f"Invalid selector type: {selector_type}") diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 0337ad53a..2a315af35 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -76,7 +76,3 @@ def _get_min_node(node_infos: dict, key: str): d_node = _get_min_node(d_node_infos, "mem_len") or random.choice(self.decode_nodes) return p_node, d_node - -class RadixSelector(PDSelector): - async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - pass From 25459a39e7d912f329e4d7cba6911f48ca5f4b63 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 13:08:58 +0800 Subject: [PATCH 29/33] better typing --- lightllm/server/httpserver/pd_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index b2322de8e..c300b75e1 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -177,7 +177,7 @@ async def _pd_process_generate( # 获取节点负载信息 -def _get_load_info(): +def _get_load_info() -> dict: from lightllm.server.api_http import g_objs if g_objs.shared_token_load is not None: current_load = [ @@ -205,7 +205,7 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): if handle_list: has_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) if has_finished_req: - load_info = _get_load_info() + load_info: dict = _get_load_info() await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) else: await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, None))) From a5e14ce44ea85765dea7a0a2880735512318358b Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 13:14:29 +0800 Subject: [PATCH 30/33] remove unnecessary if --- lightllm/server/httpserver/pd_loop.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index c300b75e1..b020d9d59 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -177,7 +177,10 @@ async def _pd_process_generate( # 获取节点负载信息 -def _get_load_info() -> dict: +def _get_load_info(have_finished_req: bool) -> dict: + if not have_finished_req: + return None + from lightllm.server.api_http import g_objs if g_objs.shared_token_load is not None: current_load = [ @@ -203,9 +206,6 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket): handle_list = await forwarding_queue.wait_to_get_all_data() if handle_list: - has_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) - if has_finished_req: - load_info: dict = _get_load_info() - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) - else: - await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, None))) + have_finished_req = any(finish_status.is_finished() for _, _, _, finish_status in handle_list) + load_info: dict = _get_load_info(have_finished_req) + await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list, load_info))) From ed77be1d960bc189cb0cb6193e335c9dd593f16e Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 14:21:36 +0800 Subject: [PATCH 31/33] feat: better load info select --- .../httpserver_for_pd_master/manager.py | 16 ---- .../node_info_recorder.py | 79 +++++++++++-------- .../pd_selector/pd_selector.py | 14 ++-- 3 files changed, 49 insertions(+), 60 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index d30ce92a7..e0311efaf 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -80,18 +80,10 @@ def update_node_load_info(self, load_info: dict): return self.node_info_recorder.update_node_load_info(load_info) - def get_node_load_info(self): - """获取所有节点的负载信息""" - return self.node_info_recorder.get_node_infos() - def get_predict_node_infos(self): """获取所有节点的预测负载信息""" return self.node_info_recorder.get_predict_node_infos() - def get_node_load_info_by_node(self, client_ip_port: str): - """获取指定节点的负载信息""" - return self.node_info_recorder.get_node_info(client_ip_port) - async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: p_node, d_node = await self.selector.select_p_d_node(prompt, sampling_params, multimodal_params) self.node_info_recorder.update_predict_node_info(p_node, d_node, prompt, sampling_params, multimodal_params) @@ -395,14 +387,6 @@ async def timer_log(self): async def put_to_handle_queue(self, obj): await self.infos_queues.put(obj) - def get_node_load_info(self): - """获取所有节点的负载信息""" - return self.pd_manager.get_node_load_info() - - def get_node_load_info_by_node(self, client_ip_port: str): - """获取指定节点的负载信息""" - return self.pd_manager.get_node_load_info_by_node(client_ip_port) - async def handle_loop(self): self.infos_queues = AsyncQueue() asyncio.create_task(self.timer_log()) diff --git a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py index 0781de2f2..e7c2e3eb0 100644 --- a/lightllm/server/httpserver_for_pd_master/node_info_recorder.py +++ b/lightllm/server/httpserver_for_pd_master/node_info_recorder.py @@ -2,7 +2,7 @@ from ..pd_io_struct import PD_Client_Obj from lightllm.server.httpserver.manager import SamplingParams, MultimodalParams -from typing import Union, List +from typing import Union, List, Dict from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -10,10 +10,11 @@ class NodeInfoRecorder: def __init__(self): - self.node_info = {} + self.prefill_node_info: dict = {} + self.decode_node_info: dict = {} def register_node(self, pd_client: PD_Client_Obj): - self.node_info[pd_client.client_ip_port] = { + node_info = { "node_id": pd_client.node_id, "client_ip_port": pd_client.client_ip_port, "mode": pd_client.mode, @@ -21,62 +22,70 @@ def register_node(self, pd_client: PD_Client_Obj): "mem_len": 0, # "batch_size": 0, } + if pd_client.mode == "prefill": + self.prefill_node_info[pd_client.client_ip_port] = node_info + elif pd_client.mode == "decode": + self.decode_node_info[pd_client.client_ip_port] = node_info + else: + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" def remove_node(self, pd_client: PD_Client_Obj): - del self.node_info[pd_client.client_ip_port] + if pd_client.mode == "prefill": + del self.prefill_node_info[pd_client.client_ip_port] + elif pd_client.mode == "decode": + del self.decode_node_info[pd_client.client_ip_port] + else: + assert False, f"mode must in ['prefill', 'decode'], but get {pd_client.mode}" def update_node_load_info(self, load_info: dict): if "client_ip_port" in load_info: ip_port = load_info["client_ip_port"] - if ip_port in self.node_info: - mem_len = load_info["mem_len"] - # batch_size = load_info["batch_size"] - self.node_info[ip_port]["mem_len"] = mem_len - # self.node_info[ip_port]["batch_size"] = batch_size - logger.debug(f"Updated node load info for {ip_port}: {mem_len}") - # logger.debug(f"Updated node load info for {ip_port}: {mem_len}, {batch_size}") + if ip_port in self.prefill_node_info: + self.prefill_node_info[ip_port]["mem_len"] = load_info["mem_len"] + elif ip_port in self.decode_node_info: + self.decode_node_info[ip_port]["mem_len"] = load_info["mem_len"] else: logger.warning(f"Received load info for unknown node: {ip_port}") else: logger.warning("Received load info without client_ip_port") - def get_node_infos(self): - return {k: { - "mem_len": v.get("mem_len", int("inf")), - # "batch_size": v.get("batch_size", float("inf")), - } for k, v in self.node_info.items()} - - def get_node_info(self, client_ip_port: str): - return self.node_info.get(client_ip_port, None) class PredictNodeInfoRecorder(NodeInfoRecorder): def __init__(self): super().__init__() - self.predict_node_info = {} + self.prefill_predict_node_info: dict = {} + self.decode_predict_node_info: dict = {} def register_node(self, pd_client: PD_Client_Obj): super().register_node(pd_client) - self.predict_node_info[pd_client.client_ip_port] = copy.copy(self.node_info[pd_client.client_ip_port]) + if pd_client.mode == "prefill": + self.prefill_predict_node_info[pd_client.client_ip_port] = copy.copy(self.prefill_node_info[pd_client.client_ip_port]) + elif pd_client.mode == "decode": + self.decode_predict_node_info[pd_client.client_ip_port] = copy.copy(self.decode_node_info[pd_client.client_ip_port]) def remove_node(self, pd_client: PD_Client_Obj): super().remove_node(pd_client) - del self.predict_node_info[pd_client.client_ip_port] + if pd_client.mode == "prefill": + del self.prefill_predict_node_info[pd_client.client_ip_port] + elif pd_client.mode == "decode": + del self.decode_predict_node_info[pd_client.client_ip_port] def update_node_load_info(self, load_info: dict): super().update_node_load_info(load_info) - self.predict_node_info[load_info["client_ip_port"]] = copy.copy(self.node_info[load_info["client_ip_port"]]) + ip_port = load_info["client_ip_port"] + if ip_port in self.prefill_node_info: + self.prefill_predict_node_info[ip_port] = copy.copy(self.prefill_node_info[ip_port]) + elif ip_port in self.decode_node_info: + self.decode_predict_node_info[ip_port] = copy.copy(self.decode_node_info[ip_port]) + else: + logger.warning(f"Received load info for unknown node: {ip_port}") def update_predict_node_info(self, p_node: PD_Client_Obj, d_node: PD_Client_Obj, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams): - self.predict_node_info[p_node.client_ip_port]["mem_len"] += len(prompt) - # self.predict_node_info[p_node.client_ip_port]["batch_size"] += 1 - self.predict_node_info[d_node.client_ip_port]["mem_len"] += sampling_params.max_new_tokens - # self.predict_node_info[d_node.client_ip_port]["batch_size"] += 1 - - def get_predict_node_infos(self): - return {k: { - "mem_len": v.get("mem_len", float("inf")), - # "batch_size": v.get("batch_size", float("inf")), - } for k, v in self.predict_node_info.items()} + self.prefill_predict_node_info[p_node.client_ip_port]["mem_len"] += len(prompt) + self.decode_predict_node_info[d_node.client_ip_port]["mem_len"] += sampling_params.max_new_tokens - def get_predict_node_info(self, client_ip_port: str): - return self.predict_node_info.get(client_ip_port, None) + def get_predict_node_infos(self) -> Dict[str, dict]: + return { + "prefill": self.prefill_predict_node_info, + "decode": self.decode_predict_node_info, + } diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index 2a315af35..c655aa279 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -52,10 +52,9 @@ async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: def _get_min_node(node_infos: dict, key: str): min_node, min_node_len = None, float("inf") for node_ip, node_info in node_infos.items(): - if node_info[key] < float("inf"): - if node_info[key] < min_node_len: - min_node_len = node_info[key] - min_node = node_ip + if node_info[key] < min_node_len: + min_node_len = node_info[key] + min_node = node_ip return min_node if self.pd_manager is None: @@ -65,13 +64,10 @@ def _get_min_node(node_infos: dict, key: str): return p_node, d_node node_infos = self.pd_manager.get_predict_node_infos() - node_infos = {k: v for k, v in node_infos.items() if v["mem_len"] < float("inf")} - if len(node_infos) == 0: - return random.choice(self.prefill_nodes), random.choice(self.decode_nodes) # 获取负载最小的节点 - p_node_infos = {k: v for k, v in node_infos.items() if k in self.prefill_nodes} - d_node_infos = {k: v for k, v in node_infos.items() if k in self.decode_nodes} + p_node_infos = node_infos["prefill"] + d_node_infos = node_infos["decode"] p_node = _get_min_node(p_node_infos, "mem_len") or random.choice(self.prefill_nodes) d_node = _get_min_node(d_node_infos, "mem_len") or random.choice(self.decode_nodes) From cedbb50c624fb79b72e428787c726a2a17a6ba64 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 14:25:21 +0800 Subject: [PATCH 32/33] feat: update get load info --- lightllm/server/httpserver/pd_loop.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index b020d9d59..bd9e9ca3d 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -182,21 +182,14 @@ def _get_load_info(have_finished_req: bool) -> dict: return None from lightllm.server.api_http import g_objs - if g_objs.shared_token_load is not None: - current_load = [ - float(g_objs.shared_token_load.get_frozened_token_count(dp_index)) for dp_index in range(g_objs.args.dp) - ] - if g_objs.args.dp == 1: - current_load = current_load[0] - load_info = { - "mem_len": current_load, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } - else: - load_info = { - "mem_len": 0, - "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" - } + assert g_objs.shared_token_load is not None, "shared_token_load is not initialized" + current_load = [ + float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(g_objs.args.dp) + ] + load_info = { + "mem_len": min(current_load), + "client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}" + } return load_info From e55cbb5cc529938b3a718b8df2ac3d48a2431576 Mon Sep 17 00:00:00 2001 From: pigKiller <22373017@buaa.edu.cn> Date: Thu, 31 Jul 2025 14:42:32 +0800 Subject: [PATCH 33/33] fix: memory node select --- .../pd_selector/pd_selector.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py index c655aa279..bae2555db 100644 --- a/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py +++ b/lightllm/server/httpserver_for_pd_master/pd_selector/pd_selector.py @@ -1,6 +1,6 @@ import random -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict from lightllm.server.pd_io_struct import PD_Client_Obj from lightllm.server.core.objs import SamplingParams from lightllm.server.multimodal_params import MultimodalParams @@ -49,13 +49,14 @@ class MemorySelector(PDSelector): """基于内存使用情况的选择器""" async def select_p_d_node(self, prompt: Union[str, List[int]], sampling_params: SamplingParams, multimodal_params: MultimodalParams) -> Tuple[PD_Client_Obj, PD_Client_Obj]: - def _get_min_node(node_infos: dict, key: str): - min_node, min_node_len = None, float("inf") - for node_ip, node_info in node_infos.items(): - if node_info[key] < min_node_len: - min_node_len = node_info[key] - min_node = node_ip - return min_node + def _get_min_node(nodes: List[PD_Client_Obj], node_infos: Dict[str, dict], key: str) -> PD_Client_Obj: + min_node, min_node_value = None, float("inf") + for node in nodes: + if node.client_ip_port in node_infos: + if node_infos[node.client_ip_port][key] < min_node_value: + min_node_value = node_infos[node.client_ip_port][key] + min_node = node + return min_node if min_node is not None else random.choice(nodes) if self.pd_manager is None: # 如果没有 PDManager 引用,回退到随机选择 @@ -68,7 +69,7 @@ def _get_min_node(node_infos: dict, key: str): # 获取负载最小的节点 p_node_infos = node_infos["prefill"] d_node_infos = node_infos["decode"] - p_node = _get_min_node(p_node_infos, "mem_len") or random.choice(self.prefill_nodes) - d_node = _get_min_node(d_node_infos, "mem_len") or random.choice(self.decode_nodes) + p_node = _get_min_node(self.prefill_nodes, p_node_infos, "mem_len") + d_node = _get_min_node(self.decode_nodes, d_node_infos, "mem_len") return p_node, d_node