Skip to content

DeepSeek MTP #913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 169 additions & 225 deletions lightllm/common/basemodel/basemodel.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from dataclasses import dataclass, field


@dataclass
class ModelInput:
batch_size: int
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
mem_indexes: torch.Tensor
b_req_idx: torch.Tensor
b_seq_len: torch.Tensor
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)
hidden_states: torch.Tensor = None


@dataclass
class ModelOutput:
logits: torch.tensor
hidden_states: torch.tensor
Comment on lines +22 to +23

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.Tensor for type hints instead of torch.tensor for consistency with PyTorch's official type hinting. While torch.tensor is a function to create tensors, torch.Tensor is the type. This is a minor point but improves consistency.

Suggested change
logits: torch.tensor
hidden_states: torch.tensor
logits: torch.Tensor
hidden_states: torch.Tensor

277 changes: 184 additions & 93 deletions lightllm/common/basemodel/cuda_graph.py

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Any
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm


class InferStateInfo:
Expand Down Expand Up @@ -54,6 +55,10 @@ def __init__(self):
self.max_q_seq_len: int = None
self.max_kv_seq_len: int = None

# Speculative decoding
self.spec_algo = SpeculativeDecodeAlgorithm.NONE
self.spec_info = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
(
Expand Down Expand Up @@ -82,10 +87,10 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
) = gen_decode_params(b_seq_len=self.b_seq_len)
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]

def copy_for_cuda_graph(self, new_infer_state):
def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
for attr_name, attr_value in vars(new_infer_state).items():
if isinstance(attr_value, torch.Tensor):
attr_ = getattr(self, attr_name, None)
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
attr_.copy_(attr_value, non_blocking=True)
attr_[: new_infer_state.batch_size].copy_(attr_value, non_blocking=True)
return
23 changes: 18 additions & 5 deletions lightllm/common/basemodel/layer_infer/cache_tensor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __del__(self):
class CudaGraphCacheTensorManager:
def __init__(self, cuda_graph_max_batch_size: int):
self.cuda_graph_max_batch_size = cuda_graph_max_batch_size
self.graph_out_tensor_dict: Dict[int, torch.Tensor] = {}
# Dict[graph_out_key, Dict[microbatch_index, tensor_chache]]
self.graph_out_tensor_dict: Dict[int, Dict[int, torch.Tensor]] = collections.defaultdict(dict)
self.managed_total_tensor_bytes = 0
return

Expand All @@ -56,6 +57,7 @@ def alloc_tensor_for_cuda_graph(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
assert microbatch_index in [0, 1]
if not is_graph_out:
Expand All @@ -66,13 +68,16 @@ def alloc_tensor_for_cuda_graph(
max_size = size // cur_batch_size * self.cuda_graph_max_batch_size

# graph out tensor, 只有一个, 不需要进行引用计数管理
if microbatch_index not in self.graph_out_tensor_dict:
microbatch_index_to_tensor_cache = self.graph_out_tensor_dict[graph_out_key]

if microbatch_index not in microbatch_index_to_tensor_cache:
graph_out_tensor = torch.empty((max_size,), dtype=data_type, device=device, requires_grad=False)
logger.info(f"pid {os.getpid()} cuda graph alloc graph out mem {shape} {data_type} {size} {max_size}")
self.managed_total_tensor_bytes += graph_out_tensor.element_size() * graph_out_tensor.numel()
logger.info(f"cuda graph managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
self.graph_out_tensor_dict[microbatch_index] = graph_out_tensor
return self.graph_out_tensor_dict[microbatch_index][0:size].view(shape)
microbatch_index_to_tensor_cache[microbatch_index] = graph_out_tensor

return self.graph_out_tensor_dict[graph_out_key][microbatch_index][0:size].view(shape)

class CacheTensorManager:
def __init__(self):
Expand Down Expand Up @@ -119,14 +124,21 @@ def alloc_tensor(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
# shape 类型转换
if isinstance(shape, list):
shape = torch.Size(shape)
# 是 cuda graph的时候,由cuda graph manager 接管
if self.is_cuda_graph:
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out, microbatch_index
self.cuda_graph_cur_batch_size,
shape,
data_type,
device,
is_graph_out,
microbatch_index,
graph_out_key,
)

# 回收可能消亡的 tensor
Expand Down Expand Up @@ -191,6 +203,7 @@ def alloc_tensor(
device: str = "cuda",
is_graph_out: bool = False,
microbatch_index: int = 0,
graph_out_key: int = 0,
) -> torch.Tensor:
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)

Expand Down
53 changes: 41 additions & 12 deletions lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,62 @@

import triton
import triton.language as tl
import copy


@triton.jit
def _fwd_kernel_copy_kv_index_to_req(
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
stride_req_to_token_b, stride_req_to_token_s
req_to_token_indexs, b_req_idx, b_seq_len, memindex, stride_req_to_token_b, stride_req_to_token_s
):
cur_index = tl.program_id(0)
cur_seq = tl.program_id(0)
cur_index = tl.program_id(1)
batch_size = tl.num_programs(1)
cur_req_idx = tl.load(b_req_idx + cur_index)
cur_token_index = tl.load(memindex + cur_index)
cur_seq_len = tl.load(b_seq_len + cur_index)
cur_token_index = tl.load(memindex + cur_index + batch_size * cur_seq)
cur_seq_len = tl.load(b_seq_len + cur_index) + cur_seq
dest_offset = req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (cur_seq_len - 1) * stride_req_to_token_s
tl.store(dest_offset, cur_token_index)
return


@torch.no_grad()
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):
seq_len = b_seq_len.shape[0]
assert b_seq_len.shape[0] == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
grid = (seq_len,)
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex, decode_len=1):
batch_size = b_seq_len.shape[0]
assert b_seq_len.shape[0] * decode_len == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
grid = (
decode_len,
batch_size,
)
num_warps = 1

