Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,14 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
case CutlassGemmType::Fp8:
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
if (sm == 89 || sm >= 120) {
if (sm == 89) {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition?
Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least can you leave a comment saying what the difference between the two lists are, so people dont have to manually compare the items

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any SM89 GPUs that can support the CtaShape16x256x128_WarpShape16x64x128, or is this an SM120 addition? Having a small M value like this helps with low latency cases, so I'd want to understand why its not supported before disabling it

This is a removal from the sm89 path. When I tested it on an L40 GPU I got Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel .
It might be that on other sm89 GPUs it will pass, the main issue is that this was the default tactic that was chosen when trying to use FusedMoE, I believe that moving it to be the last will also fix my issue.

Copy link
Contributor Author

@amirkl94 amirkl94 Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried moving this tile config to be the last one and now the default tactic won't fail on l40. The issue is that if autotuner is on then the tactics that use this tile config will report an error with a stacktrace which looks bad.
@yzh119 do you think it'll be ok to change the errors that happen in the autotuner to be debug logs? Otherwise it means users will get spammed with error messages when they run autotuning on L40 FusedMoE.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the autotuner should still output warnings, but just make them say "Skipping tactic x due to error. This tactic may not be supported by the current GPU architecture".
That said I know there is a difference of opinion on whether we should proactively filter them as you have done here, the argument being that we should be able to do the due diligence to determine what tactics are supported so that we can raise an error when a tactic fails when it shouldn't. So I can see either side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think that since this PR is somewhat time critical I'd rather move the faulty tile config to be the last so as not to remove it, and change the log to be a warning as you suggested and maybe a debug log with the stacktrace.
I'll also open an issue to take a deeper look into regarding if this specific tile config is relevant for sm89 or not as I think it might take me some time.
@djns99 , @yzh119 Does that sound ok to you?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with that resolution

} else if (sm >= 120) {
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,27 +689,24 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
#endif
} else if (sm_ >= 80 && sm_ < 90) {
#ifdef ENABLE_FP4
if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
if constexpr (use_fp8 || use_w4afp8) {
if constexpr (std::is_same_v<WeightType, __nv_fp4_e2m1>) {
TLLM_THROW("FP4 data type is not supported on SM < 90");
}
#endif
if constexpr (use_fp8 || use_w4afp8) {
#if defined(ENABLE_FP8)
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> &&
!std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
static_assert(
!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
#endif
TLLM_CHECK_WITH_INFO(sm_ == 89,
"For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
} else {
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
}

TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
} else {
TLLM_THROW("FP4 data type is not supported on SM < 90");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
}
#else
TLLM_THROW("FP4 data type is not supported on SM < 90");
#endif
} else if (sm_ >= 90) {
// For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations.
if constexpr (use_fp8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ template <typename T, typename WeightType,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
constexpr bool isValidSM120MOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \
defined(ENABLE_FP4) // TODO Is there a better choice
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value &&
cutlass::platform::is_same<T, WeightType>::value &&
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
Expand All @@ -47,7 +48,8 @@ template <typename T, typename WeightType,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
constexpr bool isValidBlackwellMOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \
defined(ENABLE_FP4) // TODO Is there a better choice
return (cutlass::platform::is_same<T, WeightType>::value ||
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
gen_cutlass_fused_moe_sm120_module,
gen_cutlass_fused_moe_sm100_module,
gen_cutlass_fused_moe_sm90_module,
gen_cutlass_fused_moe_sm89_module,
gen_trtllm_gen_fused_moe_sm100_module,
)
from ..utils import (
Expand Down Expand Up @@ -285,6 +286,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load()
elif backend == "90":
module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load()
elif backend == "89":
module = gen_cutlass_fused_moe_sm89_module(use_fast_build).build_and_load()
else:
raise ValueError(f"Invalid backend: {backend}")

Expand Down
4 changes: 4 additions & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def clear_cache_dir():
"-DFLASHINFER_ENABLE_FP8_E8M0",
"-DFLASHINFER_ENABLE_FP4_E2M1",
]
sm89_nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89",
"-DFLASHINFER_ENABLE_FP8_E8M0",
]
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
Expand Down
18 changes: 17 additions & 1 deletion flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from . import env as jit_env
from ..artifacts import ArtifactPath, CheckSumHash
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .core import (
JitSpec,
gen_jit_spec,
current_compilation_context,
sm90a_nvcc_flags,
sm89_nvcc_flags,
)
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations
Expand Down Expand Up @@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build)


def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec:
nvcc_flags = sm89_nvcc_flags + [
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
]
return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build)


def gen_cutlass_fused_moe_module(
nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False
) -> JitSpec:
Expand Down