Skip to content

Commit 71c2807

Browse files
Nancheng-11LLLLKKKK
authored andcommitted
feature - adapt deepseek in model py
1 parent 8252afc commit 71c2807

26 files changed

+1928
-30
lines changed

open_source/deps/requirements_lock_torch_gpu_cuda12.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,9 @@ filelock==3.13.1 \
587587
flash-attn @ https://rtp-opensource.oss-cn-hangzhou.aliyuncs.com/rtp_llm/flash_attn-2.7.4.post1%2Bcu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl \
588588
--hash=sha256:bfdb0f290cc3d21d0810ba49a360ef91090f62cdc1345ec6900447e0d12d99af
589589
# via -r open_source/deps/requirements_torch_gpu_cuda12.txt
590+
flashinfer-python==0.2.5 \
591+
--hash=sha256:990aa090ef781783e76b836696ece4efd23956f72b5696d622fc619a61162aef
592+
# via -r open_source/deps/requirements_torch_gpu_cuda12.txt
590593
fonttools==4.53.1 \
591594
--hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \
592595
--hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \
@@ -1549,6 +1552,27 @@ networkx==3.3 \
15491552
--hash=sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9 \
15501553
--hash=sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2
15511554
# via torch
1555+
ninja==1.13.0 \
1556+
--hash=sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f \
1557+
--hash=sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988 \
1558+
--hash=sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9 \
1559+
--hash=sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630 \
1560+
--hash=sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db \
1561+
--hash=sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978 \
1562+
--hash=sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1 \
1563+
--hash=sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72 \
1564+
--hash=sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e \
1565+
--hash=sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2 \
1566+
--hash=sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9 \
1567+
--hash=sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714 \
1568+
--hash=sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200 \
1569+
--hash=sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c \
1570+
--hash=sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5 \
1571+
--hash=sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96 \
1572+
--hash=sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1 \
1573+
--hash=sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa \
1574+
--hash=sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e
1575+
# via flashinfer-python
15521576
numba==0.60.0 \
15531577
--hash=sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74 \
15541578
--hash=sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b \
@@ -1609,6 +1633,7 @@ numpy==1.24.1 \
16091633
# contourpy
16101634
# datasets
16111635
# decord
1636+
# flashinfer-python
16121637
# gekko
16131638
# librosa
16141639
# matplotlib
@@ -3172,6 +3197,7 @@ torch @ https://mirrors.aliyun.com/pytorch-wheels/cu126/torch-2.6.0%2Bcu126-cp31
31723197
# autoawq-kernels
31733198
# bitsandbytes
31743199
# flash-attn
3200+
# flashinfer-python
31753201
# peft
31763202
# sentence-transformers
31773203
# timm

open_source/deps/requirements_torch_gpu_cuda12.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ https://mirrors.aliyun.com/pytorch-wheels/cu126/torchvision-0.21.0%2Bcu126-cp310
77
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
88
tensorrt==10.3.0
99
tensorrt-cu12-bindings==10.3.0
10-
tensorrt-cu12-libs==10.3.0
10+
tensorrt-cu12-libs==10.3.0
11+
flashinfer-python==0.2.5

rtp_llm/BUILD

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ tensorrt = [
2222
"tensorrt-cu12-libs",
2323
]
2424

