Skip to content

Commit 08ad962

Browse files
amd-yilizhaoLLLLKKKK
authored andcommitted
feat:enable fp8 hip_PA
1 parent 1f3cfd0 commit 08ad962

File tree

3 files changed

+77
-33
lines changed

3 files changed

+77
-33
lines changed

rtp_llm/cpp/devices/rocm_impl/ROCmAttentionOp.cc

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,16 +1098,8 @@ AttentionModuleOutput ROCmDevice::decoderSelfAttention(const AttentionModulePara
10981098

10991099
if (init_params_.use_aiter_pa) {
11001100
PrefixPromptBatchWeightsParam prefix_prompt_param;
1101-
if (init_params_.use_asm_pa) {
1102-
KVBlockArray kv_block_array = getKVBlockArray(params, *kv_cache_offset, batch_size, params.common.kv_cache->k_cache_buffer->type() == DataType::TYPE_FP8_E4M3, false);
1103-
prefix_prompt_param.kv_block_array = kv_block_array;
1104-
}
1105-
else {
1106-
KVBlockArray kv_block_array = getKVBlockArray(params, *kv_cache_offset, batch_size, false, true);
1107-
//PrefixPromptBatchWeightsParam prefix_prompt_param;
1108-
auto offset_kv_block_array = OffsetIndexedKVBlockArray(kv_block_array,(rtp_llm::KVBlockArrayForContextFMHA::DataType*)params.common.kv_cache->kv_cache_block_id->data(), params.common.kv_cache->k_cache_buffer->shape()[0] * params.common.kv_cache->layer_num);
1109-
prefix_prompt_param.offset_kv_block_array = offset_kv_block_array;
1110-
}
1101+
KVBlockArray kv_block_array = getKVBlockArray(params, *kv_cache_offset, batch_size, params.common.kv_cache->k_cache_buffer->type() == DataType::TYPE_FP8_E4M3, false);
1102+
prefix_prompt_param.kv_block_array = kv_block_array;
11111103

11121104
auto token_num = params.input.shape()[0];
11131105
auto decoder_batch_size = params.common.decoder_batch_size;
@@ -1208,8 +1200,7 @@ AttentionModuleOutput ROCmDevice::decoderSelfAttention(const AttentionModulePara
12081200
store_q,
12091201
store_kv,
12101202
store_cache,
1211-
nullptr,
1212-
//params.rotary_embedding_coefficient_cache ? params.rotary_embedding_coefficient_cache->data() : nullptr,
1203+
use_rope_cache && rope_cache.defined() ? static_cast<float2*>(rope_cache.data_ptr()) : nullptr,
12131204
stream_);
12141205
}
12151206
check_cuda_error();

rtp_llm/cpp/devices/rocm_impl/aiterPA.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void runAiterPA(const AttentionModuleParams& params, rtp_llm::DeviceBase* device
9393

9494
int64_t block_size = params.configs.tokens_per_block;
9595

96-
std::string kv_cache_dtype = "auto";
96+
std::string kv_cache_dtype = key_cache.dtype() == at::kFloat8_e4m3fnuz ? "fp8" : "auto";
9797

9898
double k_scale = 1.0;
9999
double v_scale = 1.0;

rtp_llm/cpp/kernels/unfused_attention_kernels.cu

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,16 +2588,42 @@ __global__ void add_fusedQKV_bias_transpose_prefill_kernel_v1(T*
25882588
KVBlockArray kv_block_array = param.kv_block_array;
25892589
Tcache* k_cache = reinterpret_cast<Tcache*>(kv_block_array.getKBlockPtr(batch_idx, dst_kv_seq_idx));
25902590
Tcache* v_cache = reinterpret_cast<Tcache*>(kv_block_array.getVBlockPtr(batch_idx, dst_kv_seq_idx));
2591+
if constexpr (std::is_same<Tcache, __nv_fp8_e4m3>::value) {
2592+
float* k_scale_ptr = reinterpret_cast<float*>(kv_block_array.getKScalePtr(batch_idx, dst_kv_seq_idx));
2593+
float* v_scale_ptr = reinterpret_cast<float*>(kv_block_array.getVScalePtr(batch_idx, dst_kv_seq_idx));
2594+
const int inScaleIdx = kv_block_array.getKVScaleLocalIdx(dst_kv_seq_idx, head_idx);
2595+
2596+
__shared__ float s_max[2];
2597+
s_max[0] = float(1 << (8 - 1));
2598+
s_max[1] = float(1 << (8 - 1));
25912599

25922600
#pragma unroll
2593-
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
2594-
const int inKBlockIdx = kv_block_array.getKLocalIdx<KvCacheDataType::BASE>(
2595-
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2596-
k_cache[inKBlockIdx] = reinterpret_cast<T*>(&k)[vec_i];
2601+
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
2602+
const int inKBlockIdx = kv_block_array.getKLocalIdx<KvCacheDataType::FP8>(
2603+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
25972604

2598-
const int inVBlockIdx = kv_block_array.getVLocalIdx(
2605+
const int inVBlockIdx = kv_block_array.getVLocalIdx(
25992606
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2600-
v_cache[inVBlockIdx] = reinterpret_cast<T*>(&v)[vec_i];
2607+
2608+
k_cache[inKBlockIdx] = Tcache(float(reinterpret_cast<T*>(&k)[vec_i]) * (float(1 << (8 - 1)) / s_max[0]));
2609+
v_cache[inVBlockIdx] = Tcache(float(reinterpret_cast<T*>(&v)[vec_i]) * (float(1 << (8 - 1)) / s_max[1]));
2610+
}
2611+
2612+
if (tidx == 0) {
2613+
*reinterpret_cast<float*>(&k_scale_ptr[inScaleIdx]) = s_max[0] / float(1 << (8 - 1));
2614+
*reinterpret_cast<float*>(&v_scale_ptr[inScaleIdx]) = s_max[1] / float(1 << (8 - 1));
2615+
}
2616+
} else {
2617+
#pragma unroll
2618+
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
2619+
const int inKBlockIdx = kv_block_array.getKLocalIdx<KvCacheDataType::BASE>(
2620+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2621+
k_cache[inKBlockIdx] = reinterpret_cast<T*>(&k)[vec_i];
2622+
2623+
const int inVBlockIdx = kv_block_array.getVLocalIdx(
2624+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
2625+
v_cache[inVBlockIdx] = reinterpret_cast<T*>(&v)[vec_i];
2626+
}
26012627
}
26022628
}
26032629
}
@@ -2993,7 +3019,8 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
29933019
bool store_qkv,
29943020
bool store_q,
29953021
bool store_kv,
2996-
bool store_cache) {
3022+
bool store_cache,
3023+
const float2* cos_sin_cache) {
29973024
extern __shared__ __align__(sizeof(float2)) char smem_[];
29983025

29993026
constexpr int vec_size = Vec_t<T>::size;
@@ -3068,7 +3095,8 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
30683095
input_len,
30693096
prefix_prompt_length,
30703097
true /*count_prefix_length*/,
3071-
true /*HANDLE_KV*/);
3098+
true /*HANDLE_KV*/,
3099+
cos_sin_cache);
30723100

30733101
if (use_logn_attn) {
30743102
logn_attention(q, tlength, rope_config.max_pos);
@@ -3084,19 +3112,43 @@ __global__ void add_fusedQKV_bias_transpose_decode_kernel_v1(T*
30843112

30853113
if (store_cache) {
30863114
if (head_idx < head_num_kv) {
3087-
OffsetIndexedKVBlockArray offset_kv_block_array = param.offset_kv_block_array;
3088-
Tcache* k_cache = reinterpret_cast<Tcache*>(offset_kv_block_array.getKBlockPtr(batch_idx, dst_kv_seq_idx));
3089-
Tcache* v_cache = reinterpret_cast<Tcache*>(offset_kv_block_array.getVBlockPtr(batch_idx, dst_kv_seq_idx));
3115+
KVBlockArray kv_block_array = param.kv_block_array;
3116+
Tcache* k_cache = reinterpret_cast<Tcache*>(kv_block_array.getKBlockPtr(batch_idx, dst_kv_seq_idx));
3117+
Tcache* v_cache = reinterpret_cast<Tcache*>(kv_block_array.getVBlockPtr(batch_idx, dst_kv_seq_idx));
3118+
if constexpr (std::is_same<Tcache, __nv_fp8_e4m3>::value) {
3119+
float* k_scale_ptr = reinterpret_cast<float*>(kv_block_array.getKScalePtr(batch_idx, dst_kv_seq_idx));
3120+
float* v_scale_ptr = reinterpret_cast<float*>(kv_block_array.getVScalePtr(batch_idx, dst_kv_seq_idx));
3121+
const int inScaleIdx = kv_block_array.getKVScaleLocalIdx(dst_kv_seq_idx, head_idx);
30903122

3123+
__shared__ float s_max[2];
3124+
s_max[0] = float(1 << (8 - 1));
3125+
s_max[1] = float(1 << (8 - 1));
30913126
#pragma unroll
3092-
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
3093-
const int inKBlockIdx = offset_kv_block_array.getKLocalIdx<KvCacheDataType::BASE>(
3094-
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3095-
k_cache[inKBlockIdx] = reinterpret_cast<T*>(&k)[vec_i];
3096-
3097-
const int inVBlockIdx = offset_kv_block_array.getVLocalIdx(
3098-
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3099-
v_cache[inVBlockIdx] = reinterpret_cast<T*>(&v)[vec_i];
3127+
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
3128+
const int inKBlockIdx = kv_block_array.getKLocalIdx<KvCacheDataType::FP8>(
3129+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3130+
3131+
const int inVBlockIdx = kv_block_array.getVLocalIdx(
3132+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3133+
3134+
k_cache[inKBlockIdx] = Tcache(float(reinterpret_cast<T*>(&k)[vec_i]) * (float(1 << (8 - 1)) / s_max[0]));
3135+
v_cache[inVBlockIdx] = Tcache(float(reinterpret_cast<T*>(&v)[vec_i]) * (float(1 << (8 - 1)) / s_max[1]));
3136+
}
3137+
if (tidx == 0) {
3138+
*reinterpret_cast<float*>(&k_scale_ptr[inScaleIdx]) = s_max[0] / float(1 << (8 - 1));
3139+
*reinterpret_cast<float*>(&v_scale_ptr[inScaleIdx]) = s_max[1] / float(1 << (8 - 1));
3140+
}
3141+
} else {
3142+
#pragma unroll
3143+
for (int vec_i = 0; vec_i < vec_size; vec_i++) {
3144+
const int inKBlockIdx = kv_block_array.getKLocalIdx<KvCacheDataType::BASE>(
3145+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3146+
k_cache[inKBlockIdx] = reinterpret_cast<T*>(&k)[vec_i];
3147+
3148+
const int inVBlockIdx = kv_block_array.getVLocalIdx(
3149+
dst_kv_seq_idx, head_idx, size_per_head, tidx * vec_size + vec_i);
3150+
v_cache[inVBlockIdx] = reinterpret_cast<T*>(&v)[vec_i];
3151+
}
31003152
}
31013153
}
31023154
}
@@ -3324,7 +3376,8 @@ void invokeAddFusedQKVBiasTransposeDecodeV1(T* q_buf
33243376
store_qkv,
33253377
store_q,
33263378
store_kv,
3327-
store_cache);
3379+
store_cache,
3380+
cos_sin_cache);
33283381
});
33293382
});
33303383
});

0 commit comments

Comments
 (0)