Skip to content

Commit

Permalink
add env
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhongmin.will committed Mar 1, 2025
1 parent 6199b0b commit 1eddaa3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 4 additions & 1 deletion csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,12 @@ mha_fwd_kvcache_mla(
run_mha_fwd_splitkv_mla<cutlass::half_t, cutlass::half_t, 576>(params, stream);
}
#endif
#ifndef FLASH_MLA_DISABLE_FP8
else if (q_dtype == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla<cutlass::float_e4m3_t, cutlass::bfloat16_t, 576>(params, stream);
} else {
}
#endif
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE"
DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE"


def append_nvcc_threads(nvcc_extra_args):
Expand All @@ -23,12 +24,13 @@ def get_sources():
sources = [
"csrc/flash_api.cpp",
"csrc/flash_fwd_mla_bf16_sm90.cu",
"csrc/flash_fwd_mla_fp8_sm90.cu",
"csrc/flash_fwd_mla_metadata.cu",
]

if not DISABLE_FP16:
sources.append("csrc/flash_fwd_mla_fp16_sm90.cu")
if not DISABLE_FP8:
sources.append("csrc/flash_fwd_mla_fp8_sm90.cu")

return sources

Expand Down

0 comments on commit 1eddaa3

Please sign in to comment.