25+
flashinfer = [
26+
"flashinfer-python",
27+
]
28+
2529
xft_dep = select({
2630
"@//:using_arm": [],
2731
"//:xft_use_icx": [
@@ -98,7 +102,7 @@ requirement([
98102
"concurrent_log_handler",
99103
"aiter",
100104
"fastsafetensors",
101-
] + tensorrt)
105+
] + tensorrt + flashinfer)
102106

103107
filegroup(
104108
name = "cutlass_config",
@@ -210,7 +214,7 @@ py_library(
210214
"//rtp_llm/model_loader:loader",
211215
"//rtp_llm/models_py:models",
212216
] + arch_dep + select({
213-
"@//:using_cuda12": tensorrt,
217+
"@//:using_cuda12": tensorrt + flashinfer,
214218
"//conditions:default": []
215219
}) + select({
216220
"@//:using_arm": [],
@@ -540,6 +544,7 @@ whl_reqs = [
540544
"bitsandbytes",
541545
"portalocker",
542546
"concurrent_log_handler",
547+
"flashinfer-python==0.2.5",
543548
] + whl_deps() + platform_deps() + xft_dep
544549

545550
py_wheel(

rtp_llm/cpp/pybind/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ cc_library(
4747
deps = [
4848
"//rtp_llm/cpp/utils:core_utils",
4949
"//rtp_llm/cpp/devices:devices_base",
50+
"//rtp_llm/cpp/devices:device_utils",
5051
"@havenask//aios/autil:base64",
5152
"@havenask//aios/autil:zlib",
5253
] + torch_deps() + select_py_bindings(),

rtp_llm/models/deepseek_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from rtp_llm.models.rotary_embedding.deepseek_rotary_embedding import (
2828
DeepseekV3YarnRotaryEmbedding,
2929
)
30+
from rtp_llm.models_py.model_desc.deepseek_v2 import DeepSeekV2Model
31+
from rtp_llm.models_py.model_desc.module_base import GptModelBase
3032
from rtp_llm.utils.model_weight import (
3133
CkptWeightInfo,
3234
W,
@@ -516,10 +518,14 @@ def _create_config(cls, ckpt_path: str):
516518
norm_type="rmsnorm",
517519
has_post_decoder_layernorm=True,
518520
)
519-
config.activation_type = "gated-silu"
521+
# config.activation_type = "gated-silu"
522+
config.activation_type = "SiGLU"
520523
DeepSeekV2._from_hf(config, ckpt_path)
521524
return config
522525

526+
def _create_python_model(self) -> Optional[GptModelBase]:
527+
self.py_model = DeepSeekV2Model(self.config, self.weight)
528+
523529
@staticmethod
524530
def _from_hf(config: GptInitModelParameters, ckpt_path: str):
525531
config_path = os.path.join(ckpt_path, "config.json")

rtp_llm/models_py/BUILD

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ py_library(
3333
visibility = ["//visibility:public"],
3434
)
3535

36+
# flashinfer-python is only available for CUDA12
37+
flashinfer = ["flashinfer-python"]
38+
requirement(flashinfer)
39+
3640
py_library(
3741
name = "modules",
3842
srcs = glob([
@@ -43,7 +47,10 @@ py_library(
4347
":utils",
4448
":kernels",
4549
":distributed",
46-
],
50+
] + select({
51+
"@//:using_cuda12": flashinfer,
52+
"//conditions:default": [],
53+
}),
4754
visibility = ["//visibility:public"],
4855
)
4956

rtp_llm/models_py/bindings/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ cc_library(
77
"OpDefs.h",
88
"OpDefsUtils.h",
99
"ParamsBase.h",
10+
"MlaParamsBase.h",
1011
],
1112
srcs = [
1213
"OpDefs.cc",
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
#include <memory>
3+
#include <torch/extension.h>
4+
#include "rtp_llm/models_py/bindings/OpDefs.h"
5+
6+
namespace rtp_llm {
7+
8+
class MlaParamsBase {
9+
public:
10+
virtual ~MlaParamsBase() = default;
11+
torch_ext::MlaParams fillParams(torch::Tensor t_prefix_lengths,
12+
torch::Tensor t_sequence_lengths,
13+
torch::Tensor t_input_lengths,
14+
torch::Tensor t_kv_cache_block_id_host,
15+
int seq_size_per_block);
16+
};
17+
18+
} // namespace rtp_llm

rtp_llm/models_py/bindings/OpDefs.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@
33
namespace torch_ext {
44

55
void registerPyOpDefs(pybind11::module& m) {
6+
pybind11::class_<MlaParams>(m, "MlaParams")
7+
.def(pybind11::init<>())
8+
.def_readonly("batch_indice", &MlaParams::batch_indice)
9+
.def_readonly("positions", &MlaParams::positions)
10+
.def_readonly("paged_kv_last_page_len", &MlaParams::paged_kv_last_page_len)
11+
.def_readonly("kvlen", &MlaParams::kvlen)
12+
.def_readonly("page_indice", &MlaParams::page_indice)
13+
.def_readonly("page_indptr", &MlaParams::page_indptr)
14+
.def_readonly("qo_indptr", &MlaParams::qo_indptr);
15+
616
pybind11::class_<KVCache>(m, "KVCache")
717
.def(pybind11::init<>())
8-
.def_readonly("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor")
18+
.def_readwrite("k_cache_base", &KVCache::k_cache_base, "Key cache base tensor")
919
.def_readonly("v_cache_base", &KVCache::v_cache_base, "Value cache base tensor")
1020
.def_readonly("k_scale_base", &KVCache::k_scale_base, "Key cache scale tensor")
1121
.def_readonly("v_scale_base", &KVCache::v_scale_base, "Value cache scale tensor")
@@ -43,12 +53,12 @@ void registerPyOpDefs(pybind11::module& m) {
4353

4454
pybind11::class_<PyAttentionInputs>(m, "PyAttentionInputs")
4555
.def(pybind11::init<>())
46-
.def_readonly("is_prefill", &PyAttentionInputs::is_prefill)
47-
.def_readonly("prefix_lengths", &PyAttentionInputs::prefix_lengths)
48-
.def_readonly("sequence_lengths", &PyAttentionInputs::sequence_lengths)
49-
.def_readonly("input_lengths", &PyAttentionInputs::input_lengths)
56+
.def_readwrite("is_prefill", &PyAttentionInputs::is_prefill)
57+
.def_readwrite("prefix_lengths", &PyAttentionInputs::prefix_lengths)
58+
.def_readwrite("sequence_lengths", &PyAttentionInputs::sequence_lengths)
59+
.def_readwrite("input_lengths", &PyAttentionInputs::input_lengths)
5060
.def_readonly("cu_seqlens", &PyAttentionInputs::cu_seqlens)
51-
.def_readonly("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host)
61+
.def_readwrite("kv_cache_block_id_host", &PyAttentionInputs::kv_cache_block_id_host)
5262
.def_readonly("kv_cache_block_id_device", &PyAttentionInputs::kv_cache_block_id_device)
5363
.def_readonly("dtype", &PyAttentionInputs::dtype)
5464
.def_readonly("kv_block_offset", &PyAttentionInputs::kv_block_offset)

rtp_llm/models_py/bindings/OpDefs.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
#include "rtp_llm/models_py/bindings/ParamsBase.h"
88
#include "rtp_llm/cpp/utils/Logger.h"
99
namespace torch_ext {
10+
struct MlaParams {
11+
torch::Tensor batch_indice;
12+
torch::Tensor positions;
13+
torch::Tensor paged_kv_last_page_len;
14+
torch::Tensor kvlen;
15+
torch::Tensor page_indice;
16+
torch::Tensor page_indptr;
17+
torch::Tensor qo_indptr;
18+
};
1019

1120
struct KVCache {
1221
torch::Tensor k_cache_base;

0 commit comments

Comments
 (0)