Skip to content

router and infer parrall. #965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 84 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
90b9ab3
add shm_reqs_buffer.py
hiworldwzj Jul 8, 2025
a6d2f24
fix
hiworldwzj Jul 9, 2025
a2fcf3b
fix
hiworldwzj Jul 9, 2025
bf7ce59
fix
hiworldwzj Jul 9, 2025
b0062a5
fix
hiworldwzj Jul 9, 2025
ca45a2c
fix
hiworldwzj Jul 9, 2025
433a442
fix
hiworldwzj Jul 9, 2025
297dcbc
fix
hiworldwzj Jul 9, 2025
c80f351
fix
hiworldwzj Jul 9, 2025
d9a4774
fix
hiworldwzj Jul 9, 2025
a52ca86
first overlap demo.
hiworldwzj Jul 10, 2025
4cb1f7b
fix
hiworldwzj Jul 11, 2025
9cac880
fix
hiworldwzj Jul 11, 2025
445c959
fix
hiworldwzj Jul 11, 2025
992d7b5
fix
hiworldwzj Jul 11, 2025
f885680
fix
hiworldwzj Jul 11, 2025
c01ea7e
fix
hiworldwzj Jul 11, 2025
e4dfa95
fix
hiworldwzj Jul 11, 2025
7e6cbf6
fix all
hiworldwzj Jul 12, 2025
f693e55
fix
hiworldwzj Jul 12, 2025
3de8ab5
fix
hiworldwzj Jul 12, 2025
2a5a9ce
fix
hiworldwzj Jul 14, 2025
df79f8e
fix
hiworldwzj Jul 14, 2025
5834e80
fix
hiworldwzj Jul 14, 2025
dd5d872
fix
hiworldwzj Jul 14, 2025
d3a9f88
fix
hiworldwzj Jul 14, 2025
d4aa1ee
fix
hiworldwzj Jul 14, 2025
8bad40b
fix
hiworldwzj Jul 14, 2025
29cfca1
fix
hiworldwzj Jul 14, 2025
21caa00
fix
hiworldwzj Jul 14, 2025
af6f547
fix
hiworldwzj Jul 14, 2025
112656b
fix
hiworldwzj Jul 15, 2025
2dff9a1
fix
hiworldwzj Jul 15, 2025
4c111a3
fix
hiworldwzj Jul 15, 2025
b172eaf
fix
hiworldwzj Jul 15, 2025
241ec63
fix
hiworldwzj Jul 15, 2025
678bb5f
inference overlap
shihaobai Jul 15, 2025
2d46245
Merge branch 'wzj_router' of https://github.com/ModelTC/lightllm into…
shihaobai Jul 15, 2025
204f6fa
fix
hiworldwzj Jul 16, 2025
dde6f18
fix
hiworldwzj Jul 16, 2025
6cd8c56
fix
hiworldwzj Jul 16, 2025
8cc2325
back the infer_struct
shihaobai Jul 16, 2025
a7fbb15
overlap sample
shihaobai Jul 16, 2025
ef35cf6
add mtp index
shihaobai Jul 16, 2025
2d81be7
fix
hiworldwzj Jul 16, 2025
8705d0a
fix next token ids.
hiworldwzj Jul 16, 2025
35f1bfa
fix
hiworldwzj Jul 16, 2025
d437d7a
fix
hiworldwzj Jul 16, 2025
ac47e1f
mtp overlap (draft)
shihaobai Jul 16, 2025
4cfae7a
fix
hiworldwzj Jul 17, 2025
9dd795d
diverse mode ok
hiworldwzj Jul 17, 2025
517c18e
add penalty_counter mode
hiworldwzj Jul 17, 2025
5fb73f9
fix
hiworldwzj Jul 17, 2025
dcb76ed
fix
hiworldwzj Jul 17, 2025
2efa7f5
fix
hiworldwzj Jul 17, 2025
4157ff4
fix
hiworldwzj Jul 18, 2025
7c1a597
improve pin mem manager
hiworldwzj Jul 18, 2025
965cdae
overlap mtp
shihaobai Jul 18, 2025
65fb8c6
merge latest
shihaobai Jul 18, 2025
1dad212
fix
hiworldwzj Jul 18, 2025
0e3cdb7
fix mtp
hiworldwzj Jul 18, 2025
098ae9d
fix
hiworldwzj Jul 18, 2025
c42e528
fix
hiworldwzj Jul 21, 2025
48fc14c
fix
hiworldwzj Jul 21, 2025
cfa1758
fix dp mode
hiworldwzj Jul 21, 2025
8a9fe69
fix
hiworldwzj Jul 21, 2025
c3a080d
fix
hiworldwzj Jul 21, 2025
1e9b419
fix
hiworldwzj Jul 21, 2025
dff8618
fix
hiworldwzj Jul 21, 2025
8599b45
fix
hiworldwzj Jul 21, 2025
c6efc13
fix
hiworldwzj Jul 21, 2025
9e5171b
fix dp prefill_mtp and decode_mtp
hiworldwzj Jul 21, 2025
c8bd663
fix
hiworldwzj Jul 21, 2025
31a3ed8
fix multinode tp error
hiworldwzj Jul 21, 2025
5cde331
fix
hiworldwzj Jul 21, 2025
8cbca72
fix
hiworldwzj Jul 21, 2025
4e7b34a
fix
hiworldwzj Jul 21, 2025
257050e
add mtp microbatch overlap
shihaobai Jul 21, 2025
f1224be
Merge branch 'wzj_router' of https://github.com/ModelTC/lightllm into…
shihaobai Jul 21, 2025
366f4d4
fix
shihaobai Jul 21, 2025
7fe6038
fix ep scatter overflow
shihaobai Jul 21, 2025
e5af904
fix
hiworldwzj Jul 21, 2025
2019b7a
fix
hiworldwzj Jul 21, 2025
a789ddb
fix
hiworldwzj Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
from lightllm.distributed.communication_op import dist_group_manager
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
from lightllm.utils.envs_utils import set_model_init_status
Expand Down Expand Up @@ -237,6 +238,7 @@ def _init_custom(self):

@torch.no_grad()
def forward(self, model_input: ModelInput):
model_input.to_cuda()
assert model_input.mem_indexes.is_cuda

if model_input.is_prefill:
Expand Down Expand Up @@ -346,6 +348,14 @@ def _decode(
self,
model_input: ModelInput,
) -> ModelOutput:
# for overlap mode
if model_input.input_ids is None:
model_input.input_ids = gather_token(
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
model_input.b_req_idx,
model_input.b_mtp_index,
)

if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
Expand Down Expand Up @@ -453,6 +463,9 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):

@torch.no_grad()
def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput):
model_input0.to_cuda()
model_input1.to_cuda()

assert model_input0.mem_indexes.is_cuda
assert model_input1.mem_indexes.is_cuda
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
Expand Down Expand Up @@ -490,6 +503,22 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod

@torch.no_grad()
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
model_input0.to_cuda()
model_input1.to_cuda()

if model_input0.input_ids is None:
model_input0.input_ids = gather_token(
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
model_input0.b_req_idx,
model_input0.b_mtp_index,
)
if model_input1.input_ids is None:
model_input1.input_ids = gather_token(
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
model_input1.b_req_idx,
model_input1.b_mtp_index,
)

assert model_input0.batch_size == model_input1.batch_size
assert model_input0.mem_indexes.is_cuda
assert model_input1.mem_indexes.is_cuda
Expand Down Expand Up @@ -659,6 +688,7 @@ def _check_max_len_infer(self):
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
total_token_num = self.batch_max_tokens
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
Expand All @@ -667,6 +697,7 @@ def _check_max_len_infer(self):
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
)
Expand Down Expand Up @@ -714,13 +745,15 @@ def _init_padded_req(self):
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
total_token_num = prefill_input_len * batch_size
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=prefill_input_len,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
b_seq_len=b_seq_len,
b_ready_cache_len=b_ready_cache_len,
is_prefill=True,
Expand Down
21 changes: 20 additions & 1 deletion lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass, field
from typing import Optional
from typing import List


@dataclass
Expand All @@ -10,20 +11,38 @@ class ModelInput:
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
mem_indexes: torch.Tensor
b_req_idx: torch.Tensor
b_mtp_index: torch.Tensor
b_seq_len: torch.Tensor
mem_indexes: torch.Tensor = None
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

# cpu 变量
mem_indexes_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出

# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
# 的 draft 模型的输入
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
if self.mem_indexes is None:
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)


@dataclass
class ModelOutput:
Expand Down
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def warmup(self, model):
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")

model_input = ModelInput(
batch_size=batch_size,
Expand All @@ -211,6 +212,7 @@ def warmup(self, model):
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
is_prefill=False,
**model._gen_special_model_input(batch_size),
)
Expand Down Expand Up @@ -256,13 +258,15 @@ def warmup_overlap(self, model):
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")

