1919#include < cutlass/epilogue/collective/collective_builder.hpp> // @manual
2020// clang-format on
2121
22- #include " cutlass_extensions/include/fp8_blockwise_cutlass_helpers.h"
2322#include " cutlass_extensions/include/kernel_mode.h"
2423
2524namespace {
@@ -46,10 +45,7 @@ at::Tensor f8f8bf16_blockwise_impl(
4645 at::Tensor XQ, // FP8
4746 at::Tensor WQ, // FP8
4847 at::Tensor x_scale,
49- at::Tensor w_scale,
50- int64_t block_m,
51- int64_t block_n,
52- int64_t block_k) {
48+ at::Tensor w_scale) {
5349 // XQ: M x K
5450 // WQ: N x K
5551 // output: M x N
@@ -71,19 +67,15 @@ at::Tensor f8f8bf16_blockwise_impl(
7167 TORCH_CHECK (WQ.stride (0 ) == K);
7268 TORCH_CHECK (WQ.stride (1 ) == 1 );
7369
74- TORCH_CHECK (block_m % TB_N == 0 );
75- TORCH_CHECK (block_n % TB_M == 0 );
76- TORCH_CHECK (block_k % TB_K == 0 );
77-
7870 TORCH_CHECK (x_scale.dim () == 2 );
7971 TORCH_CHECK (w_scale.dim () == 2 );
80- TORCH_CHECK (x_scale.size (0 ) == ceil_div (M, block_m ));
81- TORCH_CHECK (x_scale.size (1 ) == ceil_div (K, block_k ));
82- TORCH_CHECK (w_scale.size (0 ) == ceil_div (N, block_n ));
83- TORCH_CHECK (w_scale.size (1 ) == ceil_div (K, block_k ));
84- TORCH_CHECK (x_scale.stride (0 ) == ceil_div (K, block_k ));
72+ TORCH_CHECK (x_scale.size (0 ) == ceil_div (M, TB_M ));
73+ TORCH_CHECK (x_scale.size (1 ) == ceil_div (K, TB_K ));
74+ TORCH_CHECK (w_scale.size (0 ) == ceil_div (N, TB_N ));
75+ TORCH_CHECK (w_scale.size (1 ) == ceil_div (K, TB_K ));
76+ TORCH_CHECK (x_scale.stride (0 ) == ceil_div (K, TB_K ));
8577 TORCH_CHECK (x_scale.stride (1 ) == 1 );
86- TORCH_CHECK (w_scale.stride (0 ) == ceil_div (K, block_k ));
78+ TORCH_CHECK (w_scale.stride (0 ) == ceil_div (K, TB_K ));
8779 TORCH_CHECK (w_scale.stride (1 ) == 1 );
8880
8981 TORCH_CHECK (XQ.dtype () == at::kFloat8_e4m3fn );
@@ -109,7 +101,7 @@ at::Tensor f8f8bf16_blockwise_impl(
109101 constexpr int AlignmentInputB = 16 / sizeof (ElementInputB);
110102
111103 using ElementOutput = cutlass::bfloat16_t ;
112- using LayoutOutput = cutlass::layout::ColumnMajor ;
104+ using LayoutOutput = cutlass::layout::RowMajor ;
113105 constexpr int AlignmentOutput = 16 / sizeof (ElementOutput);
114106
115107 using ElementAccumulator = float ;
@@ -129,6 +121,15 @@ at::Tensor f8f8bf16_blockwise_impl(
129121 // threadblocks in a
130122 // cluster
131123
124+ using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
125+ TB_M,
126+ TB_N,
127+ TB_K,
128+ cute::GMMA::Major::K,
129+ cute::GMMA::Major::K>;
130+ using LayoutSFA = decltype (ScaleConfig::deduce_layoutSFA ());
131+ using LayoutSFB = decltype (ScaleConfig::deduce_layoutSFB ());
132+
132133 using CollectiveEpilogue =
133134 typename cutlass::epilogue::collective::CollectiveBuilder<
134135 ArchTag,
@@ -147,17 +148,17 @@ at::Tensor f8f8bf16_blockwise_impl(
147148 cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp;
148149
149150 using MainLoopSchedule =
150- cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaling ;
151+ cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise ;
151152
152153 using CollectiveMainloop =
153154 typename cutlass::gemm::collective::CollectiveBuilder<
154155 ArchTag,
155156 OperatorClass,
156157 ElementInputA,
157- LayoutInputA,
158+ cute::tuple< LayoutInputA, LayoutSFA> ,
158159 AlignmentInputA,
159160 ElementInputB,
160- LayoutInputB,
161+ cute::tuple< LayoutInputB, LayoutSFB> ,
161162 AlignmentInputB,
162163 ElementAccumulator,
163164 TileShape,
@@ -178,24 +179,27 @@ at::Tensor f8f8bf16_blockwise_impl(
178179 using StrideOutput = typename Gemm::GemmKernel::StrideD;
179180
180181 StrideInputA stride_a = cutlass::make_cute_packed_stride (
181- StrideInputA{}, cute::make_shape (N , K, 1 ));
182+ StrideInputA{}, cute::make_shape (M , K, 1 ));
182183 StrideInputB stride_b = cutlass::make_cute_packed_stride (
183- StrideInputB{}, cute::make_shape (M , K, 1 ));
184+ StrideInputB{}, cute::make_shape (N , K, 1 ));
184185 StrideOutput stride_output = cutlass::make_cute_packed_stride (
185- StrideOutput{}, cute::make_shape (N, M, 1 ));
186+ StrideOutput{}, cute::make_shape (M, N, 1 ));
187+ LayoutSFA layout_SFA =
188+ ScaleConfig::tile_atom_to_shape_SFA (cute::make_shape (M, N, K, 1 ));
189+ LayoutSFB layout_SFB =
190+ ScaleConfig::tile_atom_to_shape_SFB (cute::make_shape (M, N, K, 1 ));
186191
187192 typename Gemm::Arguments arguments{
188193 cutlass::gemm::GemmUniversalMode::kGemm ,
189- {N, M , K},
190- {reinterpret_cast <cutlass::float_e4m3_t *>(WQ .data_ptr ()),
194+ {M, N , K},
195+ {reinterpret_cast <cutlass::float_e4m3_t *>(XQ .data_ptr ()),
191196 stride_a,
192- reinterpret_cast <cutlass::float_e4m3_t *>(XQ .data_ptr ()),
197+ reinterpret_cast <cutlass::float_e4m3_t *>(WQ .data_ptr ()),
193198 stride_b,
194- w_scale.data_ptr <float >(),
195199 x_scale.data_ptr <float >(),
196- static_cast < uint8_t >(block_n / TB_M) ,
197- static_cast < uint8_t >(block_m / TB_N ),
198- static_cast < uint8_t >(block_k / TB_K) },
200+ layout_SFA ,
201+ w_scale. data_ptr < float >( ),
202+ layout_SFB },
199203 {{},
200204 (cutlass::bfloat16_t *)Y.data_ptr <at::BFloat16>(),
201205 stride_output,
@@ -244,16 +248,19 @@ at::Tensor dispatch_fp8_blockwise_kernel(
244248 int64_t block_m,
245249 int64_t block_n,
246250 int64_t block_k) {
251+ TORCH_CHECK (
252+ block_m == 128 && block_n == 128 && block_k == 128 ,
253+ " Only 128x128x128 block size is supported" );
247254 KernelMode kernel = get_kernel_mode (XQ, WQ);
248255 if (kernel == KernelMode::Small) {
249256 return f8f8bf16_blockwise_impl<128 , 128 , 128 , 2 , 1 , 1 >(
250- XQ, WQ, x_scale, w_scale, block_m, block_n, block_k );
257+ XQ, WQ, x_scale, w_scale);
251258 } else if (kernel == KernelMode::Large) {
252259 return f8f8bf16_blockwise_impl<128 , 128 , 128 , 2 , 1 , 1 >(
253- XQ, WQ, x_scale, w_scale, block_m, block_n, block_k );
260+ XQ, WQ, x_scale, w_scale);
254261 } else {
255262 return f8f8bf16_blockwise_impl<128 , 128 , 128 , 1 , 2 , 1 >(
256- XQ, WQ, x_scale, w_scale, block_m, block_n, block_k );
263+ XQ, WQ, x_scale, w_scale);
257264 }
258265}
259266
@@ -281,9 +288,9 @@ at::Tensor f8f8bf16_blockwise(
281288 at::Tensor WQ, // FP8
282289 at::Tensor x_scale,
283290 at::Tensor w_scale,
284- int64_t block_m = 256 ,
285- int64_t block_n = 256 ,
286- int64_t block_k = 256 ) {
291+ int64_t block_m = 128 ,
292+ int64_t block_n = 128 ,
293+ int64_t block_k = 128 ) {
287294 throw std::runtime_error (
288295 " CUDA version is older than 12.0" ); // requires CUDA>=12
289296}
0 commit comments