-
Notifications
You must be signed in to change notification settings - Fork 259
DeepSeek MTP #913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shihaobai
wants to merge
27
commits into
main
Choose a base branch
from
bsh-MTP
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,945
−487
Open
DeepSeek MTP #913
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
788fe91
mtp
sufubao 3ae72be
fix
sufubao 84eced6
pd master health, tokens and server busy error
e23692f
improve the copy kv kernel
shihaobai 0b17ef3
fix mtp
aa63420
cudagraph fix and decode batch
ade501e
fix
d3a35d4
fix for graph
shihaobai baa5c53
reformat
shihaobai 99eee16
decode for mtp
shihaobai 182fc9c
share mem_index between draft and main
shihaobai 04a0652
mutli step mtp and dynamic_prompt cache for mtp
shihaobai 555ec8b
add static test for mtp
shihaobai 9d1bf5f
fix input/output of the other mode
shihaobai e9faeab
fix static dp+ep
shihaobai d78b950
fix share_head norm for mtp module
shihaobai 5717ce7
fix mtp norm and fix chunked
shihaobai bf96f4b
update test
shihaobai d67a53d
Deepseek MTP for dp backend (#923)
sufubao a0d6d33
del log.txt
hiworldwzj 933f0de
del test file.
hiworldwzj 18aee9f
add mpt_gen_token_ids.
hiworldwzj 429a2f5
cache tensor manager improve.
hiworldwzj 0f4f306
add padding cuda graph feature
hiworldwzj 17051a2
refactor continous mtp
699c52d
Merge branch 'bsh-MTP' of https://github.com/ModelTC/lightllm into bs…
1862551
Merge branch 'main' into bsh-MTP
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class ModelInput: | ||
batch_size: int | ||
total_token_num: int | ||
max_len_in_batch: int | ||
input_ids: torch.Tensor | ||
mem_indexes: torch.Tensor | ||
b_req_idx: torch.Tensor | ||
b_seq_len: torch.Tensor | ||
is_prefill: bool = False | ||
b_ready_cache_len: torch.Tensor = None | ||
multimodal_params: list = field(default_factory=list) | ||
hidden_states: torch.Tensor = None | ||
|
||
|
||
@dataclass | ||
class ModelOutput: | ||
logits: torch.tensor | ||
hidden_states: torch.tensor | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from enum import IntEnum, auto | ||
|
||
|
||
class SpeculativeDecodeAlgorithm(IntEnum): | ||
NONE = auto() | ||
MTP = auto() | ||
MTP_MOUDLE = auto() | ||
|
||
def is_none(self): | ||
return self == SpeculativeDecodeAlgorithm.NONE | ||
|
||
def is_mtp(self): | ||
return self == SpeculativeDecodeAlgorithm.MTP | ||
|
||
def is_mtp_module(self): | ||
return self == SpeculativeDecodeAlgorithm.MTP_MOUDLE | ||
|
||
@staticmethod | ||
def from_string(name: str): | ||
name_map = { | ||
"MTP": SpeculativeDecodeAlgorithm.MTP, | ||
"MTP_MOUDLE": SpeculativeDecodeAlgorithm.MTP_MOUDLE, | ||
"NONE": SpeculativeDecodeAlgorithm.NONE, | ||
} | ||
if name is not None: | ||
name = name.upper() | ||
return name_map[name] | ||
|
||
def decode_len(self): | ||
if self == SpeculativeDecodeAlgorithm.NONE: | ||
return 1 | ||
if self == SpeculativeDecodeAlgorithm.MTP: | ||
return 2 | ||
if self == SpeculativeDecodeAlgorithm.MTP_MOUDLE: | ||
return 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager | ||
from lightllm.utils.log_utils import init_logger | ||
from lightllm.utils.dist_utils import get_current_rank_in_node | ||
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class Deepseek3MTPMemoryManager(Deepseek2MemoryManager): | ||
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): | ||
self.size = size | ||
self.head_num = head_num | ||
self.head_dim = head_dim | ||
self.layer_num = layer_num | ||
self.always_copy = always_copy | ||
self.dtype = dtype | ||
# profile the max total token num if the size is None | ||
self.profile_size(mem_fraction) | ||
|
||
self.mem_state = torch.arange( | ||
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True | ||
) | ||
self.mark_start = 0 | ||
self.mark_end = self.size | ||
|
||
self.can_use_mem_size = self.size | ||
|
||
self._init_buffers( | ||
self.size, | ||
dtype, | ||
head_num, | ||
head_dim, | ||
layer_num, | ||
) | ||
self.HOLD_TOKEN_MEMINDEX = self.size |
Empty file.
70 changes: 70 additions & 0 deletions
70
lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
import torch | ||
import torch.distributed as dist | ||
import numpy as np | ||
|
||
from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight | ||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer | ||
from lightllm.utils.infer_utils import mark_cost_time | ||
from lightllm.models.llama.triton_kernel.embedding import embedding | ||
from lightllm.distributed.communication_op import all_reduce | ||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
from lightllm.distributed.communication_op import all_gather | ||
|
||
|
||
class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): | ||
""" """ | ||
|
||
def __init__(self, network_config, mode): | ||
super().__init__(network_config, mode) | ||
self.eps_ = network_config["rms_norm_eps"] | ||
self.hidden_size = network_config["hidden_size"] | ||
return | ||
|
||
def mtp_context_forward( | ||
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight | ||
): | ||
assert infer_state.spec_info is not None, "need spec info for mtp model." | ||
tgt_embdings = infer_state.spec_info | ||
assert input_embdings.shape[0] == tgt_embdings.shape[0] | ||
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) | ||
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) | ||
|
||
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) | ||
infer_state.spec_info = None | ||
|
||
ans_logics = self.alloc_tensor( | ||
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype | ||
) | ||
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) | ||
return ans_logics | ||
|
||
def mtp_token_forward( | ||
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight | ||
): | ||
assert infer_state.spec_info is not None, "need spec info for mtp model." | ||
tgt_embdings = infer_state.spec_info | ||
assert input_embdings.shape[0] == tgt_embdings.shape[0] | ||
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings) | ||
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings) | ||
|
||
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) | ||
|
||
ans_logics = self.alloc_tensor( | ||
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype | ||
) | ||
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics) | ||
return ans_logics | ||
|
||
def context_forward( | ||
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight | ||
): | ||
input_embdings = super().context_forward(input_ids, infer_state, layer_weight) | ||
return self.mtp_context_forward(input_embdings, infer_state, layer_weight) | ||
|
||
def token_forward( | ||
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight | ||
): | ||
input_embdings = super().token_forward(input_ids, infer_state, layer_weight) | ||
return self.mtp_token_forward(input_embdings, infer_state, layer_weight) |
Empty file.
29 changes: 29 additions & 0 deletions
29
lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import numpy as np | ||
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
|
||
|
||
class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): | ||
def __init__(self, data_type, network_config, mode): | ||
super().__init__(data_type, network_config, mode) | ||
# 与DeepseekV3模型共享 | ||
self.wte_weight_ = None | ||
self.lm_head_weight_ = None | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
if "model.layers.0.eh_proj.weight" in weights: | ||
self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t() | ||
if "model.layers.0.enorm.weight" in weights: | ||
self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"]) | ||
if "model.layers.0.hnorm.weight" in weights: | ||
self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"]) | ||
if "model.layers.0.shared_head.norm.weight" in weights: | ||
self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"]) | ||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
torch.Tensor
for type hints instead oftorch.tensor
for consistency with PyTorch's official type hinting. Whiletorch.tensor
is a function to create tensors,torch.Tensor
is the type. This is a minor point but improves consistency.