Skip to content

Make moe permute and final as custom op #5358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
879 changes: 879 additions & 0 deletions cpp/tensorrt_llm/kernels/moeUtilOp.cu

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions cpp/tensorrt_llm/kernels/moeUtilOp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cutlass_kernels/include/moe_kernels.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace tensorrt_llm::kernels
{
bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts,
int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens,
int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert,
cudaStream_t stream);

void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_selected_experts,
int* unpermuted_source_token_ids, int64_t const num_tokens, int const num_experts_per_node,
int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream);

void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids,
int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
int64_t num_rows, int64_t num_experts_per_node, int64_t k, cutlass_kernels::CubKeyValueSorter& sorter,
void* sorter_ws, cudaStream_t stream);

template <class InputActivationsType, class ExpandedActivationsType>
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int const num_experts_per_node, float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);

template <class OutputType, class GemmOutputType, class ScaleBiasType>
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales,
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const experts_per_token, int64_t const* num_valid_ptr,
cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream);

} // namespace tensorrt_llm::kernels
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const nu
// FP4 Quantization

constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
// constexpr int CVT_FP4_SF_VEC_SIZE = 16;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
constexpr int CVT_FP4_THREADS_PER_WARP = 32;
constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16;

Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ add_library(
logitsBitmaskOp.cpp
mambaConv1dOp.cpp
moeOp.cpp
moeUtilOp.cpp
moeCommOp.cpp
moeLoadBalanceOp.cpp
fp8BlockScaleMoe.cpp
Expand Down
447 changes: 447 additions & 0 deletions cpp/tensorrt_llm/thop/moeUtilOp.cpp

Large diffs are not rendered by default.

143 changes: 126 additions & 17 deletions tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _(
trigger_completion_at_end,
):
from tensorrt_llm.functional import AllReduceFusionOp

if op == int(AllReduceFusionOp.NONE):
return [torch.empty_like(input)]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
Expand Down Expand Up @@ -55,7 +56,7 @@ def _(
else:
return [torch.empty_like(input)]

#MNNVL Allreduce
# MNNVL Allreduce
@torch.library.register_fake("trtllm::mnnvl_twoshot_allreduce")
def _(input, buffer, buffer_flags, wait_for_results):
output = input.new_empty(input.shape)
Expand All @@ -68,9 +69,18 @@ def _(comm_buf, gamma, eps, residual, buffer_flags):
return [output, residual_out]

@torch.library.register_fake("trtllm::moe_allreduce")
def _(residual, norm_weight, device_num_experts, scale_input,
active_experts_token_input, token_input, workspace, rank, nranks,
eps):
def _(
residual,
norm_weight,
device_num_experts,
scale_input,
active_experts_token_input,
token_input,
workspace,
rank,
nranks,
eps,
):
norm_out = torch.empty_like(token_input)
residual_out = torch.empty_like(residual)
return [norm_out, residual_out]
Expand Down Expand Up @@ -175,8 +185,10 @@ def _(
output_shape, scale_shape = fp4_utils.get_fp4_shape(
input.shape, sf_vec_size)

return (input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8))
return (
input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8),
)

@torch.library.register_fake("trtllm::moe_comm_prepare_indices")
def _(
Expand Down Expand Up @@ -210,9 +222,14 @@ def _(
backward_recv_rank_local_indices = gathered_target_rank_ids.new_empty(
backward_recv_rank_local_indices_shape, dtype=torch.int32)

return (local_gather_indices, send_rank_count_cum_sum,
send_rank_local_indices, recv_rank_count_cum_sum,
recv_rank_local_indices, backward_recv_rank_local_indices)
return (
local_gather_indices,
send_rank_count_cum_sum,
send_rank_local_indices,
recv_rank_count_cum_sum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
)

@torch.library.register_fake("trtllm::moe_local_gather")
def _(
Expand Down Expand Up @@ -263,14 +280,21 @@ def _(single_layer_load_balancer_ptr: int):
pass

@torch.library.register_fake("trtllm::moe_load_balance_statistic")
def _(single_layer_load_balancer_ptr: int,
gathered_raw_expert_ids: torch.Tensor, enabled: torch.Tensor,
is_first_stage: bool, is_last_stage: bool):
def _(
single_layer_load_balancer_ptr: int,
gathered_raw_expert_ids: torch.Tensor,
enabled: torch.Tensor,
is_first_stage: bool,
is_last_stage: bool,
):
pass

@torch.library.register_fake("trtllm::moe_load_balance_routing")
def _(single_layer_load_balancer_ptr: int,
token_selected_experts: torch.Tensor, offset_by_ep_rank: bool):
def _(
single_layer_load_balancer_ptr: int,
token_selected_experts: torch.Tensor,
offset_by_ep_rank: bool,
):
return torch.empty_like(token_selected_experts)

@torch.library.custom_op("trtllm::group_rms_norm_base",
Expand Down Expand Up @@ -339,9 +363,15 @@ def _(

@torch.library.register_fake(
"trtllm::mtp_sampling_and_accepted_draft_tokens_op")
def _(logits: torch.Tensor, draft_tokens: torch.Tensor,
target_tokens: torch.Tensor, num_mtp_modules: int, batch_size: int,
num_context_request: int, vocab_size: int):
def _(
logits: torch.Tensor,
draft_tokens: torch.Tensor,
target_tokens: torch.Tensor,
num_mtp_modules: int,
batch_size: int,
num_context_request: int,
vocab_size: int,
):
return logits.new_empty((batch_size, num_mtp_modules + 1),
dtype=torch.int32), logits.new_empty(
(batch_size, ), dtype=torch.int32)
Expand Down Expand Up @@ -384,3 +414,82 @@ def _(
pad_slot_id: int,
) -> None:
pass

@torch.library.register_fake("trtllm::moe_permute_op")
def _(
input: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc2_expert_weights: torch.Tensor,
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor],
num_experts_per_node: int,
tp_size: int,
tp_rank: int,
ep_size: int,
ep_rank: int,
cluster_size: int,
cluster_rank: int,
min_latency_mode: bool,
use_fp8_block_scaling: bool,
):

experts_per_token = token_selected_experts.shape[1]
num_rows = input.shape[0]
hidden_size = input.shape[1]

num_moe_inputs = experts_per_token * num_rows

unpermuted_token_selected_experts_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.int32)
unpermuted_source_token_ids_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.int32)
permuted_source_token_ids_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.int32)
permuted_token_selected_experts_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.int32)
permuted_data_tensor = input.new_empty((num_moe_inputs, hidden_size),
dtype=torch.float32)
expert_first_token_offset_tensor = token_selected_experts.new_empty(
(num_experts_per_node + 1, ), dtype=torch.int64)
permuted_token_final_scales_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.float32)
src_to_dest_map_tensor = token_selected_experts.new_empty(
(num_moe_inputs, ), dtype=torch.int32)

