-
Notifications
You must be signed in to change notification settings - Fork 558
chore: upgrade cutlass moe kernel launcher to match trtllm #1925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f109a2b
3cfba8e
9047135
d9d7723
3166245
4dce85b
d47c865
307fe30
76a9220
96c0ed4
af4036d
a49d1fd
13d8664
4f94bf0
eddb10b
ddb1345
2563556
bfe2852
da54367
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,8 @@ class DtypeUtils { | |
| default: | ||
| TVM_FFI_ICHECK(false) << "unsupported data type"; | ||
| } | ||
|
|
||
| return nvinfer1::DataType::kFLOAT; // supress compiler warning | ||
| } | ||
|
|
||
| private: | ||
|
|
@@ -111,6 +113,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| TVM_FFI_ICHECK(false) << "Invalid output type " << DLDataTypeToString(output_type) | ||
| << " specified for " << DLDataTypeToString(mActivationDtype); | ||
| } | ||
|
|
||
| return nullptr; // supress compiler warning | ||
| }; | ||
|
|
||
| FusedMoeRunner(DLDataType activation_dtype, DLDataType weight_dtype, DLDataType output_dtype, | ||
|
|
@@ -219,7 +223,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| } | ||
|
|
||
| mProfiler = std::make_shared<kernels::GemmProfilerBackend>(); | ||
| mAllProfiles = mKernelRunner->getTactics(); | ||
| // Get tactics for both GEMM1 and GEMM2, combine them | ||
| auto gemm1_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1); | ||
| auto gemm2_tactics = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_2); | ||
| mAllProfiles = gemm1_tactics; | ||
| mAllProfiles.insert(mAllProfiles.end(), gemm2_tactics.begin(), gemm2_tactics.end()); | ||
| TVM_FFI_ICHECK(!mAllProfiles.empty()) | ||
| << "No valid tactics available for fused moe op with the requested input combination " | ||
| "Activation: " | ||
|
|
@@ -368,25 +376,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
|
|
||
| // TODO: support lora in the future | ||
| ::tensorrt_llm::kernels::LoraParams lora_params{}; | ||
| // HACK Define default values for parameters we don't have good values for | ||
| bool const swizzled_input_sf = false; // Assume input_sf is not swizzled by default | ||
| int64_t const unpadded_hidden_size = hidden_size; // Assume no padding by default | ||
| bool const use_lora = false; // No lora support yet | ||
|
Comment on lines
+379
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document assumptions for default parameter values. The hardcoded defaults for
π€ Prompt for AI Agents |
||
| #ifdef USING_OSS_CUTLASS_MOE_GEMM | ||
| mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
| : nullptr, | ||
| fc1_expert_weights->data, | ||
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, | ||
| quant_params, num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace->data), output->data, | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, | ||
| enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, | ||
| min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| mKernelRunner->runMoe( | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
| : nullptr, | ||
| fc1_expert_weights->data, | ||
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, | ||
| num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace->data), | ||
| output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, | ||
| enable_alltoall, use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, | ||
| min_latency_params, enable_pdl, stream); | ||
| #else | ||
| mKernelRunner->runMoe( | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
|
|
@@ -395,10 +407,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, | ||
| num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace), output->data, | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, | ||
| mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| num_rows, hidden_size, unpadded_hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace), | ||
| output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, | ||
| use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, | ||
| enable_pdl, stream); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -544,25 +557,29 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
|
|
||
| // TODO: support lora in the future | ||
| ::tensorrt_llm::kernels::LoraParams lora_params{}; | ||
| // HACK Define default values for parameters we don't have good values for | ||
| bool const swizzled_input_sf_ml = false; // Assume input_sf is not swizzled by default | ||
| int64_t const unpadded_hidden_size_ml = hidden_size; // Assume no padding by default | ||
| bool const use_lora_ml = false; // No lora support yet | ||
| #ifdef USING_OSS_CUTLASS_MOE_GEMM | ||
| mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
| : nullptr, | ||
| fc1_expert_weights->data, | ||
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, | ||
| quant_params, num_rows, hidden_size, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace->data), output->data, | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, | ||
| enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, | ||
| min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| mKernelRunner->runMoe( | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
| : nullptr, | ||
| fc1_expert_weights->data, | ||
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, | ||
| num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace->data), | ||
| output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, | ||
| enable_alltoall, use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, | ||
| min_latency_params, enable_pdl, stream); | ||
| #else | ||
| mKernelRunner->runMoe( | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, | ||
| input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, swizzled_input_sf_ml, | ||
| reinterpret_cast<int const*>(token_selected_experts->data), | ||
| token_final_scales.has_value() | ||
| ? reinterpret_cast<float const*>(token_final_scales.value()->data) | ||
|
|
@@ -571,10 +588,11 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, | ||
| activation_params, fc2_expert_weights->data, | ||
| fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, | ||
| num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace), output->data, | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, | ||
| mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| num_rows, hidden_size, unpadded_hidden_size_ml, inter_size, num_experts_total, | ||
| static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace), | ||
| output->data, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, | ||
| use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, | ||
| enable_pdl, stream); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -636,19 +654,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| auto activation_dtype = | ||
| (mUseW4GroupScaling && !isWFP4A16Quant()) ? dl_float8_e4m3fn : mActivationDtype; | ||
| activation_dtype = isNvfp4Quant() ? dl_int64 : activation_dtype; | ||
| int64_t const unpadded_hidden_size_profiler = hidden_size; // HACK no padding by default | ||
| #ifdef USING_OSS_CUTLASS_MOE_GEMM | ||
| mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, | ||
| DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), | ||
| DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k), | ||
| hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, | ||
| min_latency_mode, | ||
| hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, | ||
| activation_type, USE_BIAS, USE_LORA, min_latency_mode, | ||
| /*need_weights*/ false, parallelism_config, enable_alltoall); | ||
| #else | ||
| mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, | ||
| DtypeUtils::dataType(activation_dtype), DtypeUtils::dataType(mWeightDtype), | ||
| DtypeUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k), | ||
| hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, | ||
| min_latency_mode, | ||
| hidden_size, unpadded_hidden_size_profiler, inter_size, group_size, | ||
| activation_type, USE_BIAS, USE_LORA, min_latency_mode, | ||
| /*need_weights*/ false, parallelism_config); | ||
| #endif | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
Verify tactic selection logic for GEMM1 vs GEMM2.
The tactics from GEMM1 and GEMM2 are now concatenated into a single
mAllProfilesvector. Ensure that when selecting tactics byprofile_id(e.g., at Line 618), the correct tactics are applied to the corresponding GEMM. ThesetRunnerProfilesmethod expects 2 profile IDs, so verify the indexing logic correctly distinguishes between GEMM1 and GEMM2 tactics.π Script executed:
Length of output: 0
π Script executed:
Length of output: 806
π Script executed:
Length of output: 5633
π Script executed:
Length of output: 5456
π Script executed:
Length of output: 12101
Apply offset correction when indexing GEMM2 tactics from concatenated mAllProfiles.
The tactics from GEMM1 and GEMM2 are concatenated into a single
mAllProfilesvector (lines 229β230), but when selecting tactics byprofile_idsinsetRunnerProfiles(lines 772β775), the GEMM2 profile index is not offset. This causes GEMM2 to incorrectly select from GEMM1's tactics range.After concatenation, GEMM2 tactics start at index
gemm1_tactics.size(), but this offset is not applied when accessingmAllProfiles.at(profile_ids.value()[1]).Fix: Apply offset for GEMM2 tactic selection:
where
gemm1_size = mKernelRunner->getTactics(kernels::MoeGemmId::GEMM_1).size()(stored during initialization).π€ Prompt for AI Agents