diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 468b77d9593..6da2da63407 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { + // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; @@ -25,6 +26,34 @@ struct sm100_fp8_config_default { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm100_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + template typename Epilogue, typename... EpilogueArgs> @@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + using Cutlass3xGemmM64 = + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm100_fp8_config_M128::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } } template