Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
bc281e2
qwen3_moe mtp
shihaobai Dec 1, 2025
f4f8415
fix weight name
shihaobai Dec 1, 2025
652dd7d
fix qwen3 fa3 mtp
shihaobai Dec 1, 2025
11dc305
fix
shihaobai Dec 1, 2025
5c2ae24
fix
shihaobai Dec 1, 2025
f09c9bb
fix
shihaobai Dec 1, 2025
c9de6f6
fix rebase
Dec 30, 2025
60b3113
mtp dense
shihaobai Dec 9, 2025
bfa1cfa
mtp dense weight
shihaobai Dec 9, 2025
323761e
fix
shihaobai Dec 9, 2025
e973b5e
fix
shihaobai Dec 9, 2025
aeb4609
fix
shihaobai Dec 9, 2025
224c398
remove mtp norm
shihaobai Dec 30, 2025
71bcd72
mtp dense
shihaobai Dec 12, 2025
5046c53
update
shihaobai Dec 25, 2025
043799b
fix
Dec 30, 2025
72616ef
fix
Dec 30, 2025
7955d77
fix mtp mistral model
Dec 30, 2025
47f768a
mistral mtp pre layer infer
Dec 30, 2025
f63a725
fix pre layer mtp
Dec 30, 2025
acdd94e
fix mistral mtp weight load
Dec 30, 2025
253a60c
fix
Dec 30, 2025
979cd27
fix
Dec 30, 2025
1fd4c92
fix mistral support fa3
Dec 31, 2025
7036a27
fix weight
Dec 31, 2025
4f8765b
fix mtp_avg_token_per_step calcu
Dec 31, 2025
f336449
diverse_mode support mtp
Dec 31, 2025
de918c7
fix init weights and init layers
Dec 31, 2025
2013d6f
fix init weights and init layers
Dec 31, 2025
b4872c0
rename mtp
Dec 31, 2025
37c04a6
fix cpu cache kv layer num
Dec 31, 2025
a4416b3
fix mem layer
Dec 31, 2025
085f9db
fix bloom
Dec 31, 2025
1a1dd87
fix added_mtp_kv_layer_num
Dec 31, 2025
5da2594
fix token decode kernel for int32 overflow
Dec 31, 2025
74e0f75
fix mtp mode support
Dec 31, 2025
594fa94
fix
hiworldwzj Jan 1, 2026
5488e4a
fix is_egale_mtp
hiworldwzj Jan 1, 2026
616c37f
fix inferstate input_ids
hiworldwzj Jan 1, 2026
419e0fe
fix get input len
hiworldwzj Jan 1, 2026
ea06919
fix norm
hiworldwzj Jan 2, 2026
35ab3bb
fix norm head tp
hiworldwzj Jan 2, 2026
aedb85d
fix norm
hiworldwzj Jan 2, 2026
82112f8
fix bloom
hiworldwzj Jan 2, 2026
89f8fda
fix rmsnorm llama call
hiworldwzj Jan 2, 2026
0aae4f6
fix pos embdiing
hiworldwzj Jan 2, 2026
edd3f35
fix starcoder
hiworldwzj Jan 2, 2026
d97cc40
fix rmsnorm call
hiworldwzj Jan 2, 2026
41e82e0
fix chatglm
hiworldwzj Jan 2, 2026
55666de
fix stablelm
hiworldwzj Jan 2, 2026
4355437
fix wte name
hiworldwzj Jan 2, 2026
0fad6c6
fix rmsnorm
hiworldwzj Jan 2, 2026
e4be216
fix
hiworldwzj Jan 2, 2026
741532e
fix
hiworldwzj Jan 2, 2026
4e9941a
fix
hiworldwzj Jan 2, 2026
c53cdf9
fix mtp deepseek
hiworldwzj Jan 2, 2026
0b74a3c
fix mtp mistral
hiworldwzj Jan 2, 2026
bfe3f98
fix
hiworldwzj Jan 2, 2026
a5f0069
fix mtp deepseek
hiworldwzj Jan 2, 2026
af9f50d
fix
hiworldwzj Jan 2, 2026
977a038
fix all weights
hiworldwzj Jan 2, 2026
87a214f
fix att sink weight
Jan 3, 2026
f36d861
fix embeding weight
hiworldwzj Jan 3, 2026
d43ced9
fix tpnorm params
hiworldwzj Jan 3, 2026
59085b6
fix
hiworldwzj Jan 3, 2026
1281301
fix diverset mtp only support no att mtp mode
hiworldwzj Jan 3, 2026
5986d33
fix cohere
hiworldwzj Jan 3, 2026
68724db
fix
hiworldwzj Jan 3, 2026
e7496bf
review fix all
hiworldwzj Jan 3, 2026
9766c38
fix prefill dp banlance feature
Jan 4, 2026
e126354
add test model acc sh
Jan 4, 2026
7d310e5
fix
Jan 4, 2026
e5335f1
fix sh
Jan 4, 2026
e3a0b4a
add test_qwen2.sh
Jan 4, 2026
8bee7ea
fix sh
Jan 4, 2026
76ff894
fix sh
Jan 4, 2026
00ad3b5
fix unittest
Jan 4, 2026
99fdb4c
fix vitnorm params
Jan 4, 2026
38e92dd
fix vl weight
Jan 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,12 @@ MTP 多预测参数

