@@ -1457,7 +1457,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) {
14571457// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the
14581458// source matrix, we simply take the modulus of the expanded index.
14591459
1460- constexpr static int EXPAND_THREADS_PER_BLOCK = 256 ;
1460+ constexpr static int EXPAND_THREADS_PER_BLOCK = 128 ;
14611461
14621462template <class InputActivationsType , class ExpandedActivationsType ,
14631463 TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType,
@@ -1697,7 +1697,7 @@ void expandInputRowsKernelLauncher(
16971697
16981698 static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount ();
16991699 // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
1700- int64_t const blocks = std::min (smCount * 8 , std::max (num_rows * k, num_padding_tokens));
1700+ int64_t const blocks = std::min (smCount * 16 , std::max (num_rows * k, num_padding_tokens));
17011701 int64_t const threads = EXPAND_THREADS_PER_BLOCK;
17021702
17031703 auto func = [&]() {
@@ -1813,6 +1813,10 @@ void finalizeMoeRoutingKernel(
18131813 ScaleBiasType const * bias, float const * scales, int const * unpermuted_row_to_permuted_row,
18141814 int const * token_selected_experts, int64_t const orig_cols, int64_t const experts_per_token_real_,
18151815 int const num_experts_per_node, int const start_expert_id) {
1816+ if constexpr (not (std::is_same_v<GemmOutputType, __nv_bfloat16> and std::is_same_v<OutputType, __nv_bfloat16>)) {
1817+ printf (" finalizeMoeRoutingKernel see unsupported dtype\n " );
1818+ asm (" trap;" );
1819+ } else {
18161820 constexpr int experts_per_token = 8 ;
18171821 if (experts_per_token != experts_per_token_real_) { asm (" trap;" ); }
18181822
@@ -1847,16 +1851,19 @@ void finalizeMoeRoutingKernel(
18471851 for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
18481852 ComputeElem thread_output;
18491853 thread_output.fill (0 );
1854+
1855+ #pragma unroll
18501856 for (int k_idx = 0 ; k_idx < experts_per_token; ++k_idx) {
18511857 int64_t const k_offset = original_row * experts_per_token + k_idx;
18521858 int64_t const expert_id = token_selected_experts[k_offset] - start_expert_id;
1853- if (expert_id < 0 || expert_id >= num_experts_per_node) {
1854- continue ;
1855- }
18561859
18571860 int64_t const expanded_original_row = original_row + k_idx * num_rows;
18581861 int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row];
18591862
1863+ if (expert_id < 0 || expert_id >= num_experts_per_node) {
1864+ continue ;
1865+ }
1866+
18601867 int64_t expanded_rows = num_rows * experts_per_token;
18611868 if (expanded_permuted_row < 0 || expanded_permuted_row >= expanded_rows) {
18621869 continue ;
@@ -1884,6 +1891,7 @@ void finalizeMoeRoutingKernel(
18841891 asm volatile (" griddepcontrol.launch_dependents;" );
18851892#endif
18861893}
1894+ }
18871895
18881896// Final kernel to unpermute and scale
18891897// This kernel unpermutes the original data, does the k-way reduction and performs the final skip
0 commit comments