Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "moe_kernels.h"

namespace tensorrt_llm::kernels::cutlass_kernels {
// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner<float, float>;

#ifdef ENABLE_BF16
Expand All @@ -38,6 +37,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
#endif
#endif
#ifdef ENABLE_FP4
Expand All @@ -54,4 +54,4 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, _
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp4_e2m1>;
#endif
#endif
}; // namespace tensorrt_llm::kernels::cutlass_kernels
} // namespace tensorrt_llm::kernels::cutlass_kernels
853 changes: 416 additions & 437 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class DtypeUtils {
default:
TVM_FFI_ICHECK(false) << "unsupported data type";
}

return nvinfer1::DataType::kFLOAT; // supress compiler warning
}

private:
Expand Down Expand Up @@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type)
<< " specified for " << DLDataTypeToString(mActivationDtype);
}

return nullptr; // supress compiler warning
};

FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype,
Expand Down Expand Up @@ -219,7 +223,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
}

mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
mAllProfiles = mKernelRunner->getTactics();
// Get tactics for both GEMM1 and GEMM2, combine them
auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1);
auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2);
mAllProfiles = gemm1_tactics;
mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end());
Comment on lines +226 to +230
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

🧩 Analysis chain

Verify tactic selection logic for GEMM1 vs GEMM2.

The tactics from GEMM1 and GEMM2 are now concatenated into a single mAllProfiles vector. Ensure that when selecting tactics by profile_id (e.g., at Line 618), the correct tactics are applied to the corresponding GEMM. The setRunnerProfiles method expects 2 profile IDs, so verify the indexing logic correctly distinguishes between GEMM1 and GEMM2 tactics.


🏁 Script executed:

#!/bin/bash
# Verify how profile_id is used to select tactics and ensure GEMM1/GEMM2 separation

# Check setRunnerProfiles implementation and how it maps profile IDs
rg -n -A 10 -B 3 'setRunnerProfiles' --type cpp

# Look for getTactics usage and how tactics are distinguished
rg -n -A 5 'getTactics.*GEMM' --type cpp

Length of output: 0


🏁 Script executed:

#!/bin/bash
# First, verify the file exists and examine the context around the modified lines
echo "=== Examining the target file around lines 226-230 ==="
head -n 240 csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | tail -n 30

echo ""
echo "=== Search for setRunnerProfiles (broader search) ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/

echo ""
echo "=== Search for how mAllProfiles is used ==="
rg -n -C 5 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for profile_id usage around line 618 ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for getTactics calls ==="
rg -n 'getTactics' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 806


🏁 Script executed:

#!/bin/bash
# Examine the target file and trace tactic selection logic

echo "=== Examining lines around 226-230 (mAllProfiles concatenation) ==="
sed -n '220,240p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for mAllProfiles usage throughout the file ==="
rg -n 'mAllProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Examining setRunnerProfiles method (lines 757-778 mentioned in scratchpad) ==="
sed -n '750,790p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for getTactics calls and GEMM references ==="
rg -n 'getTactics|GEMM_1|GEMM_2' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Examine profile_id usage pattern ==="
sed -n '610,630p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 5633


🏁 Script executed:

#!/bin/bash
# Examine the context around line 618 and how profile is used
echo "=== Full context around line 618 (profile selection) ==="
sed -n '600,650p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Search for where setRunnerProfiles is called ==="
rg -n 'setRunnerProfiles' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2

echo ""
echo "=== Search for mKernelRunner->setTactic calls ==="
rg -n 'setTactic' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 2 -B 2

echo ""
echo "=== Check runMoe signature and how it's called ==="
rg -n 'void runMoe' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu -A 8

Length of output: 5456


🏁 Script executed:

#!/bin/bash
# Search for where profile_ids is generated/populated and if there's any offset logic
echo "=== Search for profile_ids origin and computation ==="
rg -n 'profile_ids' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu | head -30

echo ""
echo "=== Search for getNumProfiles or similar methods ==="
rg -n 'getNumProfiles|getNumTactics|size()' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Check class member getters and how tactics are exposed ==="
sed -n '580,595p' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

echo ""
echo "=== Look for any comments or logic about profile indexing/offset ==="
rg -n -B 3 -A 3 'mAllProfiles\[' csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

Length of output: 12101


Apply offset correction when indexing GEMM2 tactics from concatenated mAllProfiles.

The tactics from GEMM1 and GEMM2 are concatenated into a single mAllProfiles vector (lines 229–230), but when selecting tactics by profile_ids in setRunnerProfiles (lines 772–775), the GEMM2 profile index is not offset. This causes GEMM2 to incorrectly select from GEMM1's tactics range.

After concatenation, GEMM2 tactics start at index gemm1_tactics.size(), but this offset is not applied when accessing mAllProfiles.at(profile_ids.value()[1]).

