diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..80b36295e 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from typing import final +import torch.distributed as dist from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager @@ -19,7 +20,7 @@ from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token from lightllm.utils.log_utils import init_logger -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size, get_global_rank from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed.communication_op import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput @@ -356,8 +357,20 @@ def _decode( model_input.b_mtp_index, ) - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) + # collect global max batch_size + world_size = get_global_world_size() + rank = get_global_rank() + all_batch_sizes = [None] * world_size + all_batch_sizes[rank] = model_input.batch_size + dist.all_gather_object(all_batch_sizes, model_input.batch_size) + global_max_batch_size = max(all_batch_sizes) + + if self.graph is not None and self.graph.can_run(global_max_batch_size, model_input.max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size) + if find_graph_batch_size is None: + logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.") + return None + padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) copy_kv_index_to_req( @@ -526,8 +539,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) - if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + # collect global max batch_size + world_size = get_global_world_size() + rank = get_global_rank() + all_batch_sizes = [None] * world_size + all_batch_sizes[rank] = origin_batch_size + dist.all_gather_object(all_batch_sizes, origin_batch_size) + global_max_batch_size = max(all_batch_sizes) + + if self.graph is not None and self.graph.can_run(global_max_batch_size, max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size) + if find_graph_batch_size is None: + logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.") + return None + padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..8d436e62f 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch import copy import bisect +from collections import OrderedDict from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -9,7 +10,6 @@ from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from .infer_struct import InferStateInfo - logger = init_logger(__name__) @@ -17,12 +17,14 @@ class CudaGraph: # CudaGraph forward pass for the decoding stage. def __init__(self, max_batch_size=8, max_len_in_batch=8192): - self.graph = {} + self.graph = OrderedDict() # for LRU + self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.args = get_env_start_args() self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.max_graph_pool_size = self.args.max_graph_pool_size # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -48,11 +50,20 @@ def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def need_capture(self, batch_size): - find_batch_size = self.find_closest_graph_batch_size(batch_size) - if find_batch_size is not None: - return find_batch_size not in self.graph + # We assume batch_size has already been adjusted to the closest supported graph batch size + # If the graph already exists, get it and move it to the most recently used position. + if batch_size in self.graph: + find_graph = self.graph.pop(batch_size) # Dequeue the graph + self.graph[batch_size] = find_graph # Enqueue the graph for LRU + return False else: - assert False, "dead code" + return True + + def evict_oldest_graph(self): + if self.graph: + oldest_batch_size, oldest_graph = self.graph.popitem(last=False) + del oldest_graph + logger.info(f"Evicted CUDA graph for batch size: {oldest_batch_size}") def find_closest_graph_batch_size(self, batch_size): index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) @@ -64,6 +75,9 @@ def find_closest_graph_batch_size(self, batch_size): def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group + if len(self.graph) >= self.max_graph_pool_size: + self.evict_oldest_graph() + graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch @@ -84,6 +98,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(input_ids, infer_state) + # We assume batch_size has already been adjusted to the closest supported graph batch size self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() return model_output @@ -97,6 +112,9 @@ def _capture_decode_overlap( infer_state1: InferStateInfo, ): dist_group: CustomProcessGroup = infer_state.dist_group + if len(self.graph) >= self.max_graph_pool_size: + self.evict_oldest_graph() + dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] @@ -113,6 +131,7 @@ def _capture_decode_overlap( with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) + # We assume batch_size has already been adjusted to the closest supported graph batch size self.graph[batch_size] = ( graph_obj, input_ids, @@ -191,7 +210,7 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + for batch_size in (self.cuda_graph_batch_sizes[-1],): seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -244,7 +263,7 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + for batch_size in (self.cuda_graph_batch_sizes[-1],): decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 465f9cc92..c8a26c62f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -338,6 +338,13 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") + parser.add_argument( + "--max_graph_pool_size", + type=int, + default=16, + help="""Maximum cuda graph pool size for decoding stage.""", + ) + parser.add_argument( "--graph_max_batch_size", type=int, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d4a205a15..62dfd0d5c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -76,6 +76,7 @@ class StartArgs: visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) + max_graph_pool_size: int = field(default=16) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16)