Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions open_source/deps/requirements_lock_torch_gpu_cuda12_9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
# bazel run //open_source/deps:requirements_torch_gpu_cuda12_9.update
#
--index-url https://artlab.alibaba-inc.com/1/PYPI/simple/
--extra-index-url https://mirrors.aliyun.com/pypi/simple/

accelerate==0.25.0 \
Expand Down Expand Up @@ -464,6 +465,10 @@ click==8.3.0 \
# flashinfer-python
# typer
# uvicorn
cloudpickle==3.1.2 \
--hash=sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414 \
--hash=sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a
# via tilelang
concurrent-log-handler==0.9.28 \
--hash=sha256:4cc27969b3420239bd153779266f40d9713ece814e312b7aa753ce62c6eacdb8 \
--hash=sha256:65db25d05506651a61573937880789fc51c7555e7452303042b5a402fd78939c
Expand Down Expand Up @@ -602,6 +607,47 @@ cycler==0.12.1 \
--hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
--hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
# via matplotlib
cython==3.2.4 \
--hash=sha256:02cb0cc0f23b9874ad262d7d2b9560aed9c7e2df07b49b920bda6f2cc9cb505e \
--hash=sha256:03893c88299a2c868bb741ba6513357acd104e7c42265809fd58dce1456a36fc \
--hash=sha256:14dae483ca2838b287085ff98bc206abd7a597b7bb16939a092f8e84d9062842 \
--hash=sha256:1a64a112a34ec719b47c01395647e54fb4cf088a511613f9a3a5196694e8e382 \
--hash=sha256:28b1e363b024c4b8dcf52ff68125e635cb9cb4b0ba997d628f25e32543a71103 \
--hash=sha256:28e8075087a59756f2d059273184b8b639fe0f16cf17470bd91c39921bc154e0 \
--hash=sha256:2b1f12c0e4798293d2754e73cd6f35fa5bbdf072bdc14bc6fc442c059ef2d290 \
--hash=sha256:31a90b4a2c47bb6d56baeb926948348ec968e932c1ae2c53239164e3e8880ccf \
--hash=sha256:35ab0632186057406ec729374c737c37051d2eacad9d515d94e5a3b3e58a9b02 \
--hash=sha256:36bf3f5eb56d5281aafabecbaa6ed288bc11db87547bba4e1e52943ae6961ccf \
--hash=sha256:3b6e58f73a69230218d5381817850ce6d0da5bb7e87eb7d528c7027cbba40b06 \
--hash=sha256:3b8e62049afef9da931d55de82d8f46c9a147313b69d5ff6af6e9121d545ce7a \
--hash=sha256:55b6c44cd30821f0b25220ceba6fe636ede48981d2a41b9bbfe3c7902ce44ea7 \
--hash=sha256:55eb425c0baf1c8a46aa4424bc35b709db22f3c8a1de33adb3ecb8a3d54ea42a \
--hash=sha256:64d7f71be3dd6d6d4a4c575bb3a4674ea06d1e1e5e4cd1b9882a2bc40ed3c4c9 \
--hash=sha256:67922c9de058a0bfb72d2e75222c52d09395614108c68a76d9800f150296ddb3 \
--hash=sha256:6d5267f22b6451eb1e2e1b88f6f78a2c9c8733a6ddefd4520d3968d26b824581 \
--hash=sha256:72e6c0bbd978e2678b45351395f6825b9b8466095402eae293f4f7a73e9a3e85 \
--hash=sha256:732fc93bc33ae4b14f6afaca663b916c2fdd5dcbfad7114e17fb2434eeaea45c \
--hash=sha256:767b143704bdd08a563153448955935844e53b852e54afdc552b43902ed1e235 \
--hash=sha256:83266c356c13c68ffe658b4905279c993d8a5337bb0160fa90c8a3e297ea9a2e \
--hash=sha256:84226ecd313b233da27dc2eb3601b4f222b8209c3a7216d8733b031da1dc64e6 \
--hash=sha256:869487ea41d004f8b92171f42271fbfadb1ec03bede3158705d16cd570d6b891 \
--hash=sha256:90f43be4eaa6afd58ce20d970bb1657a3627c44e1760630b82aa256ba74b4acb \
--hash=sha256:983f9d2bb8a896e16fa68f2b37866ded35fa980195eefe62f764ddc5f9f5ef8e \
--hash=sha256:b362819d155fff1482575e804e43e3a8825332d32baa15245f4642022664a3f4 \
--hash=sha256:b84d4e3c875915545f77c88dba65ad3741afd2431e5cdee6c9a20cefe6905647 \
--hash=sha256:ca2399dc75796b785f74fb85c938254fa10c80272004d573c455f9123eceed86 \
--hash=sha256:ca578c9cb872c7ecffbe14815dc4590a003bc13339e90b2633540c7e1a252839 \
--hash=sha256:d4b4fd5332ab093131fa6172e8362f16adef3eac3179fd24bbdc392531cb82fa \
--hash=sha256:e3b5ac54e95f034bc7fb07313996d27cbf71abc17b229b186c1540942d2dc28e \
--hash=sha256:e65e4773021f8dc8532010b4fbebe782c77f9a0817e93886e518c93bd6a44e9d \
--hash=sha256:e71efb20048358a6b8ec604a0532961c50c067b5e63e345e2e359fff72feaee8 \
--hash=sha256:f136f379a4a54246facd0eb6f1ee15c3837cb314ce87b677582ec014db4c6845 \
--hash=sha256:f583cad7a7eed109f0babb5035e92d0c1260598f53add626a8568b57246b62c3 \
--hash=sha256:f81eda419b5ada7b197bbc3c5f4494090e3884521ffd75a3876c93fbf66c9ca8 \
--hash=sha256:f8d685a70bce39acc1d62ec3916d9b724b5ef665b0ce25ae55e1c85ee09747fc \
--hash=sha256:fdfdd753ad7e18e5092b413e9f542e8d28b8a08203126090e1c15f7783b7fe57 \
--hash=sha256:ff9af2134c05e3734064808db95b4dd7341a39af06e8945d05ea358e1741aaed
# via tilelang
dacite==1.9.2 \
--hash=sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0 \
--hash=sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09
Expand Down Expand Up @@ -1708,6 +1754,47 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
ml-dtypes==0.5.4 \
--hash=sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf \
--hash=sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d \
--hash=sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f \
--hash=sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483 \
--hash=sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7 \
--hash=sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22 \
--hash=sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6 \
--hash=sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175 \
--hash=sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270 \
--hash=sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1 \
--hash=sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2 \
--hash=sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1 \
--hash=sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2 \
--hash=sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298 \
--hash=sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d \
--hash=sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de \
--hash=sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049 \
--hash=sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d \
--hash=sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90 \
--hash=sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb \
--hash=sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465 \
--hash=sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24 \
--hash=sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453 \
--hash=sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56 \
--hash=sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48 \
--hash=sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff \
--hash=sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460 \
--hash=sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac \
--hash=sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900 \
--hash=sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef \
--hash=sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a \
--hash=sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c \
--hash=sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040 \
--hash=sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9 \
--hash=sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7 \
--hash=sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6 \
--hash=sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b \
--hash=sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0 \
--hash=sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328
# via tilelang
mpmath==1.3.0 \
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
Expand Down Expand Up @@ -2072,6 +2159,7 @@ numpy==1.26.4 \
# gekko
# librosa
# matplotlib
# ml-dtypes
# numba
# nvidia-cutlass-dsl-libs-base
# onnx
Expand All @@ -2082,6 +2170,7 @@ numpy==1.26.4 \
# sentence-transformers
# soundfile
# soxr
# tilelang
# torchvision
# transformers
nvidia-cublas-cu12==12.9.1.4 \
Expand Down Expand Up @@ -2823,6 +2912,7 @@ psutil==7.1.2 \
# -r open_source/deps/requirements_base.txt
# accelerate
# peft
# tilelang
py-spy==0.4.1 \
--hash=sha256:1fb8bf71ab8df95a95cc387deed6552934c50feef2cf6456bc06692a5508fd0c \
--hash=sha256:4972c21890b6814017e39ac233c22572c4a61fd874524ebc5ccab0f2237aee0a \
Expand Down Expand Up @@ -3772,6 +3862,13 @@ tiktoken==0.7.0 \
--hash=sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225 \
--hash=sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d
# via -r open_source/deps/requirements_base.txt
tilelang==0.1.6 \
--hash=sha256:33b4c39faafce7b9f9e113c7bb7bf912287fda70de54be572edda99861c1e42b \
--hash=sha256:3f4c8b86395a9137f055a380c50a99932c1ac310a96e5cd5ff86bcb95ba6df5c \
--hash=sha256:55d591ebd79baeb70d26829ee58eff83703ff3c2487e2c038551823800ff2fb7 \
--hash=sha256:a1539129f938b718e126bab3315975c03ef5fa4922532db48c671a48f41e0a1a \
--hash=sha256:fa3f49c5946fa4f72e4c1e80347d1fed9907c3dc2ea20510b33f73125399c2cb
# via -r open_source/deps/requirements_torch_gpu_cuda12_9.txt
timm==0.9.12 \
--hash=sha256:2a828afac5b710a80ec66d0f85807e171e342faf5c0703b33102d8aa206f19dc \
--hash=sha256:9121d1cf320f7f32490d893340fd33117bda0a0270eb8282dfd52ae5fd3e1af6
Expand Down Expand Up @@ -3809,6 +3906,7 @@ torch @ https://rtp-opensource.oss-cn-hangzhou.aliyuncs.com/rtp_llm/cu129/torch-
# flashinfer-python
# peft
# sentence-transformers
# tilelang
# timm
# torchvision
torchvision @ https://rtp-opensource.oss-cn-hangzhou.aliyuncs.com/rtp_llm/cu129/torchvision-0.23.0%2Bcu129-cp310-cp310-manylinux_2_28_x86_64.whl \
Expand All @@ -3827,6 +3925,7 @@ tqdm==4.67.1 \
# openai
# peft
# sentence-transformers
# tilelang
# transformers
transformers==4.51.2 \
--hash=sha256:5cb8259098b75ff4b5dd04533a318f7c4750d5307d9617e6d0593526432c404d \
Expand Down Expand Up @@ -3869,6 +3968,7 @@ typing-extensions==4.15.0 \
# openai
# pydantic
# pydantic-core
# tilelang
# torch
# uvicorn
tzdata==2025.2 \
Expand Down
1 change: 1 addition & 0 deletions open_source/deps/requirements_torch_gpu_cuda12_9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ mysql-connector-python
tensorrt==10.3.0
tensorrt-cu12-bindings==10.3.0
tabulate==0.9.0
tilelang==0.1.6
10 changes: 8 additions & 2 deletions rtp_llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ flashinfer_with_cache = [
"flashinfer-jit-cache"
]

