Skip to content

DeepSeek MTP #913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
788fe91
mtp
sufubao May 8, 2025
3ae72be
fix
sufubao May 8, 2025
84eced6
pd master health, tokens and server busy error
May 14, 2025
e23692f
improve the copy kv kernel
shihaobai May 15, 2025
0b17ef3
fix mtp
May 16, 2025
aa63420
cudagraph fix and decode batch
May 16, 2025
ade501e
fix
May 19, 2025
d3a35d4
fix for graph
shihaobai May 21, 2025
baa5c53
reformat
shihaobai May 21, 2025
99eee16
decode for mtp
shihaobai May 22, 2025
182fc9c
share mem_index between draft and main
shihaobai May 22, 2025
04a0652
mutli step mtp and dynamic_prompt cache for mtp
shihaobai May 23, 2025
555ec8b
add static test for mtp
shihaobai May 26, 2025
9d1bf5f
fix input/output of the other mode
shihaobai May 27, 2025
e9faeab
fix static dp+ep
shihaobai May 28, 2025
d78b950
fix share_head norm for mtp module
shihaobai May 28, 2025
5717ce7
fix mtp norm and fix chunked
shihaobai May 28, 2025
bf96f4b
update test
shihaobai May 28, 2025
d67a53d
Deepseek MTP for dp backend (#923)
sufubao Jun 9, 2025
a0d6d33
del log.txt
hiworldwzj Jun 9, 2025
933f0de
del test file.
hiworldwzj Jun 9, 2025
18aee9f
add mpt_gen_token_ids.
hiworldwzj Jun 9, 2025
429a2f5
cache tensor manager improve.
hiworldwzj Jun 9, 2025
0f4f306
add padding cuda graph feature
hiworldwzj Jun 9, 2025
17051a2
refactor continous mtp
Jun 9, 2025
699c52d
Merge branch 'bsh-MTP' of https://github.com/ModelTC/lightllm into bs…
Jun 9, 2025
1862551
Merge branch 'main' into bsh-MTP
Jun 9, 2025
ada5f4e
fix graph
Jun 9, 2025
67875db
fix paddded cuda graph
hiworldwzj Jun 10, 2025
90598e7
remove mtp mem manager
Jun 10, 2025
8aeb88f
Merge branch 'bsh-MTP' of https://github.com/ModelTC/lightllm into bs…
Jun 10, 2025
2a8bf98
fix
hiworldwzj Jun 10, 2025
47560da
backup generic_pre_process
Jun 10, 2025
1da5671
remove prepare_draft_main_model_decode_inputs
Jun 10, 2025
e23efd6
simplify continous mtp
Jun 10, 2025
c15a899
backup dp
shihaobai Jun 10, 2025
6fd9bde
update
Jun 10, 2025
276d336
fix
Jun 10, 2025
a3a89a0
merge main
Jun 11, 2025
09be23f
fix prompt cache of mtp
Jun 11, 2025
0f8671a
fix all return of model.forward
Jun 11, 2025
9fdf20d
refactor dp mtp
Jun 11, 2025
6ccae80
fix overlap
Jun 11, 2025
c1c01e4
fix
hiworldwzj Jun 12, 2025
f9b5af8
fix mpt start params name
hiworldwzj Jun 12, 2025
98a355d
fix model_rpc.py
hiworldwzj Jun 12, 2025
1179575
fix init model
hiworldwzj Jun 12, 2025
fbc022a
fix
hiworldwzj Jun 12, 2025
65d0dd7
rename mtp acceptd token num.
hiworldwzj Jun 12, 2025
451c19c
rename mtp accepted token num.
hiworldwzj Jun 12, 2025
508d36f
fix
hiworldwzj Jun 12, 2025
6728200
fix prepare.
hiworldwzj Jun 12, 2025
cd153e9
fix
hiworldwzj Jun 12, 2025
c2c897a
fix preprecess
hiworldwzj Jun 12, 2025
9f0747e
remove overlap objs.
hiworldwzj Jun 12, 2025
0aa191f
add mtp_mode.
hiworldwzj Jun 12, 2025
9ab950f
remove more code.
hiworldwzj Jun 12, 2025
ceb1110
fix
hiworldwzj Jun 12, 2025
8b32aaa
improve post_handle to handle mtp mode.
hiworldwzj Jun 13, 2025
10abd3e
fix mtp input output define
hiworldwzj Jun 13, 2025
c67dad7
fix cuda graph
hiworldwzj Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
577 changes: 319 additions & 258 deletions lightllm/common/basemodel/basemodel.py

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelInput:
# 通用变量
batch_size: int
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
mem_indexes: torch.Tensor
b_req_idx: torch.Tensor
b_seq_len: torch.Tensor
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

# deepseekv3_mtp_draft_input_hiddens 用于 deepseekv3 模型 mtp 模式下
# 的 draft 模型的输入
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None


@dataclass
class ModelOutput:
# 通用变量
logits: torch.Tensor

# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
# 的输出变量。只在特殊的模型模式下才会具体使用和生效。

# deepseekv3_mtp_main_output_hiddens 用于在mtp模式下,llm main model
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
# 输入
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None
273 changes: 183 additions & 90 deletions lightllm/common/basemodel/cuda_graph.py

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.distributed import CustomProcessGroup
from typing import Tuple, Any
from typing import Tuple, Any, Optional
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params

Expand Down Expand Up @@ -54,6 +54,15 @@ def __init__(self):
self.max_q_seq_len: int = None
self.max_kv_seq_len: int = None

# 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在
# inferstate的基类中,但是为了代码的简洁和方便,都放在基类中
# 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。

# deepseekv3 mtp draft model 使用的额外输入参数,
# 在开启 mtp_mode == deepseekv3 时,mtp draft model
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
(
Expand Down Expand Up @@ -82,7 +91,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
) = gen_decode_params(b_seq_len=self.b_seq_len)
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]

