Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f109a2b
Fix aot failures
yongwww Aug 27, 2025
3cfba8e
>launcher.inl
aleozlx Oct 8, 2025
9047135
>generate_kernels.py
aleozlx Oct 8, 2025
d9d7723
>generate_kernels.py
aleozlx Oct 8, 2025
3166245
>launcher.inl
aleozlx Oct 8, 2025
4dce85b
>moe_gemm_kernels.h
aleozlx Oct 8, 2025
d47c865
cutlass_fused_moe_kernels.cuh is troublesome...
aleozlx Oct 11, 2025
307fe30
fix compilation errors in cutlass_fused_moe_kernels.cuh
aleozlx Oct 13, 2025
76a9220
>gather_tensor.hpp
aleozlx Oct 13, 2025
96c0ed4
fix compilation errors
aleozlx Oct 13, 2025
af4036d
fix compilation error for sm120
yongwww Aug 27, 2025
a49d1fd
Add #if defined(ENABLE_FP4) guards
yongwww Oct 17, 2025
13d8664
fix: use FLASHINFER_ENABLE_FP8_E8M0 guard for __nv_fp8_e8m0
yongwww Oct 17, 2025
4f94bf0
fix build
yongwww Oct 17, 2025
eddb10b
fix aot errors
yongwww Oct 17, 2025
ddb1345
Merge remote-tracking branch 'origin/main' into feature/cutlass_moe_u…
yongwww Oct 18, 2025
2563556
fix stale sm100 configs
aleozlx Oct 21, 2025
bfe2852
Merge branch 'main' of https://github.com/flashinfer-ai/flashinfer in…
aleozlx Oct 22, 2025
da54367
debug..
aleozlx Oct 22, 2025
a81fbd1
merge
nv-yunzheq Oct 30, 2025
0bbab20
remove debug stdout
nv-yunzheq Oct 31, 2025
0ad03f9
update incorrect comment
nv-yunzheq Nov 3, 2025
39691b5
update default layout to use swizzled
nv-yunzheq Nov 3, 2025
4a8c8cf
update format
nv-yunzheq Nov 3, 2025
f42af7c
update autotunner
nv-yunzheq Nov 4, 2025
42de94e
update precomiit
nv-yunzheq Nov 4, 2025
b09efb0
update sm121 failure
nv-yunzheq Nov 4, 2025
22b97b0
fix H100 unit test error
nv-yunzheq Nov 5, 2025
a30f033
address comments
nv-yunzheq Nov 5, 2025
31b0df0
fix compilation error
nv-yunzheq Nov 6, 2025
91a85ad
update compilation error
nv-yunzheq Nov 6, 2025
e15a96c
update to address comment
nv-yunzheq Nov 6, 2025
33aec35
fix compilation error
nv-yunzheq Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "moe_kernels.h"

namespace tensorrt_llm::kernels::cutlass_kernels {
// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;

#ifdef ENABLE_BF16
Expand All @@ -38,6 +37,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
#endif
#endif
#ifdef ENABLE_FP4
Expand All @@ -54,4 +54,12 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>;
#endif
#endif
}; // namespace tensorrt_llm::kernels::cutlass_kernels

// Explicit instantiations for finalizeMoeRoutingKernelLauncher to ensure
// symbols are emitted in the JIT library for common data types.
INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half);
INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float);
#ifdef ENABLE_BF16
INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16);
#endif
} // namespace tensorrt_llm::kernels::cutlass_kernels
871 changes: 431 additions & 440 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class DtypeUtils {
default:
TVM_FFI_ICHECK(false) << "unsupported data type";
}

return nvinfer1::DataType::kFLOAT; // supress compiler warning
}

private:
Expand Down Expand Up @@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type)
<< " specified for " << DLDataTypeToString(mActivationDtype);
}

return nullptr; // supress compiler warning
};

FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype,
Expand Down Expand Up @@ -219,7 +223,13 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
}

mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
mAllProfiles = mKernelRunner->getTactics();
// Get tactics for both GEMM1 and GEMM2, combine them
auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
mGemm1TacticCount = static_cast<int64_t>(gemm1_tactics.size());
mGemm2TacticCount = static_cast<int64_t>(gemm2_tactics.size());
mAllProfiles = gemm1_tactics;
mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end());
TVM_FFI_ICHECK(!mAllProfiles.empty())
<< "No valid tactics available for fused moe op with the requested input combination "
"Activation: "
Expand Down Expand Up @@ -367,27 +377,31 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf = true; // Assume input_sf is swizzled by default
int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
bool const use_lora = false; // No lora support yet
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
fc1_expert_weights.data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
activation_params, fc2_expert_weights.data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
quant_params, num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
#else
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
Expand All @@ -396,10 +410,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
activation_params, fc2_expert_weights.data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
#endif
}

