2424#include  < numeric> 
2525#include  < random> 
2626#include  < sstream> 
27+ #include  < type_traits> 
2728
2829#include  " tensorrt_llm/common/memoryUtils.h" 
2930#include  " tensorrt_llm/common/workspace.h" 
@@ -865,7 +866,7 @@ void threeStepBuildExpertMapsSortFirstToken(
865866// ============================== Infer GEMM sizes ================================= 
866867//  TODO Could linear search be better for small # experts
867868template  <class  T >
868- __device__  inline  int64_t  findTotalEltsLessThanTarget (T const * sorted_indices,
869+ __device__  inline  int64_t  findTotalEltsLessThanTarget_v1 (T const * sorted_indices,
869870                                                      int64_t  const  arr_length, T const  target) {
870871  int64_t  low = 0 , high = arr_length - 1 , target_location = -1 ;
871872  while  (low <= high) {
@@ -881,6 +882,48 @@ __device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
881882  return  target_location + 1 ;
882883}
883884
885+ template  <int  ARR_LENGTH_CONST, class  T >
886+ __device__  inline  int64_t  findTotalEltsLessThanTarget_v2 (T const * sorted_indices, int64_t  const  arr_length, T const  target) {
887+   if  (arr_length != ARR_LENGTH_CONST) {
888+       asm (" trap;"  );
889+   }
890+ 
891+   constexpr  unsigned  full_mask = 0xffffffffu ;
892+   constexpr  int  WARP_SZ = 32 ;
893+   const  int  lane_id = threadIdx .x  & (WARP_SZ - 1 );
894+ 
895+   int  local_count = 0 ;
896+ #pragma  unroll
897+   for  (int  k = 0 ; k < ARR_LENGTH_CONST / WARP_SZ; ++k) {
898+     const  int  idx = lane_id + k * WARP_SZ;
899+     T v = sorted_indices[idx];
900+     local_count += (v < target) ? 1  : 0 ;
901+   }
902+ 
903+ #pragma  unroll
904+   for  (int  offset = 16 ; offset > 0 ; offset >>= 1 ) {
905+     local_count += __shfl_down_sync (full_mask, local_count, offset);
906+   }
907+   int  total = __shfl_sync (full_mask, local_count, 0 );
908+ 
909+   return  (int64_t )total;
910+ }
911+ 
912+ template  <int  ARR_LENGTH_CONST, class  T >
913+ __device__  inline  int64_t  findTotalEltsLessThanTarget (T const * sorted_indices, int64_t  const  arr_length, T const  target) {
914+ //      return findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
915+ 
916+     return  findTotalEltsLessThanTarget_v2<ARR_LENGTH_CONST>(sorted_indices, arr_length, target);
917+ 
918+ //      int64_t out_v1 = findTotalEltsLessThanTarget_v1(sorted_indices, arr_length, target);
919+ //      int64_t out_v2 = findTotalEltsLessThanTarget_v2(sorted_indices, arr_length, target);
920+ //      if (out_v1 != out_v2) {
921+ //          printf("different output! v1=%lld v2=%lld\n", out_v1, out_v2);
922+ //          asm("trap;");
923+ //      }
924+ //      return out_v1;
925+ }
926+ 
884927template  <class  T >
885928using  sizeof_bits = cutlass::sizeof_bits<
886929    typename  cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t <T>>::type>;
@@ -1418,16 +1461,19 @@ constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
14181461
14191462template  <class  InputActivationsType , class  ExpandedActivationsType ,
14201463          TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
1421-           bool  PRE_QUANT_AWQ>
1464+           bool  PRE_QUANT_AWQ,  int  NUM_EXPERTS_PER_NODE_CONST =  128 >
14221465__global__  void  expandInputRowsKernel (
14231466    InputActivationsType const * unpermuted_input, ExpandedActivationsType* permuted_output,
14241467    float  const * unpermuted_scales, float * permuted_scales,
1425-     int  const * permuted_row_to_unpermuted_row, int64_t  const  num_tokens, int64_t  const  hidden_size ,
1468+     int  const * permuted_row_to_unpermuted_row, int64_t  const  num_tokens, int64_t  const  hidden_size_real_ ,
14261469    int64_t  const  k, float  const * fc1_act_global_scale, bool  use_per_expert_act_scale,
14271470    int64_t  const * expert_first_token_offset,
14281471    TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
14291472    TmaWarpSpecializedGroupedGemmInput::ElementSF const * input_sf,
14301473    int64_t  const  num_experts_per_node, InputActivationsType const * prequant_scales = nullptr ) {
1474+   constexpr  int  hidden_size = 7168 ;
1475+   if  (hidden_size != hidden_size_real_) { asm (" trap;"  ); }
1476+ 
14311477  static_assert (BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ||
14321478                    !PRE_QUANT_AWQ,
14331479                " AWQ and Block Scaling are mutually exclusive"  );
@@ -1503,14 +1549,14 @@ __global__ void expandInputRowsKernel(
15031549                         permuted_row * hidden_size / ELEM_PER_THREAD;
15041550
15051551    int64_t  const  start_offset = threadIdx .x ;
1506-     int64_t   const  stride = EXPAND_THREADS_PER_BLOCK;
1507-     int64_t   const  num_elems_in_col = hidden_size / ELEM_PER_THREAD;
1552+     constexpr   int64_t  stride = EXPAND_THREADS_PER_BLOCK;
1553+     constexpr   int64_t  num_elems_in_col = hidden_size / ELEM_PER_THREAD;
15081554    assert (hidden_size % ELEM_PER_THREAD == 0 );
15091555    assert (hidden_size % VecSize == 0 );
15101556
15111557    if  constexpr  (is_nvfp4 || is_mxfp8) {
15121558      static_assert (ELEM_PER_THREAD == 8 , " Expecting 8 elements per thread for quantized types"  );
1513-       int64_t  expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
1559+       int64_t  expert = findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node,
15141560                                                   (int64_t )permuted_row + 1 ) -
15151561                       1 ;
15161562
@@ -1519,6 +1565,7 @@ __global__ void expandInputRowsKernel(
15191565      float  global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1 .0f ;
15201566      int64_t  num_tokens_before_expert = expert_first_token_offset[expert];
15211567
1568+ #pragma  unroll
15221569      for  (int  elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
15231570        auto  in_vec = source_row_ptr[elem_index];
15241571        if  constexpr  (need_nvfp4_quant || need_mxfp8_quant) {
@@ -1687,9 +1734,20 @@ void expandInputRowsKernelLauncher(
16871734      TLLM_CHECK_WITH_INFO (quant_params.fp4 .fc1 .weight_block_scale ,
16881735                           " NVFP4 block scaling is expected for FP4xFP4"  );
16891736      TLLM_CHECK_WITH_INFO (!prequant_scales, " NVFP4 is not supported for AWQ"  );
1690-       return  &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1737+       if  (num_experts_per_node == 128 ) {
1738+         constexpr  int  NUM_EXPERTS_PER_NODE_CONST = 128 ;
1739+         return  &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
16911740                                    TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1692-                                     false >;
1741+                                     false , NUM_EXPERTS_PER_NODE_CONST>;
1742+       }
1743+       if  (num_experts_per_node == 64 ) {
1744+         constexpr  int  NUM_EXPERTS_PER_NODE_CONST = 64 ;
1745+         return  &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
1746+                                     TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4,
1747+                                     false , NUM_EXPERTS_PER_NODE_CONST>;
1748+       }
1749+       printf (" unsupported num_experts_per_node\n "  );
1750+       exit (1 );
16931751    } else 
16941752#endif 
16951753    {
@@ -1748,11 +1806,16 @@ constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
17481806//  This kernel unpermutes the original data, does the k-way reduction and performs the final skip
17491807//  connection.
17501808template  <typename  OutputType, class  GemmOutputType , class  ScaleBiasType , ScaleMode SCALE_MODE>
1751- __global__  void  finalizeMoeRoutingKernel (
1809+ __global__ 
1810+ __maxnreg__ (64 )
1811+ void  finalizeMoeRoutingKernel (
17521812    GemmOutputType const * expanded_permuted_rows, OutputType* reduced_unpermuted_output,
17531813    ScaleBiasType const * bias, float  const * scales, int  const * unpermuted_row_to_permuted_row,
1754-     int  const * token_selected_experts, int64_t  const  orig_cols, int64_t  const  experts_per_token ,
1814+     int  const * token_selected_experts, int64_t  const  orig_cols, int64_t  const  experts_per_token_real_ ,
17551815    int  const  num_experts_per_node, int  const  start_expert_id) {
1816+   constexpr  int  experts_per_token = 8 ;
1817+   if  (experts_per_token != experts_per_token_real_) { asm (" trap;"  ); }
1818+ 
17561819  int64_t  const  original_row = blockIdx .x ;
17571820  int64_t  const  num_rows = gridDim .x ;
17581821  auto  const  offset = original_row * orig_cols;
@@ -2078,7 +2141,7 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output,
20782141  float  gate_bias = 0 .0f ;
20792142  float  gate_limit = std::numeric_limits<float >::infinity ();
20802143  if  (activation_type.swiglu_alpha  || activation_type.swiglu_beta  || activation_type.swiglu_limit ) {
2081-     int  expert = findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node,
2144+     int  expert = findTotalEltsLessThanTarget< 128 > (expert_first_token_offset, num_experts_per_node,
20822145                                             (int64_t )token + 1 ) -
20832146                 1 ;
20842147    gate_alpha = activation_type.swiglu_alpha  ? activation_type.swiglu_alpha [expert] : 1 .0f ;
@@ -2126,14 +2189,17 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
21262189// ============================== Activation ================================= 
21272190
21282191template  <class  T , class  GemmOutputType , class  ScaleBiasType , class  ActFn ,
2129-           TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
2192+           TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,  int  NUM_EXPERTS_PER_NODE_CONST =  128 >
21302193__global__  void  doActivationKernel (T* output, GemmOutputType const * gemm_result,
21312194                                   float  const * fp8_quant, ScaleBiasType const * bias_ptr,
21322195                                   bool  bias_is_broadcast, int64_t  const * expert_first_token_offset,
2133-                                    int  num_experts_per_node, int64_t  inter_size ,
2196+                                    int  num_experts_per_node, int64_t  inter_size_real_ ,
21342197                                   float  const * fc2_act_global_scale, bool  use_per_expert_act_scale,
21352198                                   TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat,
21362199                                   ActivationParams activation_params) {
2200+   constexpr  int  inter_size = 2048 ;
2201+   if  (inter_size != inter_size_real_) { asm (" trap;"  ); }
2202+ 
21372203#ifdef  ENABLE_FP4
21382204  constexpr  bool  IsNVFP4 =
21392205      std::is_same_v<T, __nv_fp4_e2m1> &&
@@ -2186,7 +2252,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21862252        activation_params.swiglu_limit ) {
21872253      //  TODO this is almost certainly faster as a linear scan
21882254      expert =
2189-           findTotalEltsLessThanTarget (expert_first_token_offset, num_experts_per_node, token + 1 ) -
2255+           findTotalEltsLessThanTarget<NUM_EXPERTS_PER_NODE_CONST> (expert_first_token_offset, num_experts_per_node, token + 1 ) -
21902256          1 ;
21912257      gate_alpha = activation_params.swiglu_alpha  ? activation_params.swiglu_alpha [expert] : 1 .0f ;
21922258      gate_beta = activation_params.swiglu_beta  ? activation_params.swiglu_beta [expert] : 0 .0f ;
@@ -2218,16 +2284,18 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
22182284    auto  output_vec = reinterpret_cast <OutputElem*>(safe_inc_ptr (output, output_offset));
22192285    auto  bias_ptr_vec = reinterpret_cast <BiasElem const *>(bias_ptr + bias_offset);
22202286    int64_t  const  start_offset = tid;
2221-     int64_t   const  stride = ACTIVATION_THREADS_PER_BLOCK;
2287+     constexpr   int64_t  stride = ACTIVATION_THREADS_PER_BLOCK;
22222288    assert (inter_size % ACTIVATION_ELEM_PER_THREAD == 0 );
2223-     int64_t   const  num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
2289+     constexpr   int64_t  num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
22242290    assert (gated_off % ACTIVATION_ELEM_PER_THREAD == 0 );
22252291    int64_t  const  gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
22262292
22272293    ActFn fn{};
22282294    fn.alpha  = gate_alpha;
22292295    fn.beta  = gate_beta;
22302296    fn.limit  = gate_limit;
2297+ 
2298+ #pragma  unroll
22312299    for  (int64_t  elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
22322300      auto  fc1_value =
22332301          arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
@@ -2358,30 +2426,62 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23582426
23592427  auto  fn = [&]() {
23602428    auto  fn = [&](auto  block_scaling_type) {
2361-       auto  fn_list = std::array{
2362-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2363-                               IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2364-                               decltype (block_scaling_type)::value>,  //  Gelu
2365-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2366-                               IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2367-                               decltype (block_scaling_type)::value>,  //  Relu
2368-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2369-                               IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2370-                               decltype (block_scaling_type)::value>,  //  Silu
2371-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2372-                               GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2373-                               decltype (block_scaling_type)::value>,  //  Swiglu
2374-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2375-                               GLUAdaptor<cutlass::epilogue::thread::GELU>,
2376-                               decltype (block_scaling_type)::value>,  //  Geglu
2377-           &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2378-                               decltype (block_scaling_type)::value>,  //  SwigluBias
2379-           &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2380-                               IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2381-                               decltype (block_scaling_type)::value>  //  Identity
2382- 
2383-       };
2384-       return  fn_list[static_cast <int >(activation_type.activation_type )];
2429+       if  (num_experts_per_node == 128 ) {
2430+         constexpr  int  NUM_EXPERTS_PER_NODE_CONST = 128 ;
2431+         auto  fn_list = std::array{
2432+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2433+                                 IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2434+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Gelu
2435+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2436+                                 IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2437+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Relu
2438+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2439+                                 IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2440+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Silu
2441+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2442+                                 GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2443+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Swiglu
2444+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2445+                                 GLUAdaptor<cutlass::epilogue::thread::GELU>,
2446+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Geglu
2447+             &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2448+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  SwigluBias
2449+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2450+                                 IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2451+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>  //  Identity
2452+ 
2453+         };
2454+         return  fn_list[static_cast <int >(activation_type.activation_type )];
2455+       }
2456+       if  (num_experts_per_node == 64 ) {
2457+         constexpr  int  NUM_EXPERTS_PER_NODE_CONST = 64 ;
2458+         auto  fn_list = std::array{
2459+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2460+                                 IdentityAdaptor<cutlass::epilogue::thread::GELU>,
2461+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Gelu
2462+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2463+                                 IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
2464+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Relu
2465+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2466+                                 IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
2467+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Silu
2468+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2469+                                 GLUAdaptor<cutlass::epilogue::thread::SiLu>,
2470+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Swiglu
2471+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2472+                                 GLUAdaptor<cutlass::epilogue::thread::GELU>,
2473+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  Geglu
2474+             &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
2475+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST>,  //  SwigluBias
2476+             &doActivationKernel<T, GemmOutputType, ScaleBiasType,
2477+                                 IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2478+                                 decltype (block_scaling_type)::value, NUM_EXPERTS_PER_NODE_CONST> //  Identity
2479+ 
2480+         };
2481+         return  fn_list[static_cast <int >(activation_type.activation_type )];
2482+       }
2483+       printf (" unsupported num_experts_per_node\n "  );
2484+       exit (1 );
23852485    };
23862486    auto  NVFP4 = tensorrt_llm::common::ConstExprWrapper<
23872487        TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
0 commit comments