@@ -48,26 +48,44 @@ void MemoryLayoutStrategy::processKVTensor(torch::Tensor& kv_cache_tensor) {
4848 .dtype (dataTypeToTorchType (data_type_))
4949 .device (kv_cache_tensor.device ())
5050 .requires_grad (false );
51- const int64_t kv_total_bytes = static_cast <int64_t >(kv_cache_tensor.nbytes ());
52- const int64_t kv_typed_numel = static_cast <int64_t >(static_cast <size_t >(kv_total_bytes) / kv_elem_size);
53- torch::Tensor kv_cache_typed = torch::from_blob (kv_cache_tensor.data_ptr (), {kv_typed_numel}, kv_options);
54- torch::Tensor reshaped_tensor = kv_cache_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
55- static_cast <int64_t >(config_.block_num ),
56- static_cast <int64_t >(kv_block_stride_elems)});
57-
58- clearKVTensor (reshaped_tensor);
51+ const int64_t kv_total_bytes = static_cast <int64_t >(kv_cache_tensor.nbytes ());
52+ const int64_t kv_typed_numel = static_cast <int64_t >(static_cast <size_t >(kv_total_bytes) / kv_elem_size);
53+ torch::Tensor kv_cache_typed = torch::from_blob (kv_cache_tensor.data_ptr (), {kv_typed_numel}, kv_options);
5954
6055 layer_kv_tensors_.clear ();
6156 layer_kv_tensors_.reserve (config_.layer_num );
6257
63- for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
64- torch::Tensor layer_tensor = reshaped_tensor[layer_id];
65- layer_kv_tensors_.push_back (layer_tensor);
66-
67- RTP_LLM_LOG_DEBUG (" Layer %d tensor shape: [%s], elements: %ld" ,
68- layer_id,
69- torch::str (layer_tensor.sizes ()).c_str (),
70- layer_tensor.numel ());
58+ if (config_.use_mla && config_.seq_size_per_block > 0 ) {
59+ // MLA: concat_and_cache_mla expects [num_blocks, block_size, stride] per layer
60+ RTP_LLM_CHECK_WITH_INFO (kv_block_stride_elems % config_.seq_size_per_block == 0 ,
61+ " kv_block_stride_elems=%zu must be divisible by seq_size_per_block=%zu for MLA" ,
62+ kv_block_stride_elems,
63+ config_.seq_size_per_block );
64+ const size_t stride_elems = kv_block_stride_elems / config_.seq_size_per_block ;
65+ torch::Tensor reshaped_tensor = kv_cache_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
66+ static_cast <int64_t >(config_.block_num ),
67+ static_cast <int64_t >(config_.seq_size_per_block ),
68+ static_cast <int64_t >(stride_elems)});
69+ clearKVTensor (reshaped_tensor);
70+ for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
71+ layer_kv_tensors_.push_back (reshaped_tensor[layer_id]);
72+ RTP_LLM_LOG_DEBUG (" Layer %d KV tensor shape: [%s] (MLA 3D)" ,
73+ layer_id,
74+ torch::str (layer_kv_tensors_[layer_id].sizes ()).c_str ());
75+ }
76+ } else {
77+ // MHA: [layer_num, block_num, kv_block_stride_elems], per layer 2D
78+ torch::Tensor reshaped_tensor = kv_cache_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
79+ static_cast <int64_t >(config_.block_num ),
80+ static_cast <int64_t >(kv_block_stride_elems)});
81+ clearKVTensor (reshaped_tensor);
82+ for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
83+ layer_kv_tensors_.push_back (reshaped_tensor[layer_id]);
84+ RTP_LLM_LOG_DEBUG (" Layer %d tensor shape: [%s], elements: %ld" ,
85+ layer_id,
86+ torch::str (layer_kv_tensors_[layer_id].sizes ()).c_str (),
87+ layer_kv_tensors_[layer_id].numel ());
88+ }
7189 }
7290}
7391
@@ -76,43 +94,74 @@ bool MemoryLayoutStrategy::processScaleTensor(torch::Tensor& kv_scale_tensor) {
7694 return true ;
7795 }
7896
79- RTP_LLM_CHECK_WITH_INFO (kv_scale_tensor.numel () > 0 , " kv cache scale tensor is empty, cannot split by layers" );
80-
8197 RTP_LLM_CHECK_WITH_INFO (kv_scale_tensor.defined () && kv_scale_tensor.numel () > 0 ,
8298 " kv_scale_tensor must be provided when kv scale is enabled" );
8399 RTP_LLM_CHECK_WITH_INFO (
84100 kv_scale_tensor.dim () == 1 , " kv_scale_tensor must be 1-D, got dim=%ld" , kv_scale_tensor.dim ());
85- RTP_LLM_CHECK_WITH_INFO (static_cast <size_t >(kv_scale_tensor.numel ()) % sizeof (float ) == 0 ,
86- " kv_scale_tensor bytes must be divisible by sizeof(float): bytes=%ld" ,
87- kv_scale_tensor.numel ());
88- RTP_LLM_CHECK_WITH_INFO (static_cast <size_t >(kv_scale_tensor.numel ()) == config_.kv_scale_pool_size_bytes ,
89- " kv_scale_tensor bytes mismatch: got=%ld expect=%zu" ,
90- kv_scale_tensor.numel (),
101+ RTP_LLM_CHECK_WITH_INFO (static_cast <size_t >(kv_scale_tensor.nbytes ()) == config_.kv_scale_pool_size_bytes ,
102+ " kv_scale_tensor bytes mismatch: got=%zu expect=%zu" ,
103+ static_cast <size_t >(kv_scale_tensor.nbytes ()),
91104 config_.kv_scale_pool_size_bytes );
92- RTP_LLM_CHECK_WITH_INFO (config_.kv_scale_stride_bytes % sizeof (float ) == 0 ,
93- " kv_scale_stride_bytes must be divisible by sizeof(float): stride_bytes=%zu" ,
94- config_.kv_scale_stride_bytes );
95-
96- const size_t scale_stride_elems = config_.kv_scale_stride_bytes / sizeof (float );
97- auto scale_options =
98- torch::TensorOptions ().dtype (torch::kFloat32 ).device (kv_scale_tensor.device ()).requires_grad (false );
99- const int64_t scale_total_bytes = static_cast <int64_t >(kv_scale_tensor.nbytes ());
100- const int64_t scale_typed_numel = static_cast <int64_t >(static_cast <size_t >(scale_total_bytes) / sizeof (float ));
101- torch::Tensor kv_scale_typed = torch::from_blob (kv_scale_tensor.data_ptr (), {scale_typed_numel}, scale_options);
102- torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
103- static_cast <int64_t >(config_.block_num ),
104- static_cast <int64_t >(scale_stride_elems)});
105- clearScaleTensor (reshaped_scale_tensor);
106-
107- layer_kv_scale_tensors_.clear ();
108- layer_kv_scale_tensors_.reserve (config_.layer_num );
109- for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
110- layer_kv_scale_tensors_.push_back (reshaped_scale_tensor[layer_id]);
111-
112- RTP_LLM_LOG_DEBUG (" Layer %d scale tensor shape: [%s], elements: %ld" ,
113- layer_id,
114- torch::str (layer_kv_scale_tensors_[layer_id].sizes ()).c_str (),
115- layer_kv_scale_tensors_[layer_id].numel ());
105+
106+ if (config_.is_mla ) {
107+ // MLA: scale is byte-packed (UINT8), shape [layer_num, block_num, seq_size_per_block, bytes_per_token]
108+ RTP_LLM_CHECK_WITH_INFO (config_.seq_size_per_block > 0 , " seq_size_per_block must be > 0 for MLA scale" );
109+ RTP_LLM_CHECK_WITH_INFO (config_.kv_scale_stride_bytes % config_.seq_size_per_block == 0 ,
110+ " kv_scale_stride_bytes=%zu must be divisible by seq_size_per_block=%zu" ,
111+ config_.kv_scale_stride_bytes ,
112+ config_.seq_size_per_block );
113+
114+ const size_t scale_bytes_per_token = config_.kv_scale_stride_bytes / config_.seq_size_per_block ;
115+ auto scale_options =
116+ torch::TensorOptions ().dtype (torch::kUInt8 ).device (kv_scale_tensor.device ()).requires_grad (false );
117+ torch::Tensor kv_scale_typed = torch::from_blob (
118+ kv_scale_tensor.data_ptr (), {static_cast <int64_t >(config_.kv_scale_pool_size_bytes )}, scale_options);
119+ torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
120+ static_cast <int64_t >(config_.block_num ),
121+ static_cast <int64_t >(config_.seq_size_per_block ),
122+ static_cast <int64_t >(scale_bytes_per_token)});
123+ reshaped_scale_tensor.fill_ (0 );
124+
125+ layer_kv_scale_tensors_.clear ();
126+ layer_kv_scale_tensors_.reserve (config_.layer_num );
127+ for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
128+ layer_kv_scale_tensors_.push_back (reshaped_scale_tensor[layer_id]);
129+
130+ RTP_LLM_LOG_DEBUG (" Layer %d scale tensor shape: [%s], elements: %ld (MLA)" ,
131+ layer_id,
132+ torch::str (layer_kv_scale_tensors_[layer_id].sizes ()).c_str (),
133+ layer_kv_scale_tensors_[layer_id].numel ());
134+ }
135+ } else {
136+ // MHA: scale is FP32, shape [layer_num, block_num, scale_stride_elems] for kernel/model
137+ RTP_LLM_CHECK_WITH_INFO (static_cast <size_t >(kv_scale_tensor.numel ()) % sizeof (float ) == 0 ,
138+ " kv_scale_tensor bytes must be divisible by sizeof(float): bytes=%ld" ,
139+ kv_scale_tensor.numel ());
140+ RTP_LLM_CHECK_WITH_INFO (config_.kv_scale_stride_bytes % sizeof (float ) == 0 ,
141+ " kv_scale_stride_bytes must be divisible by sizeof(float): stride_bytes=%zu" ,
142+ config_.kv_scale_stride_bytes );
143+
144+ const size_t scale_stride_elems = config_.kv_scale_stride_bytes / sizeof (float );
145+ auto scale_options =
146+ torch::TensorOptions ().dtype (torch::kFloat32 ).device (kv_scale_tensor.device ()).requires_grad (false );
147+ const int64_t scale_total_bytes = static_cast <int64_t >(kv_scale_tensor.nbytes ());
148+ const int64_t scale_typed_numel = static_cast <int64_t >(static_cast <size_t >(scale_total_bytes) / sizeof (float ));
149+ torch::Tensor kv_scale_typed = torch::from_blob (kv_scale_tensor.data_ptr (), {scale_typed_numel}, scale_options);
150+ torch::Tensor reshaped_scale_tensor = kv_scale_typed.reshape ({static_cast <int64_t >(config_.layer_num ),
151+ static_cast <int64_t >(config_.block_num ),
152+ static_cast <int64_t >(scale_stride_elems)});
153+ clearScaleTensor (reshaped_scale_tensor);
154+
155+ layer_kv_scale_tensors_.clear ();
156+ layer_kv_scale_tensors_.reserve (config_.layer_num );
157+ for (uint32_t layer_id = 0 ; layer_id < config_.layer_num ; ++layer_id) {
158+ layer_kv_scale_tensors_.push_back (reshaped_scale_tensor[layer_id]);
159+
160+ RTP_LLM_LOG_DEBUG (" Layer %d scale tensor shape: [%s], elements: %ld" ,
161+ layer_id,
162+ torch::str (layer_kv_scale_tensors_[layer_id].sizes ()).c_str (),
163+ layer_kv_scale_tensors_[layer_id].numel ());
164+ }
116165 }
117166
118167 return true ;
0 commit comments