Skip to content

Commit 6551053

Browse files
committed
fix - fix rebase main kv cache
1 parent 413e4f9 commit 6551053

6 files changed

Lines changed: 124 additions & 53 deletions

File tree

rtp_llm/cpp/cache/BlockPoolConfigHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ class BlockPoolConfigHelper {
125125
cfg.dtype = cache_config.dtype;
126126
cfg.local_head_num_kv = spec->local_head_num_kv;
127127
cfg.enable_hybrid_attention = enable_hybrid_attention;
128+
// Scale 3D layout for MLA and indexer; KV 3D only for MLA (concat_and_cache_mla)
129+
cfg.is_mla = cache_config.use_mla || cache_config.is_sparse;
130+
cfg.use_mla = cache_config.use_mla;
131+
cfg.seq_size_per_block = static_cast<size_t>(cache_config.seq_size_per_block);
128132

129133
cfg.kv_block_pool_size_bytes =
130134
static_cast<size_t>(layer_num) * static_cast<size_t>(cfg.block_num) * cfg.kv_block_stride_bytes;

rtp_llm/cpp/cache/CacheConfig.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ struct CacheConfig {
2626
rtp_llm::DataType dtype;
2727
uint32_t layer_num; // the number of main model layers
2828
uint32_t layer_all_num; // the number of all layers including mtp modules
29-
bool use_mla = false;
29+
bool use_mla = false;
30+
bool is_sparse = false;
3031

3132
// Block configuration
3233
uint32_t block_num;

rtp_llm/cpp/cache/MLAKVCacheSpec.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@ struct MLAKVCacheSpec: public KVCacheSpec {
2626
}
2727

2828
size_t block_size() const override {
29-
return local_head_num_kv * (kv_lora_rank + rope_head_dim) * seq_size_per_block;
29+
auto is_fp8 = (dtype == DataType::TYPE_FP8_E4M3 || dtype == DataType::TYPE_FP8_E8M0);
30+
auto single_size = local_head_num_kv * (kv_lora_rank + rope_head_dim);
31+
if (is_fp8) {
32+
// First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
33+
// Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first
34+
// 128 float8_e4m3 values, the second for the next 128, and so on. Last 128 bytes: The "RoPE" part,
35+
// containing 64 bfloat16 values. This part is not quantized for accuracy.
36+
single_size = local_head_num_kv * (kv_lora_rank + kv_lora_rank / 128 * 4 + rope_head_dim * 2);
37+
}
38+
return single_size * seq_size_per_block;
3039
}
3140
size_t k_block_size() const override {
3241
return local_head_num_kv * kv_lora_rank * seq_size_per_block;

rtp_llm/cpp/cache/MemoryLayoutConfig.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ struct MemoryLayoutConfig {
3131
size_t k_scale_stride_bytes = 0;
3232
size_t v_scale_stride_bytes = 0;
3333

34-
bool is_mla = false;
34+
bool is_mla = false; // true for scale 3D layout (MLA or indexer)
35+
bool use_mla = false; // true for KV 3D layout (concat_and_cache_mla path only)
3536
// TODO(xinfei.sxf) rm head info
3637
size_t local_head_num_kv = 0;
3738
size_t seq_size_per_block = 0;

rtp_llm/cpp/cache/MemoryLayoutStrategy.cc

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,44 @@ void MemoryLayoutStrategy::processKVTensor(torch::Tensor& kv_cache_tensor) {
4848
.dtype(dataTypeToTorchType(data_type_))
4949
.device(kv_cache_tensor.device())
5050
.requires_grad(false);
51-
const int64_t kv_total_bytes = static_cast<int64_t>(kv_cache_tensor.nbytes());
52-
const int64_t kv_typed_numel = static_cast<int64_t>(static_cast<size_t>(kv_total_bytes) / kv_elem_size);
53-
torch::Tensor kv_cache_typed = torch::from_blob(kv_cache_tensor.data_ptr(), {kv_typed_numel}, kv_options);
54-
torch::Tensor reshaped_tensor = kv_cache_typed.reshape({static_cast<int64_t>(config_.layer_num),
55-
static_cast<int64_t>(config_.block_num),
56-
static_cast<int64_t>(kv_block_stride_elems)});
57-
58-
clearKVTensor(reshaped_tensor);
51+
const int64_t kv_total_bytes = static_cast<int64_t>(kv_cache_tensor.nbytes());
52+
const int64_t kv_typed_numel = static_cast<int64_t>(static_cast<size_t>(kv_total_bytes) / kv_elem_size);
53+
torch::Tensor kv_cache_typed = torch::from_blob(kv_cache_tensor.data_ptr(), {kv_typed_numel}, kv_options);
5954

6055
layer_kv_tensors_.clear();
6156
layer_kv_tensors_.reserve(config_.layer_num);
6257

63-
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
64-
torch::Tensor layer_tensor = reshaped_tensor[layer_id];
65-
layer_kv_tensors_.push_back(layer_tensor);
66-
67-
RTP_LLM_LOG_DEBUG("Layer %d tensor shape: [%s], elements: %ld",
68-
layer_id,
69-
torch::str(layer_tensor.sizes()).c_str(),
70-
layer_tensor.numel());
58+
if (config_.use_mla && config_.seq_size_per_block > 0) {
59+
// MLA: concat_and_cache_mla expects [num_blocks, block_size, stride] per layer
60+
RTP_LLM_CHECK_WITH_INFO(kv_block_stride_elems % config_.seq_size_per_block == 0,
61+
"kv_block_stride_elems=%zu must be divisible by seq_size_per_block=%zu for MLA",
62+
kv_block_stride_elems,
63+
config_.seq_size_per_block);
64+
const size_t stride_elems = kv_block_stride_elems / config_.seq_size_per_block;
65+
torch::Tensor reshaped_tensor = kv_cache_typed.reshape({static_cast<int64_t>(config_.layer_num),
66+
static_cast<int64_t>(config_.block_num),
67+
static_cast<int64_t>(config_.seq_size_per_block),
68+
static_cast<int64_t>(stride_elems)});
69+
clearKVTensor(reshaped_tensor);
70+
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
71+
layer_kv_tensors_.push_back(reshaped_tensor[layer_id]);
72+
RTP_LLM_LOG_DEBUG("Layer %d KV tensor shape: [%s] (MLA 3D)",
73+
layer_id,
74+
torch::str(layer_kv_tensors_[layer_id].sizes()).c_str());
75+
}
76+
} else {
77+
// MHA: [layer_num, block_num, kv_block_stride_elems], per layer 2D
78+
torch::Tensor reshaped_tensor = kv_cache_typed.reshape({static_cast<int64_t>(config_.layer_num),
79+
static_cast<int64_t>(config_.block_num),
80+
static_cast<int64_t>(kv_block_stride_elems)});
81+
clearKVTensor(reshaped_tensor);
82+
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
83+
layer_kv_tensors_.push_back(reshaped_tensor[layer_id]);
84+
RTP_LLM_LOG_DEBUG("Layer %d tensor shape: [%s], elements: %ld",
85+
layer_id,
86+
torch::str(layer_kv_tensors_[layer_id].sizes()).c_str(),
87+
layer_kv_tensors_[layer_id].numel());
88+
}
7189
}
7290
}
7391

@@ -76,43 +94,74 @@ bool MemoryLayoutStrategy::processScaleTensor(torch::Tensor& kv_scale_tensor) {
7694
return true;
7795
}
7896

79-
RTP_LLM_CHECK_WITH_INFO(kv_scale_tensor.numel() > 0, "kv cache scale tensor is empty, cannot split by layers");
80-
8197
RTP_LLM_CHECK_WITH_INFO(kv_scale_tensor.defined() && kv_scale_tensor.numel() > 0,
8298
"kv_scale_tensor must be provided when kv scale is enabled");
8399
RTP_LLM_CHECK_WITH_INFO(
84100
kv_scale_tensor.dim() == 1, "kv_scale_tensor must be 1-D, got dim=%ld", kv_scale_tensor.dim());
85-
RTP_LLM_CHECK_WITH_INFO(static_cast<size_t>(kv_scale_tensor.numel()) % sizeof(float) == 0,
86-
"kv_scale_tensor bytes must be divisible by sizeof(float): bytes=%ld",
87-
kv_scale_tensor.numel());
88-
RTP_LLM_CHECK_WITH_INFO(static_cast<size_t>(kv_scale_tensor.numel()) == config_.kv_scale_pool_size_bytes,
89-
"kv_scale_tensor bytes mismatch: got=%ld expect=%zu",
90-
kv_scale_tensor.numel(),
101+
RTP_LLM_CHECK_WITH_INFO(static_cast<size_t>(kv_scale_tensor.nbytes()) == config_.kv_scale_pool_size_bytes,
102+
"kv_scale_tensor bytes mismatch: got=%zu expect=%zu",
103+
static_cast<size_t>(kv_scale_tensor.nbytes()),
91104
config_.kv_scale_pool_size_bytes);
92-
RTP_LLM_CHECK_WITH_INFO(config_.kv_scale_stride_bytes % sizeof(float) == 0,
93-
"kv_scale_stride_bytes must be divisible by sizeof(float): stride_bytes=%zu",
94-
config_.kv_scale_stride_bytes);
95-
96-
const size_t scale_stride_elems = config_.kv_scale_stride_bytes / sizeof(float);
97-
auto scale_options =
98-
torch::TensorOptions().dtype(torch::kFloat32).device(kv_scale_tensor.device()).requires_grad(false);
99-
const int64_t scale_total_bytes = static_cast<int64_t>(kv_scale_tensor.nbytes());
100-
const int64_t scale_typed_numel = static_cast<int64_t>(static_cast<size_t>(scale_total_bytes) / sizeof(float));
101-
torch::Tensor kv_scale_typed = torch::from_blob(kv_scale_tensor.data_ptr(), {scale_typed_numel}, scale_options);
102-
torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape({static_cast<int64_t>(config_.layer_num),
103-
static_cast<int64_t>(config_.block_num),
104-
static_cast<int64_t>(scale_stride_elems)});
105-
clearScaleTensor(reshaped_scale_tensor);
106-
107-
layer_kv_scale_tensors_.clear();
108-
layer_kv_scale_tensors_.reserve(config_.layer_num);
109-
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
110-
layer_kv_scale_tensors_.push_back(reshaped_scale_tensor[layer_id]);
111-
112-
RTP_LLM_LOG_DEBUG("Layer %d scale tensor shape: [%s], elements: %ld",
113-
layer_id,
114-
torch::str(layer_kv_scale_tensors_[layer_id].sizes()).c_str(),
115-
layer_kv_scale_tensors_[layer_id].numel());
105+
106+
if (config_.is_mla) {
107+
// MLA: scale is byte-packed (UINT8), shape [layer_num, block_num, seq_size_per_block, bytes_per_token]
108+
RTP_LLM_CHECK_WITH_INFO(config_.seq_size_per_block > 0, "seq_size_per_block must be > 0 for MLA scale");
109+
RTP_LLM_CHECK_WITH_INFO(config_.kv_scale_stride_bytes % config_.seq_size_per_block == 0,
110+
"kv_scale_stride_bytes=%zu must be divisible by seq_size_per_block=%zu",
111+
config_.kv_scale_stride_bytes,
112+
config_.seq_size_per_block);
113+
114+
const size_t scale_bytes_per_token = config_.kv_scale_stride_bytes / config_.seq_size_per_block;
115+
auto scale_options =
116+
torch::TensorOptions().dtype(torch::kUInt8).device(kv_scale_tensor.device()).requires_grad(false);
117+
torch::Tensor kv_scale_typed = torch::from_blob(
118+
kv_scale_tensor.data_ptr(), {static_cast<int64_t>(config_.kv_scale_pool_size_bytes)}, scale_options);
119+
torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape({static_cast<int64_t>(config_.layer_num),
120+
static_cast<int64_t>(config_.block_num),
121+
static_cast<int64_t>(config_.seq_size_per_block),
122+
static_cast<int64_t>(scale_bytes_per_token)});
123+
reshaped_scale_tensor.fill_(0);
124+
125+
layer_kv_scale_tensors_.clear();
126+
layer_kv_scale_tensors_.reserve(config_.layer_num);
127+
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
128+
layer_kv_scale_tensors_.push_back(reshaped_scale_tensor[layer_id]);
129+
130+
RTP_LLM_LOG_DEBUG("Layer %d scale tensor shape: [%s], elements: %ld (MLA)",
131+
layer_id,
132+
torch::str(layer_kv_scale_tensors_[layer_id].sizes()).c_str(),
133+
layer_kv_scale_tensors_[layer_id].numel());
134+
}
135+
} else {
136+
// MHA: scale is FP32, shape [layer_num, block_num, scale_stride_elems] for kernel/model
137+
RTP_LLM_CHECK_WITH_INFO(static_cast<size_t>(kv_scale_tensor.numel()) % sizeof(float) == 0,
138+
"kv_scale_tensor bytes must be divisible by sizeof(float): bytes=%ld",
139+
kv_scale_tensor.numel());
140+
RTP_LLM_CHECK_WITH_INFO(config_.kv_scale_stride_bytes % sizeof(float) == 0,
141+
"kv_scale_stride_bytes must be divisible by sizeof(float): stride_bytes=%zu",
142+
config_.kv_scale_stride_bytes);
143+
144+
const size_t scale_stride_elems = config_.kv_scale_stride_bytes / sizeof(float);
145+
auto scale_options =
146+
torch::TensorOptions().dtype(torch::kFloat32).device(kv_scale_tensor.device()).requires_grad(false);
147+
const int64_t scale_total_bytes = static_cast<int64_t>(kv_scale_tensor.nbytes());
148+
const int64_t scale_typed_numel = static_cast<int64_t>(static_cast<size_t>(scale_total_bytes) / sizeof(float));
149+
torch::Tensor kv_scale_typed = torch::from_blob(kv_scale_tensor.data_ptr(), {scale_typed_numel}, scale_options);
150+
torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape({static_cast<int64_t>(config_.layer_num),
151+
static_cast<int64_t>(config_.block_num),
152+
static_cast<int64_t>(scale_stride_elems)});
153+
clearScaleTensor(reshaped_scale_tensor);
154+
155+
layer_kv_scale_tensors_.clear();
156+
layer_kv_scale_tensors_.reserve(config_.layer_num);
157+
for (uint32_t layer_id = 0; layer_id < config_.layer_num; ++layer_id) {
158+
layer_kv_scale_tensors_.push_back(reshaped_scale_tensor[layer_id]);
159+
160+
RTP_LLM_LOG_DEBUG("Layer %d scale tensor shape: [%s], elements: %ld",
161+
layer_id,
162+
torch::str(layer_kv_scale_tensors_[layer_id].sizes()).c_str(),
163+
layer_kv_scale_tensors_[layer_id].numel());
164+
}
116165
}
117166

118167
return true;

rtp_llm/cpp/cache/SingleConfigCreator.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ CacheConfig SingleConfigCreator::createSingleConfig(const ModelConfig& mod
2626
config.block_num = 0;
2727
config.seq_size_per_block = static_cast<uint32_t>(model_config.attn_config.tokens_per_block);
2828

29-
config.use_mla = model_config.attn_config.use_mla;
30-
config.dtype = dtype;
29+
config.use_mla = model_config.attn_config.use_mla;
30+
config.dtype = dtype;
31+
config.is_sparse = model_config.attn_config.is_sparse;
3132

3233
KVCacheSpecPtr spec;
3334
if (model_config.attn_config.use_mla && model_config.mla_ops_type != rtp_llm::MlaOpsType::MHA) {
@@ -47,6 +48,12 @@ CacheConfig SingleConfigCreator::createSingleConfig(const ModelConfig& mod
4748
config.kv_scale_stride_bytes = config.cache_specs[0]->scale_block_size_bytes();
4849
config.kv_scale_size_bytes = static_cast<size_t>(config.layer_num) * config.kv_scale_stride_bytes;
4950

51+
if (config.is_sparse) {
52+
auto indexer_dim = model_config.attn_config.indexer_head_dim;
53+
config.kv_scale_stride_bytes = (indexer_dim + indexer_dim / 128 * 4) * spec->seq_size_per_block;
54+
config.kv_scale_size_bytes = static_cast<size_t>(config.layer_num) * config.kv_scale_stride_bytes;
55+
}
56+
5057
config.block_size_bytes = config.kv_block_size_bytes + config.kv_scale_size_bytes;
5158
config.group_layer_num = layer_num; // only 1 group for SingleConfig
5259

0 commit comments

Comments
 (0)