_fwd_kernel_copy_kv_index_to_req[grid](
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
req_to_token_indexs,
b_req_idx,
b_seq_len,
memindex,
req_to_token_indexs.stride(0),
req_to_token_indexs.stride(1),
num_warps=num_warps,
num_stages=1,
)

return


if __name__ == "__main__":
for decode_len in [1, 2]:
max_request_num = 100
max_sequence_length = 1000
req_to_token_indexs = torch.zeros((max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda")
bs = 8
b_req_idx = torch.randint(low=0, high=max_request_num - 1, size=(bs,)).cuda()
b_seq_len = torch.randint(low=1, high=max_sequence_length, size=(bs,)).cuda()
memindex = torch.randint(low=0, high=10000, size=(bs * decode_len,)).cuda()
copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex, decode_len)

for i in range(bs):
for j in range(decode_len):
if req_to_token_indexs[b_req_idx[i]][b_seq_len[i] + j - 1] != memindex[j * bs + i]:
print("ERROR")
exit(1)

print("PASS")
35 changes: 35 additions & 0 deletions lightllm/common/spec_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from enum import IntEnum, auto


class SpeculativeDecodeAlgorithm(IntEnum):
NONE = auto()
MTP = auto()
MTP_MOUDLE = auto()

def is_none(self):
return self == SpeculativeDecodeAlgorithm.NONE

def is_mtp(self):
return self == SpeculativeDecodeAlgorithm.MTP

def is_mtp_module(self):
return self == SpeculativeDecodeAlgorithm.MTP_MOUDLE

@staticmethod
def from_string(name: str):
name_map = {
"MTP": SpeculativeDecodeAlgorithm.MTP,
"MTP_MOUDLE": SpeculativeDecodeAlgorithm.MTP_MOUDLE,
"NONE": SpeculativeDecodeAlgorithm.NONE,
}
if name is not None:
name = name.upper()
return name_map[name]

def decode_len(self):
if self == SpeculativeDecodeAlgorithm.NONE:
return 1
if self == SpeculativeDecodeAlgorithm.MTP:
return 2
if self == SpeculativeDecodeAlgorithm.MTP_MOUDLE:
return 2
1 change: 1 addition & 0 deletions lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch.distributed as dist
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm


class Deepseek2InferStateInfo(LlamaInferStateInfo):
Expand Down
Empty file.
36 changes: 36 additions & 0 deletions lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt

logger = init_logger(__name__)


class Deepseek3MTPMemoryManager(Deepseek2MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
# profile the max total token num if the size is None
self.profile_size(mem_fraction)

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size

self.can_use_mem_size = self.size

self._init_buffers(
self.size,
dtype,
head_num,
head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
Empty file.
70 changes: 70 additions & 0 deletions lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import torch
import torch.distributed as dist
import numpy as np

from lightllm.models.deepseek_mtp.layer_weights.pre_and_post_layer_weight import Deepseek3MTPPreAndPostLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.models.llama.triton_kernel.embedding import embedding
from lightllm.distributed.communication_op import all_reduce
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.distributed.communication_op import all_gather


class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer):
""" """

def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.eps_ = network_config["rms_norm_eps"]
self.hidden_size = network_config["hidden_size"]
return

def mtp_context_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
assert infer_state.spec_info is not None, "need spec info for mtp model."
tgt_embdings = infer_state.spec_info
assert input_embdings.shape[0] == tgt_embdings.shape[0]
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)

cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
infer_state.spec_info = None

ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)
return ans_logics

def mtp_token_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
assert infer_state.spec_info is not None, "need spec info for mtp model."
tgt_embdings = infer_state.spec_info
assert input_embdings.shape[0] == tgt_embdings.shape[0]
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)

cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)

ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)
return ans_logics

def context_forward(
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
return self.mtp_context_forward(input_embdings, infer_state, layer_weight)

def token_forward(
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Deepseek3MTPPreAndPostLayerWeight
):
input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
return self.mtp_token_forward(input_embdings, infer_state, layer_weight)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight


class Deepseek3MTPPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# 与DeepseekV3模型共享
self.wte_weight_ = None
self.lm_head_weight_ = None
return

def load_hf_weights(self, weights):
if "model.layers.0.eh_proj.weight" in weights:
self.eh_proj_weight_ = self._cuda(weights["model.layers.0.eh_proj.weight"]).t()
if "model.layers.0.enorm.weight" in weights:
self.enorm_weight_ = self._cuda(weights["model.layers.0.enorm.weight"])
if "model.layers.0.hnorm.weight" in weights:
self.hnorm_weight_ = self._cuda(weights["model.layers.0.hnorm.weight"])
if "model.layers.0.shared_head.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.layers.0.shared_head.norm.weight"])
return

def verify_load(self):
errors = "weights load not ok"
weights = [self.eh_proj_weight_, self.enorm_weight_, self.hnorm_weight_, self.final_norm_weight_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
Loading
Loading