flashmla = [
"fast-hadamard-transform",
"flash-mla",
"tilelang",
]

deep = [
"deep_gemm",
"deep_ep"
Expand Down Expand Up @@ -118,7 +124,7 @@ requirement([
"aiter",
"fastsafetensors",
"pybind11_stubgen",
] + tensorrt + deep + flashinfer_with_cache)
] + tensorrt + deep + flashinfer_with_cache + flashmla)

filegroup(
name = "cutlass_config",
Expand Down Expand Up @@ -242,7 +248,7 @@ py_library(
"@//:cuda_pre_12_9": tensorrt,
"//conditions:default": []
}) + select({
"@//:using_cuda12_9_x86": flashinfer_with_cache,
"@//:using_cuda12_9_x86": flashinfer_with_cache + flashmla,
"@//:using_cuda12_arm": flashinfer_with_cache,
"@//:cuda_pre_12_9": flashinfer,
"//conditions:default": []
Expand Down
2 changes: 1 addition & 1 deletion rtp_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .ops import *

# check triton version
# if triton.__version__ < "3.5":
# if triton.__version__ < "3.4":
# enable_compile_monitor()


Expand Down
5 changes: 5 additions & 0 deletions rtp_llm/cpp/cache/BlockPoolConfigHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BlockPoolConfigHelper {

config.memory_layouts.push_back(main_layout);

// Create MTP sub-model layouts
for (size_t i = 0; i < cache_config.mtp_sub_configs.size(); ++i) {
const auto& mtp_sub_config = cache_config.mtp_sub_configs[i];
RTP_LLM_CHECK_WITH_INFO(mtp_sub_config != nullptr, "mtp_sub_configs[%zu] is null", i);
Expand Down Expand Up @@ -124,6 +125,10 @@ class BlockPoolConfigHelper {
cfg.dtype = cache_config.dtype;
cfg.local_head_num_kv = spec->local_head_num_kv;
cfg.enable_hybrid_attention = enable_hybrid_attention;
// Scale 3D layout for MLA and indexer; KV 3D only for MLA (concat_and_cache_mla)
cfg.is_mla = cache_config.use_mla || cache_config.is_sparse;
cfg.use_mla = cache_config.use_mla;
cfg.seq_size_per_block = static_cast<size_t>(cache_config.seq_size_per_block);

cfg.kv_block_pool_size_bytes =
static_cast<size_t>(layer_num) * static_cast<size_t>(cfg.block_num) * cfg.kv_block_stride_bytes;
Expand Down
3 changes: 2 additions & 1 deletion rtp_llm/cpp/cache/CacheConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ struct CacheConfig {
rtp_llm::DataType dtype;
uint32_t layer_num; // the number of main model layers
uint32_t layer_all_num; // the number of all layers including mtp modules
bool use_mla = false;
bool use_mla = false;
bool is_sparse = false;

// Block configuration
uint32_t block_num;
Expand Down
11 changes: 10 additions & 1 deletion rtp_llm/cpp/cache/MLAKVCacheSpec.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ struct MLAKVCacheSpec: public KVCacheSpec {
}

size_t block_size() const override {
return local_head_num_kv * (kv_lora_rank + rope_head_dim) * seq_size_per_block;
auto is_fp8 = (dtype == DataType::TYPE_FP8_E4M3 || dtype == DataType::TYPE_FP8_E8M0);
auto single_size = local_head_num_kv * (kv_lora_rank + rope_head_dim);
if (is_fp8) {
// First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
// Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first
// 128 float8_e4m3 values, the second for the next 128, and so on. Last 128 bytes: The "RoPE" part,
// containing 64 bfloat16 values. This part is not quantized for accuracy.
single_size = local_head_num_kv * (kv_lora_rank + kv_lora_rank / 128 * 4 + rope_head_dim * 2);
}
return single_size * seq_size_per_block;
}
size_t k_block_size() const override {
return local_head_num_kv * kv_lora_rank * seq_size_per_block;
Expand Down
3 changes: 2 additions & 1 deletion rtp_llm/cpp/cache/MemoryLayoutConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct MemoryLayoutConfig {
size_t k_scale_stride_bytes = 0;
size_t v_scale_stride_bytes = 0;

bool is_mla = false;
bool is_mla = false; // true for scale 3D layout (MLA or indexer)
bool use_mla = false; // true for KV 3D layout (concat_and_cache_mla path only)
// TODO(xinfei.sxf) rm head info
size_t local_head_num_kv = 0;
size_t seq_size_per_block = 0;
Expand Down
Loading
Loading