micro_batch = ModelInput(
is_prefill=False,
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
input_ids=input_ids,
b_mtp_index=b_mtp_index,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
Expand Down
8 changes: 5 additions & 3 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from .triton_kernel.multimodal_emb import mark_multimodal_obj
from .batch_objs import ModelInput


class InferStateInfo:
Expand Down Expand Up @@ -87,9 +88,10 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.b_kv_seq_len,
self.b1_cu_kv_seq_len,
self.position_ids,
self.max_q_seq_len,
self.max_kv_seq_len,
) = gen_decode_params(b_seq_len=self.b_seq_len)
) = gen_decode_params(self.b_seq_len)
self.max_q_seq_len = 1
# TODO: check the correctness
self.max_kv_seq_len = self.max_len_in_batch
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]

def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
Expand Down
165 changes: 165 additions & 0 deletions lightllm/common/basemodel/triton_kernel/gather_token_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_scatter(
next_token_ids,
req_to_next_token_ids,
b_req_idx,
b_mtp_index,
b_has_out,
req_to_next_token_ids_stride,
req_to_next_token_ids_stride_1,
num_size,
HAS_OUT_IS_NONE: tl.constexpr,
BLOCK: tl.constexpr,
):
block_index = tl.program_id(0)
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
block_mask = block_range < num_size

cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
cur_next_token_id = tl.load(next_token_ids + block_range, mask=block_mask)

if not HAS_OUT_IS_NONE:
cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False)
tl.store(
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
cur_next_token_id,
mask=cur_has_out & block_mask,
)
else:
tl.store(
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
cur_next_token_id,
mask=block_mask,
)

return


@torch.no_grad()
def scatter_token(
next_token_ids: torch.Tensor,
req_to_next_token_ids: torch.Tensor,
b_req_idx: torch.Tensor,
b_mtp_index: torch.Tensor,
b_has_out: torch.Tensor = None,
):
"""
This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
Args:
next_token_ids: (batch_size,)
req_to_next_token_ids: (max_req_num, max_mtp_step)
b_req_idx: (batch_size,)
b_mtp_index: (batch_size,)
"""
assert next_token_ids.shape[0] == b_req_idx.shape[0]
batch_size = b_req_idx.shape[0]
BLOCK = 256

grid = (triton.cdiv(batch_size, BLOCK),)
num_warps = 1

_fwd_kernel_scatter[grid](
next_token_ids=next_token_ids,
req_to_next_token_ids=req_to_next_token_ids,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
b_has_out=b_has_out,
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
num_size=batch_size,
HAS_OUT_IS_NONE=b_has_out is None,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return


@triton.jit
def _fwd_kernel_gather(
req_to_next_token_ids,
req_to_next_token_ids_stride,
req_to_next_token_ids_stride_1,
output,
b_req_idx,
b_mtp_index,
num_size,
BLOCK: tl.constexpr,
):
block_index = tl.program_id(0)
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
block_mask = block_range < num_size
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
cur_next_token_id = tl.load(
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask
)
tl.store(output + block_range, cur_next_token_id, mask=block_mask)
return


def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b_mtp_index: torch.Tensor):
"""
This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
Args:
req_to_token_info: (max_req_num, max_mtp_step)
b_req_idx: (batch_size,)
b_mtp_index: (batch_size,)
Returns:
output: (batch_size,)
"""
batch_size = b_req_idx.shape[0]
output = torch.empty(batch_size, dtype=req_to_next_token_ids.dtype, device="cuda")
BLOCK = 256
grid = (triton.cdiv(batch_size, BLOCK),)
num_warps = 1
_fwd_kernel_gather[grid](
req_to_next_token_ids=req_to_next_token_ids,
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
output=output,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
num_size=batch_size,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return output


def test_scatter_token_to_cpu():
batch_size = 30
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
token_info = torch.randn((batch_size,)).cuda()
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
diff = (req_to_token_info[20 : 20 + batch_size].cuda().view(-1) - token_info).abs().max()
assert diff < 1e-6
print("test_scatter_token_to_cpu passed")


def test_gather_token():
batch_size = 30
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
token_info = torch.randn((batch_size,)).cuda()
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
output = gather_token(req_to_token_info, req_ids, mtp_index)
diff = (token_info - output).abs().max()
assert diff < 1e-6
print("test_gather_token passed")


if __name__ == "__main__":
test_scatter_token_to_cpu()
test_gather_token()
5 changes: 2 additions & 3 deletions lightllm/common/basemodel/triton_kernel/gen_decode_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ def gen_decode_params(b_seq_len: torch.Tensor):
position_ids = b_seq_len - 1
b_q_seq_len = torch.ones_like(b_seq_len)
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
max_q_seq_len = b_q_seq_len.max().item()
max_kv_seq_len = b_kv_seq_len.max().item()
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len

return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids
Loading