def copy_for_cuda_graph(self, new_infer_state):
def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
for attr_name, attr_value in vars(new_infer_state).items():
if isinstance(attr_value, torch.Tensor):
attr_ = getattr(self, attr_name, None)
Expand Down
23 changes: 18 additions & 5 deletions lightllm/common/basemodel/layer_infer/cache_tensor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __del__(self):
class CudaGraphCacheTensorManager:
def __init__(self, cuda_graph_max_batch_size: int):
self.cuda_graph_max_batch_size = cuda_graph_max_batch_size
self.graph_out_tensor_dict: Dict[int, torch.Tensor] = {}
# Dict[graph_out_key, Dict[microbatch_index, tensor_chache]]
self.graph_out_tensor_dict: Dict[int, Dict[int, torch.Tensor]] = collections.defaultdict(dict)
self.managed_total_tensor_bytes = 0
return

Expand All @@ -56,6 +57,7 @@ def alloc_tensor_for_cuda_graph(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
assert microbatch_index in [0, 1]
if not is_graph_out:
Expand All @@ -66,13 +68,16 @@ def alloc_tensor_for_cuda_graph(
max_size = size // cur_batch_size * self.cuda_graph_max_batch_size

# graph out tensor, 只有一个, 不需要进行引用计数管理
if microbatch_index not in self.graph_out_tensor_dict:
microbatch_index_to_tensor_cache = self.graph_out_tensor_dict[graph_out_key]

if microbatch_index not in microbatch_index_to_tensor_cache:
graph_out_tensor = torch.empty((max_size,), dtype=data_type, device=device, requires_grad=False)
logger.info(f"pid {os.getpid()} cuda graph alloc graph out mem {shape} {data_type} {size} {max_size}")
self.managed_total_tensor_bytes += graph_out_tensor.element_size() * graph_out_tensor.numel()
logger.info(f"cuda graph managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
self.graph_out_tensor_dict[microbatch_index] = graph_out_tensor
return self.graph_out_tensor_dict[microbatch_index][0:size].view(shape)
microbatch_index_to_tensor_cache[microbatch_index] = graph_out_tensor

return self.graph_out_tensor_dict[graph_out_key][microbatch_index][0:size].view(shape)

class CacheTensorManager:
def __init__(self):
Expand Down Expand Up @@ -119,14 +124,21 @@ def alloc_tensor(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
# shape 类型转换
if isinstance(shape, list):
shape = torch.Size(shape)
# 是 cuda graph的时候,由cuda graph manager 接管
if self.is_cuda_graph:
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out, microbatch_index
self.cuda_graph_cur_batch_size,
shape,
data_type,
device,
is_graph_out,
microbatch_index,
graph_out_key,
)

# 回收可能消亡的 tensor
Expand Down Expand Up @@ -191,6 +203,7 @@ def alloc_tensor(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)

Expand Down
26 changes: 0 additions & 26 deletions lightllm/common/basemodel/microbatch_overlap_objs.py

This file was deleted.

2 changes: 1 addition & 1 deletion lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def update_reqs_token_counter(

if not self.enable_gpu_buffer_for_out_token_id_counter:
for req_obj, next_token_id in zip(req_objs, next_token_ids):
if req_obj.need_out_token_id_statistics:
if req_obj.need_out_token_id_statistics and req_obj.cur_output_len > 0:
req_obj.out_token_id_count[next_token_id] += 1
else:
b_req_idx = torch.tensor(
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch.distributed as dist
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm


class Deepseek2InferStateInfo(LlamaInferStateInfo):
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, kvargs):
self.enable_flashinfer = (
get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode
)
self.mtp_layer_num = get_env_start_args().spec_step
super().__init__(kvargs)
return

Expand Down Expand Up @@ -102,7 +103,7 @@ def _init_mem_manager(self):
dtype=self.data_type,
head_num=1,
head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"],
layer_num=self.config["num_hidden_layers"],
layer_num=self.config["num_hidden_layers"] + self.mtp_layer_num,
mem_fraction=self.mem_fraction,
)
return
Expand Down
Empty file.
Empty file.
60 changes: 60 additions & 0 deletions lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch

from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward


class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer):
""" """

def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.eps_ = network_config["rms_norm_eps"]
self.hidden_size = network_config["hidden_size"]
return

def _mtp_context_forward(
self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
assert input_embdings.shape[0] == tgt_embdings.shape[0]
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)

cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)

ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)
return ans_logics

