Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
case CutlassGemmType::Fp8:
if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) {
if (sm == 89 || sm >= 120) {
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
if (sm == 89 || sm == 120) {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128};
} else {
// no valid ampere style fp8 configs for sm90
return {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,28 +688,28 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
TLLM_THROW("FP4 data type is not supported on SM < 90");
#endif
} else if (sm_ >= 80 && sm_ < 90) {
#ifdef ENABLE_FP4
if constexpr (!std::is_same_v<WeightType, __nv_fp4_e2m1>) {
if constexpr (use_fp8 || use_w4afp8) {
if constexpr (use_fp8 || use_w4afp8) {
#if defined(ENABLE_FP8)
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3> &&
!std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
static_assert(
!std::is_same_v<OutputType, __nv_fp8_e4m3> && !std::is_same_v<OutputType, __nv_fp8_e5m2>,
"FP8 GEMM Output not supported");
#endif
TLLM_CHECK_WITH_INFO(sm_ == 89,
"For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm89, EpilogueTag>(
inputs, multi_processor_count_);
} else {
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<WeightType, __nv_fp4_e2m1>) {
TLLM_THROW("FP4 data type is not supported on SM < 90");
} else {
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
}
} else {
TLLM_THROW("FP4 data type is not supported on SM < 90");
}
#else
TLLM_THROW("FP4 data type is not supported on SM < 90");
dispatchMoeGemmToCutlass<T, WeightType, ScaleBiasType, cutlass::arch::Sm80, EpilogueTag>(
inputs, multi_processor_count_);
#endif
}
} else if (sm_ >= 90) {
// For SM120+ FP8 MoE, redirect to SM89 (Ada) FP8 kernel implementations.
if constexpr (use_fp8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ template <typename T, typename WeightType,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion Fusion =
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE>
constexpr bool isValidSM120MOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) // TODO Is there a better choice
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) && \
defined(ENABLE_FP4) // TODO Is there a better choice
return cutlass::platform::is_same<T, __nv_fp4_e2m1>::value &&
cutlass::platform::is_same<T, WeightType>::value &&
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
Expand All @@ -49,8 +50,13 @@ template <typename T, typename WeightType,
constexpr bool isValidBlackwellMOESpecialisation() {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // TODO Is there a better choice
return (cutlass::platform::is_same<T, WeightType>::value ||
#if defined(ENABLE_FP4)
(cutlass::platform::is_same<T, __nv_fp8_e4m3>::value &&
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)) &&
cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value)
#else
false
#endif
) &&
cutlass::platform::is_same<EpilogueTag, cutlass_extensions::EpilogueOpDefault>::value &&
Fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
#else
Expand Down
6 changes: 5 additions & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,12 @@ def choose_one(
)
except Exception as e:
shapes = self._get_input_sizes(tensors)
logger.warning(
f"[Autotuner]: Skipping tactic {r} {tac}, due to failure while profiling."
)

logger.error(
# Log stacktrace as debug to not spam log
logger.debug(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
)

Expand Down
3 changes: 3 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
gen_cutlass_fused_moe_sm120_module,
gen_cutlass_fused_moe_sm100_module,
gen_cutlass_fused_moe_sm90_module,
gen_cutlass_fused_moe_sm89_module,
gen_trtllm_gen_fused_moe_sm100_module,
)
from ..utils import (
Expand Down Expand Up @@ -285,6 +286,8 @@ def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = Fa
module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load()
elif backend == "90":
module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load()
elif backend == "89":
module = gen_cutlass_fused_moe_sm89_module(use_fast_build).build_and_load()
else:
raise ValueError(f"Invalid backend: {backend}")

Expand Down
4 changes: 4 additions & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def clear_cache_dir():
"-DFLASHINFER_ENABLE_FP8_E8M0",
"-DFLASHINFER_ENABLE_FP4_E2M1",
]
sm89_nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89",
"-DFLASHINFER_ENABLE_FP8_E8M0",
]
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
Expand Down
18 changes: 17 additions & 1 deletion flashinfer/jit/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

from . import env as jit_env
from ..artifacts import ArtifactPath, CheckSumHash
from .core import JitSpec, gen_jit_spec, current_compilation_context, sm90a_nvcc_flags
from .core import (
JitSpec,
gen_jit_spec,
current_compilation_context,
sm90a_nvcc_flags,
sm89_nvcc_flags,
)
from .cpp_ext import is_cuda_version_at_least
from .cubin_loader import get_cubin, get_meta_hash
from .gemm.cutlass.generate_kernels import generate_gemm_operations
Expand Down Expand Up @@ -71,6 +77,16 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec:
return gen_cutlass_fused_moe_module(nvcc_flags, "90", use_fast_build)


def gen_cutlass_fused_moe_sm89_module(use_fast_build: bool = False) -> JitSpec:
nvcc_flags = sm89_nvcc_flags + [
"-DENABLE_BF16",
"-DENABLE_FP8",
"-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "",
"-DUSING_OSS_CUTLASS_MOE_GEMM",
]
return gen_cutlass_fused_moe_module(nvcc_flags, "89", use_fast_build)


def gen_cutlass_fused_moe_module(
nvcc_flags: List[str], device_arch: str, use_fast_build: bool = False
) -> JitSpec:
Expand Down