@@ -2588,16 +2588,42 @@ __global__ void add_fusedQKV_bias_transpose_prefill_kernel_v1(T*
25882588 KVBlockArray kv_block_array = param.kv_block_array ;
25892589 Tcache* k_cache = reinterpret_cast <Tcache*>(kv_block_array.getKBlockPtr (batch_idx, dst_kv_seq_idx));
25902590 Tcache* v_cache = reinterpret_cast <Tcache*>(kv_block_array.getVBlockPtr (batch_idx, dst_kv_seq_idx));
2591+ if constexpr (std::is_same<Tcache, __nv_fp8_e4m3>::value) {
2592+ float * k_scale_ptr = reinterpret_cast <float *>(kv_block_array.getKScalePtr (batch_idx, dst_kv_seq_idx));
2593+ float * v_scale_ptr = reinterpret_cast <float *>(kv_block_array.getVScalePtr (batch_idx, dst_kv_seq_idx));
2594+ const int inScaleIdx = kv_block_array.getKVScaleLocalIdx (dst_kv_seq_idx, head_idx);
2595+
2596+ __shared__ float s_max[2 ];
2597+ s_max[0 ] = float (1 << (8 - 1 ));
2598+ s_max[1 ] = float (1 << (8 - 1 ));
25912599
25922600#pragma unroll
2593- for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
2594- const int inKBlockIdx = kv_block_array.getKLocalIdx <KvCacheDataType::BASE>(
2595- dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2596- k_cache[inKBlockIdx] = reinterpret_cast <T*>(&k)[vec_i];
2601+ for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
2602+ const int inKBlockIdx = kv_block_array.getKLocalIdx <KvCacheDataType::FP8>(
2603+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
25972604
2598- const int inVBlockIdx = kv_block_array.getVLocalIdx (
2605+ const int inVBlockIdx = kv_block_array.getVLocalIdx (
25992606 dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2600- v_cache[inVBlockIdx] = reinterpret_cast <T*>(&v)[vec_i];
2607+
2608+ k_cache[inKBlockIdx] = Tcache (float (reinterpret_cast <T*>(&k)[vec_i]) * (float (1 << (8 - 1 )) / s_max[0 ]));
2609+ v_cache[inVBlockIdx] = Tcache (float (reinterpret_cast <T*>(&v)[vec_i]) * (float (1 << (8 - 1 )) / s_max[1 ]));
2610+ }
2611+
2612+ if (tidx == 0 ) {
2613+ *reinterpret_cast <float *>(&k_scale_ptr[inScaleIdx]) = s_max[0 ] / float (1 << (8 - 1 ));
2614+ *reinterpret_cast <float *>(&v_scale_ptr[inScaleIdx]) = s_max[1 ] / float (1 << (8 - 1 ));
2615+ }
2616+ } else {
2617+ #pragma unroll
2618+ for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
2619+ const int inKBlockIdx = kv_block_array.getKLocalIdx <KvCacheDataType::BASE>(
2620+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2621+ k_cache[inKBlockIdx] = reinterpret_cast <T*>(&k)[vec_i];
2622+
2623+ const int inVBlockIdx = kv_block_array.getVLocalIdx (
2624+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2625+ v_cache[inVBlockIdx] = reinterpret_cast <T*>(&v)[vec_i];
2626+ }
26012627 }
26022628 }
26032629 }
@@ -2993,7 +3019,8 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
29933019 bool store_qkv,
29943020 bool store_q,
29953021 bool store_kv,
2996- bool store_cache) {
3022+ bool store_cache,
3023+ const float2 * cos_sin_cache) {
29973024 extern __shared__ __align__ (sizeof (float2 )) char smem_[];
29983025
29993026 constexpr int vec_size = Vec_t<T>::size;
@@ -3068,7 +3095,8 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
30683095 input_len,
30693096 prefix_prompt_length,
30703097 true /* count_prefix_length*/ ,
3071- true /* HANDLE_KV*/ );
3098+ true /* HANDLE_KV*/ ,
3099+ cos_sin_cache);
30723100
30733101 if (use_logn_attn) {
30743102 logn_attention (q, tlength, rope_config.max_pos );
@@ -3084,19 +3112,43 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
30843112
30853113 if (store_cache) {
30863114 if (head_idx < head_num_kv) {
3087- OffsetIndexedKVBlockArray offset_kv_block_array = param.offset_kv_block_array ;
3088- Tcache* k_cache = reinterpret_cast <Tcache*>(offset_kv_block_array.getKBlockPtr (batch_idx, dst_kv_seq_idx));
3089- Tcache* v_cache = reinterpret_cast <Tcache*>(offset_kv_block_array.getVBlockPtr (batch_idx, dst_kv_seq_idx));
3115+ KVBlockArray kv_block_array = param.kv_block_array ;
3116+ Tcache* k_cache = reinterpret_cast <Tcache*>(kv_block_array.getKBlockPtr (batch_idx, dst_kv_seq_idx));
3117+ Tcache* v_cache = reinterpret_cast <Tcache*>(kv_block_array.getVBlockPtr (batch_idx, dst_kv_seq_idx));
3118+ if constexpr (std::is_same<Tcache, __nv_fp8_e4m3>::value) {
3119+ float * k_scale_ptr = reinterpret_cast <float *>(kv_block_array.getKScalePtr (batch_idx, dst_kv_seq_idx));
3120+ float * v_scale_ptr = reinterpret_cast <float *>(kv_block_array.getVScalePtr (batch_idx, dst_kv_seq_idx));
3121+ const int inScaleIdx = kv_block_array.getKVScaleLocalIdx (dst_kv_seq_idx, head_idx);
30903122
3123+ __shared__ float s_max[2 ];
3124+ s_max[0 ] = float (1 << (8 - 1 ));
3125+ s_max[1 ] = float (1 << (8 - 1 ));
30913126#pragma unroll
3092- for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
3093- const int inKBlockIdx = offset_kv_block_array.getKLocalIdx <KvCacheDataType::BASE>(
3094- dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3095- k_cache[inKBlockIdx] = reinterpret_cast <T*>(&k)[vec_i];
3096-
3097- const int inVBlockIdx = offset_kv_block_array.getVLocalIdx (
3098- dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3099- v_cache[inVBlockIdx] = reinterpret_cast <T*>(&v)[vec_i];
3127+ for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
3128+ const int inKBlockIdx = kv_block_array.getKLocalIdx <KvCacheDataType::FP8>(
3129+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3130+
3131+ const int inVBlockIdx = kv_block_array.getVLocalIdx (
3132+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3133+
3134+ k_cache[inKBlockIdx] = Tcache (float (reinterpret_cast <T*>(&k)[vec_i]) * (float (1 << (8 - 1 )) / s_max[0 ]));
3135+ v_cache[inVBlockIdx] = Tcache (float (reinterpret_cast <T*>(&v)[vec_i]) * (float (1 << (8 - 1 )) / s_max[1 ]));
3136+ }
3137+ if (tidx == 0 ) {
3138+ *reinterpret_cast <float *>(&k_scale_ptr[inScaleIdx]) = s_max[0 ] / float (1 << (8 - 1 ));
3139+ *reinterpret_cast <float *>(&v_scale_ptr[inScaleIdx]) = s_max[1 ] / float (1 << (8 - 1 ));
3140+ }
3141+ } else {
3142+ #pragma unroll
3143+ for (int vec_i = 0 ; vec_i < vec_size; vec_i++) {
3144+ const int inKBlockIdx = kv_block_array.getKLocalIdx <KvCacheDataType::BASE>(
3145+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3146+ k_cache[inKBlockIdx] = reinterpret_cast <T*>(&k)[vec_i];
3147+
3148+ const int inVBlockIdx = kv_block_array.getVLocalIdx (
3149+ dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3150+ v_cache[inVBlockIdx] = reinterpret_cast <T*>(&v)[vec_i];
3151+ }
31003152 }
31013153 }
31023154 }
@@ -3324,7 +3376,8 @@ void invokeAddFusedQKVBiasTransposeDecodeV1(T* q_buf
33243376 store_qkv,
33253377 store_q,
33263378 store_kv,
3327- store_cache);
3379+ store_cache,
3380+ cos_sin_cache);
33283381 });
33293382 });
33303383 });
0 commit comments