Skip to content

Commit 159d0a0

Browse files
authored
Feature: Add support for L40 FusedMoE in cutlass path (#1973)
## 📌 Description Fixed a few compilation issues for L40, and removed 1 gemm tactic for `sm == 89` that crashes due to: ``` Assertion failed: GPU lacks the shared memory resources to run GroupedGEMM kernel ``` ## 🧪 Tests Ran `pytest tests/moe/test_trtllm_cutlass_fused_moe.py` manually on an L40 GPU and verified all tests passed. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Official support for SM89 target: build/JIT flags and a public generation path to target it. * **Bug Fixes / Compatibility** * Clarified FP8/FP4 dispatch: FP8 paths enabled for SM89; FP4 usage remains gated and now requires explicit enablement. * **Performance** * Adjusted kernel/tile selection order for certain FP8 paths to prefer SM89-optimized options. * **Chores** * Reduced logging severity for failed tactic profiling to warn/debug. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Amir Klein <[email protected]>
1 parent 9ce1af7 commit 159d0a0

File tree

7 files changed

+55
-22
lines changed

7 files changed

+55
-22
lines changed

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,14 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
158158
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
159159
case CutlassGemmType::Fp8:
160160
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
161-
if (sm == 89 || sm >= 120) {
162-
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
163-
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
161+
if (sm == 89 || sm == 120) {
162+
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
164163
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
165164
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
166165
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
167166
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
168-
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
167+
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64,
168+
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128};
169169
} else {
170170
// no valid ampere style fp8 configs for sm90
171171
return {};

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -688,28 +688,28 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
688688
TLLM_THROW("FP4 data type is not supported on SM < 90");
689689
#endif
690690
} else if (sm_ >= 80 && sm_ < 90) {
691-
#ifdef ENABLE_FP4
692-
if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
693-
if constexpr (use_fp8 || use_w4afp8) {
691+
if constexpr (use_fp8 || use_w4afp8) {
694692
#if defined(ENABLE_FP8)
695-
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> &&
696-
!std::is_same_v<OutputType, __nv_fp8_e5m2>,
697-
"FP8 GEMM Output not supported");
693+
static_assert(
694+
!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
695+
"FP8 GEMM Output not supported");
698696
#endif
699-
TLLM_CHECK_WITH_INFO(sm_ == 89,
700-
"For sm >= 80 and < 90, fp8 is only supported with sm == 89");
701-
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
702-
inputs, multi_processor_count_);
697+
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
698+
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
699+
inputs, multi_processor_count_);
700+
} else {
701+
#ifdef ENABLE_FP4
702+
if constexpr (std::is_same_v<WeightType, __nv_fp4_e2m1>) {
703+
TLLM_THROW("FP4 data type is not supported on SM < 90");
703704
} else {
704705
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
705706
inputs, multi_processor_count_);
706707
}
707-
} else {
708-
TLLM_THROW("FP4 data type is not supported on SM < 90");
709-
}
710708
#else
711-
TLLM_THROW("FP4 data type is not supported on SM < 90");
709+
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
710+
inputs, multi_processor_count_);
712711
#endif
712+
}
713713
} else if (sm_ >= 90) {
714714
// For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations.
715715
if constexpr (use_fp8) {

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_tma_warp_specialized_traits.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ template <typename T, typename WeightType,
3232
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
3333
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
3434
constexpr bool isValidSM120MOESpecialisation() {
35-
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
35+
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \
36+
defined(ENABLE_FP4) // TODO Is there a better choice
3637
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value &&
3738
cutlass::platform::is_same<T, WeightType>::value &&
3839
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
@@ -49,8 +50,13 @@ template <typename T, typename WeightType,
4950
constexpr bool isValidBlackwellMOESpecialisation() {
5051
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice
5152
return (cutlass::platform::is_same<T, WeightType>::value ||
53+
#if defined(ENABLE_FP4)
5254
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
53-
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
55+
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)
56+
#else
57+
false
58+
#endif
59+
) &&
5460
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
5561
Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
5662
#else

flashinfer/autotuner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,12 @@ def choose_one(
482482
)
483483
except Exception as e:
484484
shapes = self._get_input_sizes(tensors)
485+
logger.warning(
486+
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling."
487+
)
485488

486-
logger.error(
489+
# Log stacktrace as debug to not spam log
490+
logger.debug(
487491
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
488492
)
489493

flashinfer/fused_moe/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
gen_cutlass_fused_moe_sm120_module,
3838
gen_cutlass_fused_moe_sm100_module,
3939
gen_cutlass_fused_moe_sm90_module,
40+
gen_cutlass_fused_moe_sm89_module,
4041
gen_trtllm_gen_fused_moe_sm100_module,
4142
)
4243
from ..utils import (
@@ -285,6 +286,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
285286
module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load()
286287
elif backend == "90":
287288
module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load()
289+
elif backend == "89":
290+
module = gen_cutlass_fused_moe_sm89_module(use_fast_build).build_and_load()
288291
else:
289292
raise ValueError(f"Invalid backend: {backend}")
290293

flashinfer/jit/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def clear_cache_dir():
9090
"-DFLASHINFER_ENABLE_FP8_E8M0",
9191
"-DFLASHINFER_ENABLE_FP4_E2M1",
9292
]
93+
sm89_nvcc_flags = [
94+
"-gencode=arch=compute_89,code=sm_89",
95+
"-DFLASHINFER_ENABLE_FP8_E8M0",
96+
]
9397
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
9498
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
9599
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags

flashinfer/jit/fused_moe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
from . import env as jit_env
2020
from ..artifacts import ArtifactPath, CheckSumHash
21-
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
21+
from .core import (
22+
JitSpec,
23+
gen_jit_spec,
24+
current_compilation_context,
25+
sm90a_nvcc_flags,
26+
sm89_nvcc_flags,
27+
)
2228
from .cpp_ext import is_cuda_version_at_least
2329
from .cubin_loader import get_cubin, get_meta_hash
2430
from .gemm.cutlass.generate_kernels import generate_gemm_operations
@@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
7177
return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build)
7278

7379

80+
def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec:
81+
nvcc_flags = sm89_nvcc_flags + [
82+
"-DENABLE_BF16",
83+
"-DENABLE_FP8",
84+
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
85+
"-DUSING_OSS_CUTLASS_MOE_GEMM",
86+
]
87+
return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build)
88+
89+
7490
def gen_cutlass_fused_moe_module(
7591
nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False
7692
) -> JitSpec:

0 commit comments

Comments
 (0)