Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 support #54

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.

Currently released:
- BF16, FP16
- BF16, FP16, E4M3
- Paged kvcache with block size of 64

## Quick start
Expand Down
40 changes: 33 additions & 7 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,19 @@ mha_fwd_kvcache_mla(
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
const at::Tensor &num_splits, // batch_size + 1
c10::optional<const at::Tensor> &descale_q_, // batch_size
c10::optional<const at::Tensor> &descale_k_ // batch_size
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);

at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;

auto q_dtype = q.dtype();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
auto q_dtype = q.scalar_type();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn);
TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);

Expand All @@ -104,6 +107,20 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (q_dtype == torch::kFloat8_e4m3fn) {
TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8");
auto descale_q = descale_q_.value();
auto descale_k = descale_k_.value();
CHECK_DEVICE(descale_q);
CHECK_DEVICE(descale_k);
TORCH_CHECK(descale_q.stride(-1) == 1);
TORCH_CHECK(descale_k.stride(-1) == 1);
TORCH_CHECK(descale_q.dtype() == torch::kFloat);
TORCH_CHECK(descale_k.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q, 1);
CHECK_SHAPE(descale_k, 1);
}

if (seqlen_q_ori == 1) { is_causal = false; }

const int ngroups = num_heads_ori / num_heads_k;
Expand All @@ -127,7 +144,8 @@ mha_fwd_kvcache_mla(
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type));
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));

Flash_fwd_mla_params params = {};
Expand Down Expand Up @@ -167,6 +185,11 @@ mha_fwd_kvcache_mla(
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;

if (q_dtype == torch::kFloat8_e4m3fn) {
params.descale_q_ptr = reinterpret_cast<float*>(descale_q_.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float*>(descale_k_.value().data_ptr());
}

TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
Expand All @@ -186,15 +209,18 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);


if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(params, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
}
#endif
else {
else if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_fwd_mla_bf16_sm90.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include "flash_fwd_mla_kernel.h"

template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
2 changes: 1 addition & 1 deletion csrc/flash_fwd_mla_fp16_sm90.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include "flash_fwd_mla_kernel.h"

template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
3 changes: 3 additions & 0 deletions csrc/flash_fwd_mla_fp8_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "flash_fwd_mla_kernel.h"

template void run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
Loading