.. option:: --mtp_mode

支持的 mtp 模式,建议使用 deepseekv3_eagle获得更好的性能体验,可选值:
支持的 mtp 模式,建议使用 eagle_with_att获得更好的性能体验,可选值:

* ``deepseekv3_vanilla``
* ``deepseekv3_eagle``
* ``vanilla_with_att``
* ``eagle_with_att``
* ``vanilla_no_att``
* ``eagle_no_att``
* ``None``: 不启用 mtp(默认)

.. option:: --mtp_draft_model_dir
Expand Down
8 changes: 5 additions & 3 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,12 @@ MTP Multi-Prediction Parameters

.. option:: --mtp_mode

Supported mtp modes, it is recommended to use deepseekv3_eagle for better performance, optional values:
Supported mtp modes, it is recommended to use eagle_with_att for better performance, optional values:

* ``deepseekv3_vanilla``
* ``deepseekv3_eagle``
* ``vanilla_with_att``
* ``eagle_with_att``
* ``vanilla_no_att``
* ``eagle_no_att``
* ``None``: Do not enable mtp (default)

.. option:: --mtp_draft_model_dir
Expand Down
136 changes: 68 additions & 68 deletions lightllm/common/basemodel/basemodel.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class ModelInput:
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
# mtp_draft_input_hiddens 用于模型 mtp 模式下
# 的 draft 模型的输入
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def to_cuda(self):
if self.input_ids is not None:
Expand Down Expand Up @@ -90,12 +90,12 @@ class ModelOutput:
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。

# deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
# mtp_main_output_hiddens 用于在mtp模式下,llm main model
# 输出最后一层的hidden state 状态用于 draft 模型的 mtp_draft_input_hiddens
# 输入
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None
mtp_main_output_hiddens: Optional[torch.Tensor] = None

def to_no_ref_tensor(self):
self.logits = tensor_to_no_ref_tensor(self.logits)
if self.deepseekv3_mtp_main_output_hiddens is not None:
self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens)
if self.mtp_main_output_hiddens is not None:
self.mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.mtp_main_output_hiddens)
51 changes: 20 additions & 31 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def find_closest_graph_batch_size(self, batch_size):
else:
return None

def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo):
def _capture_decode(self, decode_func, infer_state: InferStateInfo):
dist_group: CustomProcessGroup = infer_state.dist_group
graph_obj = torch.cuda.CUDAGraph()
input_ids = infer_state.input_ids
batch_size = input_ids.shape[0]
infer_state.max_len_in_batch = self.graph_max_len_in_batch
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
Expand All @@ -78,27 +79,26 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
# 中的 tensor。
for _ in range(1):
torch.cuda.synchronize()
decode_func(input_ids, copy.copy(infer_state))
decode_func(copy.copy(infer_state))
torch.cuda.synchronize()