def _mtp_token_forward(
self, input_embdings, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
assert input_embdings.shape[0] == tgt_embdings.shape[0]
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)

cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)

ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)
return ans_logics

def context_forward(
self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
return self._mtp_context_forward(input_embdings, infer_state, layer_weight)

def token_forward(
self, input_ids, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
return self._mtp_token_forward(input_embdings, infer_state, layer_weight)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight


class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# 与DeepseekV3模型共享
self.wte_weight_ = None
self.lm_head_weight_ = None
return

def load_hf_weights(self, weights):
if "model.layers.0.eh_proj.weight" in weights:
self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t()
if "model.layers.0.enorm.weight" in weights:
self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"])
if "model.layers.0.hnorm.weight" in weights:
self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"])
if "model.layers.0.shared_head.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"])
return

def verify_load(self):
errors = "weights load not ok"
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
46 changes: 46 additions & 0 deletions lightllm/models/deepseek_mtp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
from lightllm.models.deepseek_mtp.layer_infer.pre_layer_infer import Deepseek3MTPPreLayerInfer
from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight
from lightllm.common.basemodel import TpPartBaseModel


class Deepseek3MTPModel(Deepseek2TpPartModel):

pre_and_post_weight_class = Deepseek3MTPPreAndPostLayerWeight
pre_layer_infer_class = Deepseek3MTPPreLayerInfer

def __init__(self, kvargs: dict):
self._pre_init(kvargs)
super().__init__(kvargs)
return

def _pre_init(self, kvargs: dict):
self.main_model: TpPartBaseModel = kvargs.pop("main_model")
self.mem_layer_start = kvargs.pop("mem_layer_start", 0)
return

def _init_custom(self):
self._cos_cached = self.main_model._cos_cached
self._sin_cached = self.main_model._sin_cached
return

def _init_req_manager(self):
self.req_manager = self.main_model.req_manager
return

def _init_mem_manager(self):
self.mem_manager = self.main_model.mem_manager
return

def _init_weights(self):
super()._init_weights()
self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_
self.pre_post_weight.lm_head_weight_ = self.main_model.pre_post_weight.lm_head_weight_
return

def _init_infer_layer(self):
super()._init_infer_layer()
# reset the layer_num_ of the self.layers_infer
for layer in self.layers_infer:
layer.layer_num_ = layer.layer_num_ + self.mem_layer_start
return
2 changes: 1 addition & 1 deletion lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
assert False, "Error State"

def token_forward(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings_dtype = input_embdings.dtype
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = last_input.permute(1, 0).view(-1, token_num)
Expand Down
Loading