@@ -2844,7 +2844,8 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
2844
2844
// Function to convert and pack a single component
2845
2845
DEVICE_INLINE uint32_t
2846
2846
convertAndPack (float component, float inv_scale, float shift = 0.0 ) {
2847
- auto val = (component - shift) * inv_scale;
2847
+ // auto val = (component - shift) * inv_scale;
2848
+ auto val = fmaf (component, inv_scale, -shift * inv_scale);
2848
2849
val = fmaxf (val, -FP8_E4M3_MAX::value);
2849
2850
val = fminf (val, FP8_E4M3_MAX::value);
2850
2851
auto x = __nv_fp8_e4m3 (val);
@@ -2876,104 +2877,109 @@ __global__ void quantizeQKVPerHead(
2876
2877
float * const scale_k,
2877
2878
float * const scale_v,
2878
2879
float kv_multiplier = 64 .f) {
2879
- extern __shared__ float
2880
- shared_max[]; // Shared memory for block-level reduction
2881
2880
// Launch one warp per token. Each thread handles 4 elements.
2882
- // warps = B_T * HH
2883
- auto b_t_hh = blockIdx .x * blockDim .y + threadIdx .y ;
2881
+ // warps = B_T
2884
2882
auto N_KVH = cache_K.size (2 );
2885
2883
auto N_H = XQ_O.size (1 );
2886
-
2887
2884
auto B_T = XQ_O.size (0 );
2888
2885
// TODO: Support N_KVH > 1
2889
- CUDA_KERNEL_ASSERT (N_KVH == 1 );
2886
+ // CUDA_KERNEL_ASSERT(N_KVH == 1);
2890
2887
2891
- // TODO: Support amax_head of size [B, N_H] for decode case
2892
- // auto HH = (scale_k == nullptr) ? N_H : N_KVH * 2 + N_H;
2893
2888
auto HH = N_H + N_KVH * 2 ;
2889
+ auto maxHH = scale_k ? HH : N_H;
2894
2890
2895
- auto hh = b_t_hh % HH;
2896
- auto b_t = b_t_hh / HH;
2891
+ uint2 buffer;
2897
2892
2898
- // Boundary check
2899
- if (b_t >= B_T || hh >= HH)
2900
- return ;
2901
-
2902
- auto seqpos_t = varseq_seqpos[b_t ];
2903
- auto b = varseq_batch ? varseq_batch[b_t ] : b_t ;
2893
+ // warps_per_block = blockDim.y
2894
+ // warp_id = threadIdx.y
2895
+ // block_id = blockIdx.x
2904
2896
2897
+ // Calculate scaling factor
2898
+ constexpr float min_scaling_factor = 1 .0f / (FP8_E4M3_MAX::value * 512 .f );
2899
+ int b = 0 ;
2900
+ int last_b = -1 ;
2905
2901
int h = 0 ;
2906
2902
float * qparam = nullptr ;
2907
- QKV qkv;
2908
2903
at::Float8_e4m3fn* dst_row_q = nullptr ;
2909
- if (hh < N_H) {
2910
- qkv = QKV::Q;
2911
- h = hh;
2912
- qparam = scale_q + b * N_KVH + hh / (N_H / N_KVH);
2913
- dst_row_q = &XQ_O[b_t ][h][0 ];
2914
- } else if (hh < N_H + N_KVH) {
2915
- qkv = QKV::K;
2916
- h = hh - N_H;
2917
- if (scale_k == nullptr )
2918
- return ;
2919
- qparam = scale_k + b * N_KVH + h;
2920
- dst_row_q = &cache_K[b][seqpos_t ][h][0 ];
2921
- } else {
2922
- qkv = QKV::V;
2923
- h = hh - N_H - N_KVH;
2924
- if (scale_k == nullptr )
2925
- return ;
2926
- qparam = scale_v + b * N_KVH + h;
2927
- dst_row_q = &cache_V[b][seqpos_t ][h][0 ];
2928
- }
2929
- // Skip quantization if scale is pre-calculated for K/V
2930
- // as in decode and partial prefill cases
2931
- bool is_precalculated_qparam_b_t =
2932
- is_precalculated_qparam ? is_precalculated_qparam[b_t ] : true ;
2933
- if (qkv != QKV::Q && is_precalculated_qparam_b_t )
2934
- return ;
2935
-
2936
- CUDA_KERNEL_ASSERT (uintptr_t (qparam) % 4 == 0 );
2937
2904
float val = 0 ;
2938
- if (qkv == QKV::Q) {
2939
- // assert N_KVH == 1
2940
- for (int _h = 0 ; _h < N_H; _h++) {
2941
- val = fmaxf (val, xqkv_amax_head[b * HH + _h]);
2905
+ float inv_scale = 0 ;
2906
+
2907
+ uint d = 4 * threadIdx .x ;
2908
+
2909
+ auto b_t_start = blockIdx .x * blockDim .y + threadIdx .y ;
2910
+ for (int b_t = b_t_start; b_t < B_T; b_t += blockDim .y * gridDim .x ) {
2911
+ b = varseq_batch ? varseq_batch[b_t ] : b_t ;
2912
+ last_b = b_t == 0 ? -1 : varseq_batch[b_t - 1 ];
2913
+ {
2914
+ // Skip quantization of KV if scale is pre-calculated for K/V
2915
+ // as in decode and partial prefill cases
2916
+ bool is_precalculated_qparam_b_t =
2917
+ is_precalculated_qparam ? is_precalculated_qparam[b_t ] : true ;
2918
+ if (is_precalculated_qparam_b_t )
2919
+ maxHH = N_H;
2920
+ }
2921
+ val = 0 ;
2922
+ for (auto hh = 0 ; hh < N_H; hh++) {
2923
+ val = fmaxf (val, xqkv_amax_head[b * HH + hh]);
2924
+ }
2925
+
2926
+ for (auto hh = 0 ; hh < maxHH; hh++) {
2927
+ {
2928
+ at::BFloat16* src_row = &xqkv[(b_t * HH + hh + 0 ) * D_H];
2929
+ buffer = *reinterpret_cast <uint2 *>(&src_row[d]);
2930
+ val = (hh < N_H) ? val : xqkv_amax_head[b * HH + hh];
2931
+ }
2932
+
2933
+ {
2934
+ int seqpos_t = varseq_seqpos[b_t ];
2935
+ if (hh < N_H) {
2936
+ h = hh;
2937
+ qparam = scale_q + b * N_KVH + hh / (N_H / N_KVH);
2938
+ dst_row_q = &XQ_O[b_t ][h][0 ];
2939
+ val = val * 8 ;
2940
+ } else if (hh < N_H + N_KVH) {
2941
+ h = hh - N_H;
2942
+
2943
+ qparam = scale_k + b * N_KVH + h;
2944
+ dst_row_q = &cache_K[b][seqpos_t ][h][0 ];
2945
+ val = kv_multiplier * val;
2946
+ } else {
2947
+ h = hh - N_H - N_KVH;
2948
+
2949
+ qparam = scale_v + b * N_KVH + h;
2950
+ dst_row_q = &cache_V[b][seqpos_t ][h][0 ];
2951
+ val = kv_multiplier * val;
2952
+ }
2953
+ }
2954
+ {
2955
+ float scale = 0 ;
2956
+ val = fminf (val, 12000 );
2957
+ scale = fmaxf (val / FP8_E4M3_MAX::value, min_scaling_factor);
2958
+ bool is_first_token = (b_t == 0 || !varseq_batch || last_b != b);
2959
+ if (threadIdx .x == 0 && h == 0 && is_first_token) {
2960
+ *qparam = scale;
2961
+ }
2962
+ inv_scale = 1 / scale;
2963
+ }
2964
+
2965
+ {
2966
+ bfx4 src;
2967
+ fx4 dst;
2968
+ uint32_t x_bits[4 ];
2969
+ // Convert and pack data
2970
+ // 8 bytes are 4 elements of type bf16
2971
+ *reinterpret_cast <uint2 *>(&src) = buffer;
2972
+ dst = bfx4_to_fx4 (src);
2973
+ x_bits[0 ] = convertAndPack (dst.x , inv_scale);
2974
+ x_bits[1 ] = convertAndPack (dst.y , inv_scale);
2975
+ x_bits[2 ] = convertAndPack (dst.z , inv_scale);
2976
+ x_bits[3 ] = convertAndPack (dst.w , inv_scale);
2977
+ uint32_t packed = packComponents (x_bits);
2978
+ // CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[d]) % 4 == 0);
2979
+ *reinterpret_cast <uint32_t *>(&dst_row_q[d]) = packed;
2980
+ }
2942
2981
}
2943
- val *= 8 ; // [Experimental] improves accuracy
2944
- } else {
2945
- val = kv_multiplier * xqkv_amax_head[b * HH + hh];
2946
2982
}
2947
- // Calculate scaling factor
2948
- constexpr float min_scaling_factor = 1 .0f / (FP8_E4M3_MAX::value * 512 .f );
2949
- val = fminf (val, 12000 );
2950
- float scale = fmaxf (val / FP8_E4M3_MAX::value, min_scaling_factor);
2951
- bool is_first_token =
2952
- (b_t == 0 || !varseq_batch || varseq_batch[b_t - 1 ] != b);
2953
- if (threadIdx .x == 0 && h == 0 && is_first_token) {
2954
- *qparam = scale;
2955
- }
2956
- // Write scaling factor
2957
- auto inv_scale = 1 / scale;
2958
- CUDA_KERNEL_ASSERT (uintptr_t (&xqkv[b_t * HH * D_H + hh * D_H]) % 16 == 0 );
2959
- at::BFloat16* src_row = &xqkv[b_t * HH * D_H + hh * D_H];
2960
-
2961
- // Convert and pack data
2962
- bfx4 src;
2963
- fx4 dst;
2964
- uint32_t x_bits[4 ];
2965
- auto d = threadIdx .x * 4 ;
2966
- CUDA_KERNEL_ASSERT (uintptr_t (&src_row[d]) % 8 == 0 );
2967
- // 8 bytes are 4 elements of type bf16
2968
- *reinterpret_cast <uint2 *>(&src) = *reinterpret_cast <uint2 *>(&src_row[d]);
2969
- dst = bfx4_to_fx4 (src);
2970
- x_bits[0 ] = convertAndPack (dst.x , inv_scale);
2971
- x_bits[1 ] = convertAndPack (dst.y , inv_scale);
2972
- x_bits[2 ] = convertAndPack (dst.z , inv_scale);
2973
- x_bits[3 ] = convertAndPack (dst.w , inv_scale);
2974
- uint32_t packed = packComponents (x_bits);
2975
- CUDA_KERNEL_ASSERT (uintptr_t (&dst_row_q[d]) % 4 == 0 );
2976
- *reinterpret_cast <uint32_t *>(&dst_row_q[d]) = packed;
2977
2983
}
2978
2984
2979
2985
at::Tensor quantize_qkv_per_head (
@@ -2988,21 +2994,19 @@ at::Tensor quantize_qkv_per_head(
2988
2994
int64_t B, // Batch size
2989
2995
std::optional<at::Tensor> qparam_k,
2990
2996
std::optional<at::Tensor> qparam_v) {
2991
- auto B_T = XQ_O.size (0 );
2992
2997
auto N_KVH_L = cache_K.size (2 );
2993
- auto N_H_L = XQ_O.size (1 );
2994
- auto HH = N_H_L + N_KVH_L * 2 ;
2998
+
2995
2999
float * qparam_k_ptr = nullptr ;
2996
3000
float * qparam_v_ptr = nullptr ;
2997
3001
if (qparam_k.has_value ()) {
2998
3002
// prefill case
2999
3003
qparam_k_ptr = qparam_k.value ().data_ptr <float >();
3000
3004
qparam_v_ptr = qparam_v.value ().data_ptr <float >();
3001
3005
}
3002
- auto num_warps = B_T * HH;
3003
- dim3 block_size (kThreadsPerWarp , kWarpsPerBlock );
3004
- dim3 grid_size (cuda_calc_xblock_count (num_warps, kWarpsPerBlock ));
3005
3006
3007
+ constexpr int32_t kMaxBlocks = 512 ;
3008
+ dim3 block_size (kThreadsPerWarp , kWarpsPerBlock );
3009
+ dim3 grid_size (kMaxBlocks );
3006
3010
auto scale_q = at::zeros ({B, N_KVH_L}, XQ_O.options ().dtype (at::kFloat ));
3007
3011
float * const scale_q_ptr = scale_q.data_ptr <float >();
3008
3012
// Launch the kernel
0 commit comments