with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output = decode_func(input_ids, infer_state)
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
model_output = decode_func(infer_state)
self.graph[batch_size] = (graph_obj, infer_state, model_output)
graph_obj.replay()
return model_output

def _capture_decode_overlap(
self,
decode_func,
input_ids: torch.Tensor,
infer_state: InferStateInfo,
input_ids1: torch.Tensor,
infer_state1: InferStateInfo,
):
dist_group: CustomProcessGroup = infer_state.dist_group
dist_group1 = infer_state1.dist_group
graph_obj = torch.cuda.CUDAGraph()
input_ids = infer_state.input_ids
batch_size = input_ids.shape[0]
infer_state.max_len_in_batch = self.graph_max_len_in_batch
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
Expand All @@ -107,17 +107,15 @@ def _capture_decode_overlap(
# warmup
for _ in range(1):
torch.cuda.synchronize()
decode_func(input_ids, copy.copy(infer_state), input_ids1, copy.copy(infer_state1))
decode_func(copy.copy(infer_state), copy.copy(infer_state1))
torch.cuda.synchronize()
with lightllm_capture_graph(dist_group1):
with lightllm_capture_graph(dist_group):
with torch.cuda.graph(graph_obj, pool=self.mempool):
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
model_output, model_output1 = decode_func(infer_state, infer_state1)
self.graph[batch_size] = (
graph_obj,
input_ids,
infer_state,
input_ids1,
infer_state1,
model_output,
model_output1,
Expand All @@ -128,59 +126,50 @@ def _capture_decode_overlap(
def capture_decode(
self,
decode_func,
input_ids: torch.Tensor,
infer_state: InferStateInfo,
input_ids1: Optional[torch.Tensor] = None,
infer_state1: Optional[torch.Tensor] = None,
infer_state1: Optional[InferStateInfo] = None,
):
"""
Capture the cuda graph for the decoding stage.
input_ids1 and infer_state1 is used for the overlap.
"""
if self.enable_decode_microbatch_overlap:
return self._capture_decode_overlap(decode_func, input_ids, infer_state, input_ids1, infer_state1)
return self._capture_decode_overlap(decode_func, infer_state, infer_state1)
else:
assert input_ids1 is None and infer_state1 is None
return self._capture_decode(decode_func, input_ids, infer_state)
assert infer_state1 is None
return self._capture_decode(decode_func, infer_state)

def _replay(self, input_ids: torch.Tensor, infer_state: InferStateInfo):
batch_size = input_ids.shape[0]
graph_obj, graph_input_ids, graph_infer_state, graph_output = self.graph[batch_size]
graph_input_ids.copy_(input_ids)
def _replay(self, infer_state: InferStateInfo):
batch_size = infer_state.input_ids.shape[0]
graph_obj, graph_infer_state, graph_output = self.graph[batch_size]
graph_infer_state.copy_for_cuda_graph(infer_state)
graph_obj.replay()
return graph_output

def _replay_overlap(
self,
input_ids: torch.Tensor,
infer_state: InferStateInfo,
input_ids1: torch.Tensor,
infer_state1: InferStateInfo,
):
batch_size = input_ids.shape[0]
batch_size = infer_state.input_ids.shape[0]
(
graph_obj,
graph_input_ids,
graph_infer_state,
graph_input_ids1,
graph_infer_state1,
graph_model_output,
graph_model_output1,
) = self.graph[batch_size]
graph_input_ids.copy_(input_ids)
graph_infer_state.copy_for_cuda_graph(infer_state)
graph_input_ids1.copy_(input_ids1)
graph_infer_state1.copy_for_cuda_graph(infer_state1)
graph_obj.replay()
return graph_model_output, graph_model_output1

def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
def replay(self, infer_state, infer_state1=None):
if self.enable_decode_microbatch_overlap:
return self._replay_overlap(input_ids, infer_state, input_ids1, infer_state1)
return self._replay_overlap(infer_state, infer_state1)
else:
assert input_ids1 is None and infer_state1 is None
return self._replay(input_ids, infer_state)
assert infer_state1 is None
return self._replay(infer_state)

@torch.no_grad()
def warmup(self, model):
Expand Down
15 changes: 10 additions & 5 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class InferStateInfo:
"""

def __init__(self):
self.input_ids: torch.Tensor = None
self.batch_size: int = None
self.total_token_num: int = None
self.b_req_idx: torch.Tensor = None
Expand Down Expand Up @@ -71,10 +72,10 @@ def __init__(self):
# inferstate的基类中,但是为了代码的简洁和方便,都放在基类中
# 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。

# deepseekv3 mtp draft model 使用的额外输入参数,
# 在开启 mtp_mode == deepseekv3 时,mtp draft model
# mtp draft model 使用的额外输入参数,
# 在开启 mtp_mode 时,mtp draft model
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
self.mtp_draft_input_hiddens: Optional[torch.Tensor] = None

# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
Expand All @@ -88,7 +89,8 @@ def __init__(self):
self.dp_output_split_sizes: List[List[int]] = None
self.dp_input_split_sizes: List[List[int]] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
def init_some_extra_state(self, model):

if self.is_prefill:
(
self.b_q_seq_len,
Expand All @@ -97,7 +99,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.b1_cu_kv_seq_len,
self.position_ids,
) = gen_prefill_params(
input_token_num=input_ids.shape[0],
input_token_num=self.input_ids.shape[0],
b_ready_cache_len=self.b_ready_cache_len,
b_seq_len=self.b_seq_len,
)
Expand Down Expand Up @@ -211,6 +213,9 @@ def prefill_dp_balance(self, input_ids: torch.Tensor):

self.position_sin = self._all_to_all_balance_get(self.position_sin)

self._unbalance_input_ids = self.input_ids
self.input_ids = new_input_ids

return new_input_ids

def _all_to_all_balance_get(self, data: torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class PreLayerInferTpl(PreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.eps_ = 1e-5
self.vob_start_id_ = -1
self.vob_end_id_ = -1
return

def _norm(self, input, infer_state, layer_weight) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
COLMMWeight,
ROWBMMWeight,
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
from .fused_moe_weight_ep import FusedMoeWeightEP
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
from .att_sink_weight import TpAttSinkWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class TpAttSinkWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type):
super().__init__()
self.weight_name = weight_name
self.data_type_ = data_type
self.weight: torch.Tensor = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
if self.weight_name not in weights or self.weight is not None:
return

t_weight = weights[self.weight_name]
start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight)
self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
return self.weight is not None
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from abc import ABC, abstractmethod
from typing import Dict
from typing import Dict, Tuple
from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp, get_current_device_id


Expand Down Expand Up @@ -29,3 +29,30 @@ def load_hf_weights(self, weights):

def verify_load(self) -> bool:
raise NotImplementedError("verify_load must implement this method")

def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]:
"""
Docstring for _get_head_tp_split_params,
一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。
:param self: Description
:param weight: Description
:type weight: torch.Tensor
:return: Description
:rtype: Tuple[int, int]
"""
assert weight.ndim == 2

all_head_num = weight.shape[0]
tp_head_num = all_head_num // self.tp_world_size_

if tp_head_num > 0:
start_head_index = self.tp_rank_ * tp_head_num
end_head_index = (self.tp_rank_ + 1) * tp_head_num
else:
# 当 tp_world_size 大于 all_head_num 时的特殊处理
scale_size = self.tp_world_size_ // all_head_num
assert self.tp_world_size_ % all_head_num == 0
start_head_index = self.tp_rank_ // scale_size
end_head_index = start_head_index + 1

return start_head_index, end_head_index
Loading