@@ -882,9 +882,8 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v1(T const* sorted_indices
882882 return target_location + 1 ;
883883}
884884
885- template <class T >
885+ template <class T , int ARR_LENGTH_CONST >
886886__device__ inline int64_t findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t const arr_length, T const target) {
887- constexpr int ARR_LENGTH_CONST = 128 ;
888887 if (arr_length != ARR_LENGTH_CONST) {
889888 asm (" trap;" );
890889 }
@@ -910,11 +909,11 @@ __device__ inline int64_t findTotalEltsLessThanTarget_v2(T const* sorted_indices
910909 return (int64_t )total;
911910}
912911
913- template <class T >
912+ template <class T , int ARR_LENGTH_CONST >
914913__device__ inline int64_t findTotalEltsLessThanTarget (T const * sorted_indices, int64_t const arr_length, T const target) {
915914// return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
916915
917- return findTotalEltsLessThanTarget_v2 (sorted_indices, arr_length, target);
916+ return findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST> (sorted_indices, arr_length, target);
918917
919918// int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
920919// int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
@@ -1462,7 +1461,7 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 128;
14621461
14631462template <class InputActivationsType , class ExpandedActivationsType ,
14641463 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1465- bool PRE_QUANT_AWQ>
1464+ bool PRE_QUANT_AWQ, int NUM_EXPERTS_PER_NODE_CONST = 128 >
14661465__global__ void expandInputRowsKernel (
14671466 InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
14681467 float const * unpermuted_scales, float * permuted_scales,
@@ -1557,7 +1556,7 @@ __global__ void expandInputRowsKernel(
15571556
15581557 if constexpr (is_nvfp4 || is_mxfp8) {
15591558 static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types" );
1560- int64_t expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1559+ int64_t expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
15611560 (int64_t )permuted_row + 1 ) -
15621561 1 ;
15631562
@@ -1735,9 +1734,20 @@ void expandInputRowsKernelLauncher(
17351734 TLLM_CHECK_WITH_INFO (quant_params.fp4 .fc1 .weight_block_scale ,
17361735 " NVFP4 block scaling is expected for FP4xFP4" );
17371736 TLLM_CHECK_WITH_INFO (!prequant_scales, " NVFP4 is not supported for AWQ" );
1738- return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737+ if (num_experts_per_node == 128 ) {
1738+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
1739+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
17391740 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1740- false >;
1741+ false , NUM_EXPERTS_PER_NODE_CONST>;
1742+ }
1743+ if (num_experts_per_node == 64 ) {
1744+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 64 ;
1745+ return &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747+ false , NUM_EXPERTS_PER_NODE_CONST>;
1748+ }
1749+ printf (" unsupported num_experts_per_node\n " );
1750+ exit (1 );
17411751 } else
17421752#endif
17431753 {
@@ -2159,7 +2169,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
21592169 float gate_bias = 0 .0f ;
21602170 float gate_limit = std::numeric_limits<float >::infinity ();
21612171 if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit ) {
2162- int expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
2172+ int expert = findTotalEltsLessThanTarget< 128 > (expert_first_token_offset, num_experts_per_node,
21632173 (int64_t )token + 1 ) -
21642174 1 ;
21652175 gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha [expert] : 1 .0f ;
@@ -2207,7 +2217,7 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
22072217// ============================== Activation =================================
22082218
22092219template <class T , class GemmOutputType , class ScaleBiasType , class ActFn ,
2210- TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2220+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int NUM_EXPERTS_PER_NODE_CONST = 128 >
22112221__global__ void doActivationKernel (T* output, GemmOutputType const * gemm_result,
22122222 float const * fp8_quant, ScaleBiasType const * bias_ptr,
22132223 bool bias_is_broadcast, int64_t const * expert_first_token_offset,
@@ -2270,7 +2280,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22702280 activation_params.swiglu_limit ) {
22712281 // TODO this is almost certainly faster as a linear scan
22722282 expert =
2273- findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2283+ findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
22742284 1 ;
22752285 gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha [expert] : 1 .0f ;
22762286 gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta [expert] : 0 .0f ;
@@ -2444,30 +2454,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
24442454
24452455 auto fn = [&]() {
24462456 auto fn = [&](auto block_scaling_type) {
2447- auto fn_list = std::array{
2448- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2449- IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2450- decltype (block_scaling_type)::value>, // Gelu
2451- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2452- IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2453- decltype (block_scaling_type)::value>, // Relu
2454- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2455- IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2456- decltype (block_scaling_type)::value>, // Silu
2457- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2458- GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2459- decltype (block_scaling_type)::value>, // Swiglu
2460- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461- GLUAdaptor<cutlass::epilogue::thread::GELU>,
2462- decltype (block_scaling_type)::value>, // Geglu
2463- &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2464- decltype (block_scaling_type)::value>, // SwigluBias
2465- &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466- IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2467- decltype (block_scaling_type)::value> // Identity
2468-
2469- };
2470- return fn_list[static_cast <int >(activation_type.activation_type )];
2457+ if (num_experts_per_node == 128 ) {
2458+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2459+ auto fn_list = std::array{
2460+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2461+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2462+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2463+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2464+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2465+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2466+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2467+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2468+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2469+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2470+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2471+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2472+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2473+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2474+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2475+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2476+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2477+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2478+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2479+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> // Identity
2480+
2481+ };
2482+ return fn_list[static_cast <int >(activation_type.activation_type )];
2483+ }
2484+ if (num_experts_per_node == 64 ) {
2485+ constexpr int NUM_EXPERTS_PER_NODE_CONST = 128 ;
2486+ auto fn_list = std::array{
2487+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2488+ IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2489+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Gelu
2490+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2491+ IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2492+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Relu
2493+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2494+ IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2495+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Silu
2496+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2497+ GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2498+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Swiglu
2499+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2500+ GLUAdaptor<cutlass::epilogue::thread::GELU>,
2501+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // Geglu
2502+ &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2503+ decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>, // SwigluBias
2504+ &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2505+ IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2506+ decltype (block_scaling_type)::value>, NUM_EXPERTS_PER_NODE_CONST // Identity
2507+
2508+ };
2509+ return fn_list[static_cast <int >(activation_type.activation_type )];
2510+ }
2511+ printf (" unsupported num_experts_per_node\n " );
2512+ exit (1 );
24712513 };
24722514 auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<
24732515 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
0 commit comments