From d490c3f63a300fbcfa98d1318a53e713d3e780b4 Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:30:52 +0300 Subject: [PATCH 1/7] Feature: Add support for L40 FusedMoE in cutlass path Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../cutlass_kernels/cutlass_heuristic.cpp | 9 +++++- .../moe_gemm/moe_gemm_template_dispatch.h | 31 +++++++++---------- .../moe_tma_warp_specialized_traits.h | 6 ++-- flashinfer/fused_moe/core.py | 3 ++ flashinfer/jit/core.py | 4 +++ flashinfer/jit/fused_moe.py | 18 ++++++++++- 6 files changed, 50 insertions(+), 21 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 80b1928437..089bf8c3e2 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -158,7 +158,14 @@ std::vector get_candidate_tiles( CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { - if (sm == 89 || sm >= 120) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else if (sm >= 120) { return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 110c517605..815d8f1388 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -689,27 +689,24 @@ void MoeGemmRunner::dispatchToArch( #endif } else if (sm_ >= 80 && sm_ < 90) { #ifdef ENABLE_FP4 - if constexpr (!std::is_same_v) { - if constexpr (use_fp8 || use_w4afp8) { + if constexpr (std::is_same_v) { + TLLM_THROW("FP4 data type is not supported on SM < 90"); + } +#endif + if constexpr (use_fp8 || use_w4afp8) { #if defined(ENABLE_FP8) - static_assert(!std::is_same_v && - !std::is_same_v, - "FP8 GEMM Output not supported"); + static_assert( + !std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); #endif - TLLM_CHECK_WITH_INFO(sm_ == 89, - "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); - } else { - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); - } + + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); } else { - TLLM_THROW("FP4 data type is not supported on SM < 90"); + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); } -#else - TLLM_THROW("FP4 data type is not supported on SM < 90"); -#endif } else if (sm_ >= 90) { // For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations. if constexpr (use_fp8) { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index d2f742d3d0..c879599c3e 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -32,7 +32,8 @@ template constexpr bool isValidSM120MOESpecialisation() { -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \ + defined(ENABLED_FP4) // TODO Is there a better choice return cutlass::platform::is_same::value && cutlass::platform::is_same::value && cutlass::platform::is_same::value && @@ -47,7 +48,8 @@ template constexpr bool isValidBlackwellMOESpecialisation() { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ + defined(ENABLED_FP4) // TODO Is there a better choice return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) && diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2ce2a8b6d0..5f0e33ccf9 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -37,6 +37,7 @@ gen_cutlass_fused_moe_sm120_module, gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, + gen_cutlass_fused_moe_sm89_module, gen_trtllm_gen_fused_moe_sm100_module, ) from ..utils import ( @@ -285,6 +286,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load() elif backend == "90": module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load() + elif backend == "89": + module = gen_cutlass_fused_moe_sm89_module(use_fast_build).build_and_load() else: raise ValueError(f"Invalid backend: {backend}") diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 2eec7ac2ce..e7dec73723 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -90,6 +90,10 @@ def clear_cache_dir(): "-DFLASHINFER_ENABLE_FP8_E8M0", "-DFLASHINFER_ENABLE_FP4_E2M1", ] +sm89_nvcc_flags = [ + "-gencode=arch=compute_89,code=sm_89", + "-DFLASHINFER_ENABLE_FP8_E8M0", +] sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index f0f781ad05..11398fabd9 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -18,7 +18,13 @@ from . import env as jit_env from ..artifacts import ArtifactPath, CheckSumHash -from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags +from .core import ( + JitSpec, + gen_jit_spec, + current_compilation_context, + sm90a_nvcc_flags, + sm89_nvcc_flags, +) from .cpp_ext import is_cuda_version_at_least from .cubin_loader import get_cubin, get_meta_hash from .gemm.cutlass.generate_kernels import generate_gemm_operations @@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec: return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build) +def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec: + nvcc_flags = sm89_nvcc_flags + [ + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", + "-DUSING_OSS_CUTLASS_MOE_GEMM", + ] + return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build) + + def gen_cutlass_fused_moe_module( nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False ) -> JitSpec: From 50a1bb09e02393c0afd8acb02c153312558edc6e Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Thu, 23 Oct 2025 21:32:59 +0300 Subject: [PATCH 2/7] Fix wrong env Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../moe_gemm/moe_tma_warp_specialized_traits.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index c879599c3e..4b803cdad6 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -33,7 +33,7 @@ template constexpr bool isValidSM120MOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \ - defined(ENABLED_FP4) // TODO Is there a better choice + defined(ENABLE_FP4) // TODO Is there a better choice return cutlass::platform::is_same::value && cutlass::platform::is_same::value && cutlass::platform::is_same::value && @@ -49,7 +49,7 @@ template constexpr bool isValidBlackwellMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ - defined(ENABLED_FP4) // TODO Is there a better choice + defined(ENABLE_FP4) // TODO Is there a better choice return (cutlass::platform::is_same::value || (cutlass::platform::is_same::value && cutlass::platform::is_same::value)) && From 2302b53c4ca169e096e2354627f6bce0ab928de3 Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Mon, 27 Oct 2025 16:36:47 +0200 Subject: [PATCH 3/7] Fix compilation issue Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../moe_gemm/moe_gemm_template_dispatch.h | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 815d8f1388..16b39246cb 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -688,24 +688,27 @@ void MoeGemmRunner::dispatchToArch( TLLM_THROW("FP4 data type is not supported on SM < 90"); #endif } else if (sm_ >= 80 && sm_ < 90) { -#ifdef ENABLE_FP4 - if constexpr (std::is_same_v) { - TLLM_THROW("FP4 data type is not supported on SM < 90"); - } -#endif if constexpr (use_fp8 || use_w4afp8) { #if defined(ENABLE_FP8) static_assert( !std::is_same_v && !std::is_same_v, "FP8 GEMM Output not supported"); #endif - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); dispatchMoeGemmToCutlass( inputs, multi_processor_count_); } else { +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) { + TLLM_THROW("FP4 data type is not supported on SM < 90"); + } else { + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } +#else dispatchMoeGemmToCutlass( inputs, multi_processor_count_); +#endif } } else if (sm_ >= 90) { // For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations. From 19785b9e9135b883079879af6c9fdb8439d0942f Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:19:25 +0200 Subject: [PATCH 4/7] Fine grain ifdefs Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../moe_gemm/moe_tma_warp_specialized_traits.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index 4b803cdad6..2846249485 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -48,12 +48,13 @@ template constexpr bool isValidBlackwellMOESpecialisation() { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && \ - defined(ENABLE_FP4) // TODO Is there a better choice - return (cutlass::platform::is_same::value || - (cutlass::platform::is_same::value && - cutlass::platform::is_same::value)) && - cutlass::platform::is_same::value && +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice + return (cutlass::platform::is_same::value +#if defined(ENABLE_FP4) + || (cutlass::platform::is_same::value && + cutlass::platform::is_same::value)) +#endif + && cutlass::platform::is_same::value && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled From 6462feb4341004a0e31272cec1af2304c2697e2e Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:19:42 +0200 Subject: [PATCH 5/7] Change tile config order instead of removing so that default case passes on L40 Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../kernels/cutlass_kernels/cutlass_heuristic.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 089bf8c3e2..8fc256ba31 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -158,21 +158,14 @@ std::vector get_candidate_tiles( CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { - if (sm == 89) { + if (sm == 89 || sm == 120) { return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; - } else if (sm >= 120) { - return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; } else { // no valid ampere style fp8 configs for sm90 return {}; From 31c11e00fa8f7f9b871d66a869725074c9df8432 Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:23:26 +0200 Subject: [PATCH 6/7] Change autotuner logging Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- flashinfer/autotuner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 902f659a1d..f8af220916 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -482,8 +482,12 @@ def choose_one( ) except Exception as e: shapes = self._get_input_sizes(tensors) + logger.warning( + f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling." + ) - logger.error( + # Log stacktrace as debug to not spam log + logger.debug( f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}" ) From ddc2ac26a9938ff097614c551df8062eb1488582 Mon Sep 17 00:00:00 2001 From: Amir Klein <203507526+amirkl94@users.noreply.github.com> Date: Tue, 28 Oct 2025 10:32:58 +0200 Subject: [PATCH 7/7] ifdefs Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> --- .../moe_gemm/moe_tma_warp_specialized_traits.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h index 2846249485..d0bcbb978d 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h @@ -49,12 +49,15 @@ template constexpr bool isValidBlackwellMOESpecialisation() { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice - return (cutlass::platform::is_same::value + return (cutlass::platform::is_same::value || #if defined(ENABLE_FP4) - || (cutlass::platform::is_same::value && - cutlass::platform::is_same::value)) + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) +#else + false #endif - && cutlass::platform::is_same::value && + ) && + cutlass::platform::is_same::value && Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE; #else return false; // CUTLASS_ARCH_MMA_SM100_SUPPORTED is set when Blackwell kernels are enabled