Skip to content

cuda graph pool with LRU #964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F
from typing import final

import torch.distributed as dist
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.mem_manager import MemoryManager
Expand All @@ -19,7 +20,7 @@
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size, get_global_rank
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
Expand Down Expand Up @@ -356,8 +357,20 @@ def _decode(
model_input.b_mtp_index,
)

if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
# collect global max batch_size
world_size = get_global_world_size()
rank = get_global_rank()
all_batch_sizes = [None] * world_size
all_batch_sizes[rank] = model_input.batch_size
dist.all_gather_object(all_batch_sizes, model_input.batch_size)
global_max_batch_size = max(all_batch_sizes)

if self.graph is not None and self.graph.can_run(global_max_batch_size, model_input.max_len_in_batch):
find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size)
if find_graph_batch_size is None:
logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.")
return None

padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
infer_state = self._create_inferstate(padded_model_input)
copy_kv_index_to_req(
Expand Down Expand Up @@ -526,8 +539,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
origin_batch_size = model_input0.batch_size
max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch)

if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch):
find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size)
# collect global max batch_size
world_size = get_global_world_size()
rank = get_global_rank()
all_batch_sizes = [None] * world_size
all_batch_sizes[rank] = origin_batch_size
dist.all_gather_object(all_batch_sizes, origin_batch_size)
global_max_batch_size = max(all_batch_sizes)

if self.graph is not None and self.graph.can_run(global_max_batch_size, max_len_in_batch):
find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size)
if find_graph_batch_size is None:
logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.")
return None

padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
infer_state0 = self._create_inferstate(padded_model_input0, 0)
Expand Down
35 changes: 27 additions & 8 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@
import torch
import copy
import bisect
from collections import OrderedDict
from typing import Optional
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from .infer_struct import InferStateInfo


logger = init_logger(__name__)


class CudaGraph:
# CudaGraph forward pass for the decoding stage.

def __init__(self, max_batch_size=8, max_len_in_batch=8192):
self.graph = {}
self.graph = OrderedDict() # for LRU

self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.args = get_env_start_args()
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
self.max_graph_pool_size = self.args.max_graph_pool_size

# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
Expand All @@ -48,11 +50,20 @@ def can_run(self, batch_size, max_len_in_batch):
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch

def need_capture(self, batch_size):
find_batch_size = self.find_closest_graph_batch_size(batch_size)
if find_batch_size is not None:
return find_batch_size not in self.graph
# We assume batch_size has already been adjusted to the closest supported graph batch size
# If the graph already exists, get it and move it to the most recently used position.
if batch_size in self.graph:
find_graph = self.graph.pop(batch_size) # Dequeue the graph
self.graph[batch_size] = find_graph # Enqueue the graph for LRU
return False
else:
assert False, "dead code"
return True

def evict_oldest_graph(self):
if self.graph:
oldest_batch_size, oldest_graph = self.graph.popitem(last=False)
del oldest_graph
logger.info(f"Evicted CUDA graph for batch size: {oldest_batch_size}")

def find_closest_graph_batch_size(self, batch_size):
index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size)
Expand All @@ -64,6 +75,9 @@ def find_closest_graph_batch_size(self, batch_size):

def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo):
dist_group: CustomProcessGroup = infer_state.dist_group
if len(self.graph) >= self.max_graph_pool_size:
self.evict_oldest_graph()

graph_obj = torch.cuda.CUDAGraph()
batch_size = input_ids.shape[0]
infer_state.max_len_in_batch = self.graph_max_len_in_batch
Expand All @@ -84,6 +98,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(input_ids, infer_state)
# We assume batch_size has already been adjusted to the closest supported graph batch size
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
graph_obj.replay()
return model_output
Expand All @@ -97,6 +112,9 @@ def _capture_decode_overlap(
infer_state1: InferStateInfo,
):
dist_group: CustomProcessGroup = infer_state.dist_group
if len(self.graph) >= self.max_graph_pool_size:
self.evict_oldest_graph()

dist_group1 = infer_state1.dist_group
graph_obj = torch.cuda.CUDAGraph()
batch_size = input_ids.shape[0]
Expand All @@ -113,6 +131,7 @@ def _capture_decode_overlap(
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
# We assume batch_size has already been adjusted to the closest supported graph batch size
self.graph[batch_size] = (
graph_obj,
input_ids,
Expand Down Expand Up @@ -191,7 +210,7 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
for batch_size in (self.cuda_graph_batch_sizes[-1],):
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -244,7 +263,7 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
for batch_size in (self.cuda_graph_batch_sizes[-1],):
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
7 changes: 7 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")

parser.add_argument(
"--max_graph_pool_size",
type=int,
default=16,
help="""Maximum cuda graph pool size for decoding stage.""",
)

parser.add_argument(
"--graph_max_batch_size",
type=int,
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class StartArgs:
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
enable_monitor_auth: bool = field(default=False)
disable_cudagraph: bool = field(default=False)
max_graph_pool_size: int = field(default=16)
graph_max_batch_size: int = field(default=256)
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
Expand Down