Skip to content

Commit aa2fe3d

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
optimization of perKVhead quantization (#4161)
Summary: Pull Request resolved: #4161 X-link: facebookresearch/FBGEMM#1241 y-sq noticed that for prefill chunk of 64k, the improvement in attention kernel runtime for local layers is cancelled out by around 0.4 ms overhead from the quantization kernel. https://docs.google.com/document/d/193GL7o5GMlpVlwEDVxqoDB6O85zDuS8A5-PUOSsZc1s/edit?tab=t.0#bookmark=id.zh92spta1uxw Before: BS =1 , Seqlen = 64k Elapsed Cycles cycle 530,268 Memory Throughput % 17.24 Duration us 392.93 After: ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.34 Elapsed Cycles cycle 192,884 Memory Throughput % 46.01 DRAM Throughput % 46.01 Duration us 143.23 L1/TEX Cache Throughput % 15.15 L2 Cache Throughput % 39.31 SM Active Cycles cycle 181,953.16 Compute (SM) Throughput % 71.92 ----------------------- ----------- ------------ Reviewed By: y-sq Differential Revision: D74924275 fbshipit-source-id: 94d3d0e6e2eac1106f31c63de4115962a7502894
1 parent 0e8d00b commit aa2fe3d

File tree

1 file changed

+93
-89
lines changed

1 file changed

+93
-89
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 93 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2844,7 +2844,8 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
28442844
// Function to convert and pack a single component
28452845
DEVICE_INLINE uint32_t
28462846
convertAndPack(float component, float inv_scale, float shift = 0.0) {
2847-
auto val = (component - shift) * inv_scale;
2847+
// auto val = (component - shift) * inv_scale;
2848+
auto val = fmaf(component, inv_scale, -shift * inv_scale);
28482849
val = fmaxf(val, -FP8_E4M3_MAX::value);
28492850
val = fminf(val, FP8_E4M3_MAX::value);
28502851
auto x = __nv_fp8_e4m3(val);
@@ -2876,104 +2877,109 @@ __global__ void quantizeQKVPerHead(
28762877
float* const scale_k,
28772878
float* const scale_v,
28782879
float kv_multiplier = 64.f) {
2879-
extern __shared__ float
2880-
shared_max[]; // Shared memory for block-level reduction
28812880
// Launch one warp per token. Each thread handles 4 elements.
2882-
// warps = B_T * HH
2883-
auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y;
2881+
// warps = B_T
28842882
auto N_KVH = cache_K.size(2);
28852883
auto N_H = XQ_O.size(1);
2886-
28872884
auto B_T = XQ_O.size(0);
28882885
// TODO: Support N_KVH > 1
2889-
CUDA_KERNEL_ASSERT(N_KVH == 1);
2886+
// CUDA_KERNEL_ASSERT(N_KVH == 1);
28902887
2891-
// TODO: Support amax_head of size [B, N_H] for decode case
2892-
// auto HH = (scale_k == nullptr) ? N_H : N_KVH * 2 + N_H;
28932888
auto HH = N_H + N_KVH * 2;
2889+
auto maxHH = scale_k ? HH : N_H;
28942890
2895-
auto hh = b_t_hh % HH;
2896-
auto b_t = b_t_hh / HH;
2891+
uint2 buffer;
28972892
2898-
// Boundary check
2899-
if (b_t >= B_T || hh >= HH)
2900-
return;
2901-
2902-
auto seqpos_t = varseq_seqpos[b_t];
2903-
auto b = varseq_batch ? varseq_batch[b_t] : b_t;
2893+
// warps_per_block = blockDim.y
2894+
// warp_id = threadIdx.y
2895+
// block_id = blockIdx.x
29042896
2897+
// Calculate scaling factor
2898+
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX::value * 512.f);
2899+
int b = 0;
2900+
int last_b = -1;
29052901
int h = 0;
29062902
float* qparam = nullptr;
2907-
QKV qkv;
29082903
at::Float8_e4m3fn* dst_row_q = nullptr;
2909-
if (hh < N_H) {
2910-
qkv = QKV::Q;
2911-
h = hh;
2912-
qparam = scale_q + b * N_KVH + hh / (N_H / N_KVH);
2913-
dst_row_q = &XQ_O[b_t][h][0];
2914-
} else if (hh < N_H + N_KVH) {
2915-
qkv = QKV::K;
2916-
h = hh - N_H;
2917-
if (scale_k == nullptr)
2918-
return;
2919-
qparam = scale_k + b * N_KVH + h;
2920-
dst_row_q = &cache_K[b][seqpos_t][h][0];
2921-
} else {
2922-
qkv = QKV::V;
2923-
h = hh - N_H - N_KVH;
2924-
if (scale_k == nullptr)
2925-
return;
2926-
qparam = scale_v + b * N_KVH + h;
2927-
dst_row_q = &cache_V[b][seqpos_t][h][0];
2928-
}
2929-
// Skip quantization if scale is pre-calculated for K/V
2930-
// as in decode and partial prefill cases
2931-
bool is_precalculated_qparam_b_t =
2932-
is_precalculated_qparam ? is_precalculated_qparam[b_t] : true;
2933-
if (qkv != QKV::Q && is_precalculated_qparam_b_t)
2934-
return;
2935-
2936-
CUDA_KERNEL_ASSERT(uintptr_t(qparam) % 4 == 0);
29372904
float val = 0;
2938-
if (qkv == QKV::Q) {
2939-
// assert N_KVH == 1
2940-
for (int _h = 0; _h < N_H; _h++) {
2941-
val = fmaxf(val, xqkv_amax_head[b * HH + _h]);
2905+
float inv_scale = 0;
2906+
2907+
uint d = 4 * threadIdx.x;
2908+
2909+
auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
2910+
for (int b_t = b_t_start; b_t < B_T; b_t += blockDim.y * gridDim.x) {
2911+
b = varseq_batch ? varseq_batch[b_t] : b_t;
2912+
last_b = b_t == 0 ? -1 : varseq_batch[b_t - 1];
2913+
{
2914+
// Skip quantization of KV if scale is pre-calculated for K/V
2915+
// as in decode and partial prefill cases
2916+
bool is_precalculated_qparam_b_t =
2917+
is_precalculated_qparam ? is_precalculated_qparam[b_t] : true;
2918+
if (is_precalculated_qparam_b_t)
2919+
maxHH = N_H;
2920+
}
2921+
val = 0;
2922+
for (auto hh = 0; hh < N_H; hh++) {
2923+
val = fmaxf(val, xqkv_amax_head[b * HH + hh]);
2924+
}
2925+
2926+
for (auto hh = 0; hh < maxHH; hh++) {
2927+
{
2928+
at::BFloat16* src_row = &xqkv[(b_t * HH + hh + 0) * D_H];
2929+
buffer = *reinterpret_cast<uint2*>(&src_row[d]);
2930+
val = (hh < N_H) ? val : xqkv_amax_head[b * HH + hh];
2931+
}
2932+
2933+
{
2934+
int seqpos_t = varseq_seqpos[b_t];
2935+
if (hh < N_H) {
2936+
h = hh;
2937+
qparam = scale_q + b * N_KVH + hh / (N_H / N_KVH);
2938+
dst_row_q = &XQ_O[b_t][h][0];
2939+
val = val * 8;
2940+
} else if (hh < N_H + N_KVH) {
2941+
h = hh - N_H;
2942+
2943+
qparam = scale_k + b * N_KVH + h;
2944+
dst_row_q = &cache_K[b][seqpos_t][h][0];
2945+
val = kv_multiplier * val;
2946+
} else {
2947+
h = hh - N_H - N_KVH;
2948+
2949+
qparam = scale_v + b * N_KVH + h;
2950+
dst_row_q = &cache_V[b][seqpos_t][h][0];
2951+
val = kv_multiplier * val;
2952+
}
2953+
}
2954+
{
2955+
float scale = 0;
2956+
val = fminf(val, 12000);
2957+
scale = fmaxf(val / FP8_E4M3_MAX::value, min_scaling_factor);
2958+
bool is_first_token = (b_t == 0 || !varseq_batch || last_b != b);
2959+
if (threadIdx.x == 0 && h == 0 && is_first_token) {
2960+
*qparam = scale;
2961+
}
2962+
inv_scale = 1 / scale;
2963+
}
2964+
2965+
{
2966+
bfx4 src;
2967+
fx4 dst;
2968+
uint32_t x_bits[4];
2969+
// Convert and pack data
2970+
// 8 bytes are 4 elements of type bf16
2971+
*reinterpret_cast<uint2*>(&src) = buffer;
2972+
dst = bfx4_to_fx4(src);
2973+
x_bits[0] = convertAndPack(dst.x, inv_scale);
2974+
x_bits[1] = convertAndPack(dst.y, inv_scale);
2975+
x_bits[2] = convertAndPack(dst.z, inv_scale);
2976+
x_bits[3] = convertAndPack(dst.w, inv_scale);
2977+
uint32_t packed = packComponents(x_bits);
2978+
// CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[d]) % 4 == 0);
2979+
*reinterpret_cast<uint32_t*>(&dst_row_q[d]) = packed;
2980+
}
29422981
}
2943-
val *= 8; // [Experimental] improves accuracy
2944-
} else {
2945-
val = kv_multiplier * xqkv_amax_head[b * HH + hh];
29462982
}
2947-
// Calculate scaling factor
2948-
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX::value * 512.f);
2949-
val = fminf(val, 12000);
2950-
float scale = fmaxf(val / FP8_E4M3_MAX::value, min_scaling_factor);
2951-
bool is_first_token =
2952-
(b_t == 0 || !varseq_batch || varseq_batch[b_t - 1] != b);
2953-
if (threadIdx.x == 0 && h == 0 && is_first_token) {
2954-
*qparam = scale;
2955-
}
2956-
// Write scaling factor
2957-
auto inv_scale = 1 / scale;
2958-
CUDA_KERNEL_ASSERT(uintptr_t(&xqkv[b_t * HH * D_H + hh * D_H]) % 16 == 0);
2959-
at::BFloat16* src_row = &xqkv[b_t * HH * D_H + hh * D_H];
2960-
2961-
// Convert and pack data
2962-
bfx4 src;
2963-
fx4 dst;
2964-
uint32_t x_bits[4];
2965-
auto d = threadIdx.x * 4;
2966-
CUDA_KERNEL_ASSERT(uintptr_t(&src_row[d]) % 8 == 0);
2967-
// 8 bytes are 4 elements of type bf16
2968-
*reinterpret_cast<uint2*>(&src) = *reinterpret_cast<uint2*>(&src_row[d]);
2969-
dst = bfx4_to_fx4(src);
2970-
x_bits[0] = convertAndPack(dst.x, inv_scale);
2971-
x_bits[1] = convertAndPack(dst.y, inv_scale);
2972-
x_bits[2] = convertAndPack(dst.z, inv_scale);
2973-
x_bits[3] = convertAndPack(dst.w, inv_scale);
2974-
uint32_t packed = packComponents(x_bits);
2975-
CUDA_KERNEL_ASSERT(uintptr_t(&dst_row_q[d]) % 4 == 0);
2976-
*reinterpret_cast<uint32_t*>(&dst_row_q[d]) = packed;
29772983
}
29782984
29792985
at::Tensor quantize_qkv_per_head(
@@ -2988,21 +2994,19 @@ at::Tensor quantize_qkv_per_head(
29882994
int64_t B, // Batch size
29892995
std::optional<at::Tensor> qparam_k,
29902996
std::optional<at::Tensor> qparam_v) {
2991-
auto B_T = XQ_O.size(0);
29922997
auto N_KVH_L = cache_K.size(2);
2993-
auto N_H_L = XQ_O.size(1);
2994-
auto HH = N_H_L + N_KVH_L * 2;
2998+
29952999
float* qparam_k_ptr = nullptr;
29963000
float* qparam_v_ptr = nullptr;
29973001
if (qparam_k.has_value()) {
29983002
// prefill case
29993003
qparam_k_ptr = qparam_k.value().data_ptr<float>();
30003004
qparam_v_ptr = qparam_v.value().data_ptr<float>();
30013005
}
3002-
auto num_warps = B_T * HH;
3003-
dim3 block_size(kThreadsPerWarp, kWarpsPerBlock);
3004-
dim3 grid_size(cuda_calc_xblock_count(num_warps, kWarpsPerBlock));
30053006
3007+
constexpr int32_t kMaxBlocks = 512;
3008+
dim3 block_size(kThreadsPerWarp, kWarpsPerBlock);
3009+
dim3 grid_size(kMaxBlocks);
30063010
auto scale_q = at::zeros({B, N_KVH_L}, XQ_O.options().dtype(at::kFloat));
30073011
float* const scale_q_ptr = scale_q.data_ptr<float>();
30083012
// Launch the kernel

0 commit comments

Comments
 (0)