Fix: Apply offset for GEMM2 tactic selection:

best_gemm2_profile = profile_ids.value()[1] == -1 ? best_gemm2_profile
                                                  : mAllProfiles.at(profile_ids.value()[1] + gemm1_size);

where gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size() (stored during initialization).

πŸ€– Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 226–230 (and referencing setRunnerProfiles around lines 772–775),
the concatenation of GEMM1 and GEMM2 tactics into mAllProfiles does not account
for the GEMM1 length when later indexing GEMM2 tactics, causing wrong
selections; fix by capturing and storing gemm1_size during initialization (e.g.,
gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size()) and
then when selecting GEMM2 use profile index + gemm1_size (apply this offset
where mAllProfiles.at(profile_ids.value()[1]) is used), keeping the -1 sentinel
handling unchanged so that if profile_ids.value()[1] == -1 you keep the default,
otherwise index mAllProfiles at profile_ids.value()[1] + gemm1_size.

TVM_FFI_ICHECK(!mAllProfiles.empty())
<< "No valid tactics available for fused moe op with the requested input combination "
"Activation: "
Expand Down Expand Up @@ -368,25 +376,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf = false; // Assume input_sf is not swizzled by default
int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default
bool const use_lora = false; // No lora support yet
Comment on lines +379 to +382
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Document assumptions for default parameter values.

The hardcoded defaults for swizzled_input_sf, unpadded_hidden_size, and use_lora may not hold in all scenarios. Particularly, setting unpadded_hidden_size = hidden_size assumes no padding, which may be incorrect if padding occurs upstream. Consider:

  1. Validating these assumptions with assertions when these features are actually used
  2. Adding TODO comments linking to future work that will provide proper values
  3. Documenting these limitations in user-facing documentation
πŸ€– Prompt for AI Agents
In csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
around lines 372-375 the code hardcodes defaults for swizzled_input_sf,
unpadded_hidden_size and use_lora which may be incorrect in some configurations;
update the code to (1) add a TODO comment with a link to the tracking issue/PR
for providing proper values, (2) add runtime assertions or checks where these
flags/values are actually used (e.g., assert unpadded_hidden_size <= hidden_size
and verify swizzled_input_sf only when the input layout indicates swizzling),
and (3) if possible derive unpadded_hidden_size from upstream tensor metadata or
add it as an explicit parameter to the caller API and fall back to the current
default only with a clear warning log; also add a brief note in the module’s
user-facing docs describing this limitation and the expectation until full
support is implemented.

#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
: nullptr,
fc1_expert_weights->data,
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace->data), output->data,
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling,
min_latency_mode, min_latency_params, enable_pdl, stream);
mKernelRunner->runMoe(
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
: nullptr,
fc1_expert_weights->data,
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace->data),
output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
enable_alltoall, use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode,
min_latency_params, enable_pdl, stream);
#else
mKernelRunner->runMoe(
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr,
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
Expand All @@ -395,10 +407,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output->data,
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false,
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
#endif
}

Expand Down Expand Up @@ -544,25 +557,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

// TODO: support lora in the future
::tensorrt_llm::kernels::LoraParams lora_params{};
// HACK Define default values for parameters we don't have good values for
bool const swizzled_input_sf_ml = false; // Assume input_sf is not swizzled by default
int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default
bool const use_lora_ml = false; // No lora support yet
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
: nullptr,
fc1_expert_weights->data,
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total,
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace->data), output->data,
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling,
min_latency_mode, min_latency_params, enable_pdl, stream);
mKernelRunner->runMoe(
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
: nullptr,
fc1_expert_weights->data,
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace->data),
output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
enable_alltoall, use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode,
min_latency_params, enable_pdl, stream);
#else
mKernelRunner->runMoe(
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr,
input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml,
reinterpret_cast<int const*>(token_selected_experts->data),
token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value()->data)
Expand All @@ -571,10 +588,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr,
activation_params, fc2_expert_weights->data,
fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output->data,
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false,
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
#endif
}

Expand Down Expand Up @@ -636,19 +654,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
auto activation_dtype =
(mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype;
activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype;
int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config, enable_alltoall);
#else
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype),
DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA,
min_latency_mode,
hidden_size, unpadded_hidden_size_profiler, inter_size, group_size,
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config);
#endif

Expand Down
3 changes: 3 additions & 0 deletions csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,9 @@ using Int = ConstExprWrapper<int, VALUE>;
template <bool VALUE>
using Bool = ConstExprWrapper<bool, VALUE>;

template <bool VALUE>
using ConstBool = ConstExprWrapper<bool, VALUE>;

template <typename T>
struct TmaDescType;

Expand Down
Loading