return (
unpermuted_token_selected_experts_tensor,
unpermuted_source_token_ids_tensor,
permuted_source_token_ids_tensor,
permuted_token_selected_experts_tensor,
permuted_data_tensor,
expert_first_token_offset_tensor,
permuted_token_final_scales_tensor,
src_to_dest_map_tensor,
)

@torch.library.register_fake("trtllm::moe_finalize_scale_op")
def _(
gemm2_output: torch.Tensor,
fc2_expert_biases: torch.Tensor,
unpermuted_final_scales: torch.Tensor,
expanded_source_row_to_expanded_dest_row: torch.Tensor,
expert_for_source_row: torch.Tensor,
expert_first_token_offset_tensor: torch.Tensor,
num_rows: int,
hidden_size: int,
experts_per_token: int,
num_experts_per_node: int,
tp_size: int,
tp_rank: int,
ep_size: int,
ep_rank: int,
):

return gemm2_output.new_empty((num_rows, hidden_size),
dtype=gemm2_output.dtype)


def fp8_quantize_1x128(input: torch.Tensor):
return torch.ops.trtllm.fp8_quantize_1x128(input)
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_wide_ep import WideEPMoE
from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import MoeLoadBalancer
Expand All @@ -14,6 +15,10 @@
SparseMixerMoeRoutingMethod, StaticMoeRoutingMethod)

__all__ = [
"VanillaMoE",
"CutlassFusedMoE",
"CuteDslFusedMoE",
"TRTLLMGenFusedMoE",
"BaseMoeRoutingMethod",
"create_moe",
"CutlassFusedMoE",
Expand Down
37 changes: 30 additions & 7 deletions tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tensorrt_llm.models.modeling_utils import QuantConfig

from ...model_config import ModelConfig
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
Expand All @@ -16,10 +17,11 @@


def get_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype] = None,
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype] = None,
override_quant_config: Optional[QuantConfig] = None,
) -> Type[MoE]:
moe_backend = model_config.moe_backend
quant_config = model_config.quant_config
if override_quant_config is not None:
Expand All @@ -28,6 +30,8 @@ def get_moe_cls(
return CutlassFusedMoE
elif moe_backend.upper() == "VANILLA":
return VanillaMoE
elif moe_backend.upper() == "CUTEDSL":
return CuteDslFusedMoE
elif moe_backend.upper() == "TRTLLM":
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
Expand Down Expand Up @@ -64,10 +68,13 @@ def create_moe(

moe_load_balancer = get_moe_load_balancer()
if moe_load_balancer is not None:
assert moe_cls == CutlassFusedMoE, "MoE Load Balance is only supported in CutlassFusedMoE now."
assert (moe_cls == CutlassFusedMoE
), "MoE Load Balance is only supported in CutlassFusedMoE now."

if moe_cls == TRTLLMGenFusedMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."

return moe_cls(
routing_method=routing_method,
Expand Down Expand Up @@ -109,7 +116,9 @@ def create_moe(
layer_idx=layer_idx,
)
elif moe_cls == VanillaMoE:
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE."
assert (
not apply_router_weight_on_input
), "apply_router_weight_on_input is not supported in VanillaMoE."

return moe_cls(
routing_method=routing_method,
Expand All @@ -122,5 +131,19 @@ def create_moe(
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif moe_cls == CuteDslFusedMoE:
return moe_cls(
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
else:
raise ValueError(f"Unsupported moe backend: {moe_cls}")
Loading
Loading