Expand Down Expand Up @@ -547,39 +561,44 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf_ml = true; // Assume input_sf is swizzled by default
int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default
bool const use_lora_ml = false; // No lora support yet
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
fc1_expert_weights.data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
activation_params, fc2_expert_weights.data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
#else
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
swizzled_input_sf_ml, reinterpret_cast<int const*>(token_selected_experts.data_ptr()),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().data_ptr())
: nullptr,
fc1_expert_weights.data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr,
activation_params, fc2_expert_weights.data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
quant_params, num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml,
lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl,
stream);
#endif
}

Expand Down Expand Up @@ -641,19 +660,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
auto activation_dtype =
(mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype;
activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype;
int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config, enable_alltoall);
#else
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config);
#endif

Expand Down Expand Up @@ -691,6 +711,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
});
} else if (name == "get_tactic_num") {
return Function::FromTyped([this]() -> int64_t { return getTacticNum(); });
} else if (name == "get_gemm1_tactic_count") {
return Function::FromTyped([this]() -> int64_t { return mGemm1TacticCount; });
} else if (name == "get_gemm2_tactic_count") {
return Function::FromTyped([this]() -> int64_t { return mGemm2TacticCount; });
} else if (name == "run_moe") {
return Function::FromTyped(
[this](TensorView output, TensorView input, TensorView token_selected_experts,
Expand Down Expand Up @@ -758,6 +782,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
int64_t mGemm1TacticCount{0};
int64_t mGemm2TacticCount{0};

void setRunnerProfiles(Optional<Array<int64_t>> profile_ids) {
if (mUseDeepSeekFP8BlockScaling) {
Expand All @@ -771,13 +797,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
}

auto best_gemm1_profile = mAllProfiles.front();
auto best_gemm2_profile = mAllProfiles.front();
// Default GEMM2 profile should come from the GEMM2 subrange if present
auto best_gemm2_profile =
(mGemm2TacticCount > 0 && mAllProfiles.size() > static_cast<size_t>(mGemm1TacticCount))
? mAllProfiles.at(mGemm1TacticCount)
: mAllProfiles.front();
if (profile_ids.has_value()) {
TVM_FFI_ICHECK_EQ(profile_ids.value().size(), 2) << "Expecting 2 profile ids";
best_gemm1_profile = profile_ids.value()[0] == -1 ? best_gemm1_profile
: mAllProfiles.at(profile_ids.value()[0]);
best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
: mAllProfiles.at(profile_ids.value()[1]);
// GEMM1 index: accept absolute index; otherwise if clearly out of combined range, keep
// default
auto id1 = profile_ids.value()[0];
if (id1 != -1) {
TVM_FFI_ICHECK(id1 >= 0 && id1 < mGemm1TacticCount) << "Invalid gemm1 profile id: " << id1;
best_gemm1_profile = mAllProfiles.at(id1);
}

// GEMM2 index: support both absolute (combined) and relative (within GEMM2 subrange) ids
auto id2 = profile_ids.value()[1];
if (id2 != -1) {
int64_t absolute_id2 = id2;
// If id2 appears relative to GEMM2 subrange, offset it
if (id2 >= 0 && id2 < mGemm2TacticCount) {
absolute_id2 = mGemm1TacticCount + id2;
}
TVM_FFI_ICHECK(absolute_id2 >= 0 &&
absolute_id2 < static_cast<int64_t>(mAllProfiles.size()))
<< "Invalid gemm2 profile id: " << id2;
best_gemm2_profile = mAllProfiles.at(absolute_id2);
}
}
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
}
Expand Down
3 changes: 3 additions & 0 deletions csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,9 @@ using Int = ConstExprWrapper<int, VALUE>;
template <bool VALUE>
using Bool = ConstExprWrapper<bool, VALUE>;

template <bool VALUE>
using ConstBool = ConstExprWrapper<bool, VALUE>;

template <typename T>
struct TmaDescType;

Expand Down
Loading