Skip to content

Commit 2674b39

Browse files
jwfrommmeta-codesync[bot]
authored andcommitted
Modernize FP8 Blockwise GEMM (#5002)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2015 Pull Request resolved: #5002 FBGEMM (and MSLK) previously relied on a customized implementation of blockwise scaling that has bitrotted over the past few releases of CUTLASS. Now that CUTLASS has fully supported blockwise scaling natively, it is time to switch over to it. This diff updates our blockwise FP8 gemm to use the standard cutlass implementation. This improves performance a bit and resolves some correctness issues. Reviewed By: jerryzh168 Differential Revision: D84632911 fbshipit-source-id: f68a04041b6a31ac84e4ac6ba72da3e8c203ebd6
1 parent 0d49628 commit 2674b39

File tree

2 files changed

+42
-1639
lines changed

2 files changed

+42
-1639
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include <cutlass/epilogue/collective/collective_builder.hpp> // @manual
2020
// clang-format on
2121

22-
#include "cutlass_extensions/include/fp8_blockwise_cutlass_helpers.h"
2322
#include "cutlass_extensions/include/kernel_mode.h"
2423

2524
namespace {
@@ -46,10 +45,7 @@ at::Tensor f8f8bf16_blockwise_impl(
4645
at::Tensor XQ, // FP8
4746
at::Tensor WQ, // FP8
4847
at::Tensor x_scale,
49-
at::Tensor w_scale,
50-
int64_t block_m,
51-
int64_t block_n,
52-
int64_t block_k) {
48+
at::Tensor w_scale) {
5349
// XQ: M x K
5450
// WQ: N x K
5551
// output: M x N
@@ -71,19 +67,15 @@ at::Tensor f8f8bf16_blockwise_impl(
7167
TORCH_CHECK(WQ.stride(0) == K);
7268
TORCH_CHECK(WQ.stride(1) == 1);
7369

74-
TORCH_CHECK(block_m % TB_N == 0);
75-
TORCH_CHECK(block_n % TB_M == 0);
76-
TORCH_CHECK(block_k % TB_K == 0);
77-
7870
TORCH_CHECK(x_scale.dim() == 2);
7971
TORCH_CHECK(w_scale.dim() == 2);
80-
TORCH_CHECK(x_scale.size(0) == ceil_div(M, block_m));
81-
TORCH_CHECK(x_scale.size(1) == ceil_div(K, block_k));
82-
TORCH_CHECK(w_scale.size(0) == ceil_div(N, block_n));
83-
TORCH_CHECK(w_scale.size(1) == ceil_div(K, block_k));
84-
TORCH_CHECK(x_scale.stride(0) == ceil_div(K, block_k));
72+
TORCH_CHECK(x_scale.size(0) == ceil_div(M, TB_M));
73+
TORCH_CHECK(x_scale.size(1) == ceil_div(K, TB_K));
74+
TORCH_CHECK(w_scale.size(0) == ceil_div(N, TB_N));
75+
TORCH_CHECK(w_scale.size(1) == ceil_div(K, TB_K));
76+
TORCH_CHECK(x_scale.stride(0) == ceil_div(K, TB_K));
8577
TORCH_CHECK(x_scale.stride(1) == 1);
86-
TORCH_CHECK(w_scale.stride(0) == ceil_div(K, block_k));
78+
TORCH_CHECK(w_scale.stride(0) == ceil_div(K, TB_K));
8779
TORCH_CHECK(w_scale.stride(1) == 1);
8880

8981
TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn);
@@ -109,7 +101,7 @@ at::Tensor f8f8bf16_blockwise_impl(
109101
constexpr int AlignmentInputB = 16 / sizeof(ElementInputB);
110102

111103
using ElementOutput = cutlass::bfloat16_t;
112-
using LayoutOutput = cutlass::layout::ColumnMajor;
104+
using LayoutOutput = cutlass::layout::RowMajor;
113105
constexpr int AlignmentOutput = 16 / sizeof(ElementOutput);
114106

115107
using ElementAccumulator = float;
@@ -129,6 +121,15 @@ at::Tensor f8f8bf16_blockwise_impl(
129121
// threadblocks in a
130122
// cluster
131123

124+
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
125+
TB_M,
126+
TB_N,
127+
TB_K,
128+
cute::GMMA::Major::K,
129+
cute::GMMA::Major::K>;
130+
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
131+
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
132+
132133
using CollectiveEpilogue =
133134
typename cutlass::epilogue::collective::CollectiveBuilder<
134135
ArchTag,
@@ -147,17 +148,17 @@ at::Tensor f8f8bf16_blockwise_impl(
147148
cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp;
148149

149150
using MainLoopSchedule =
150-
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaling;
151+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
151152

152153
using CollectiveMainloop =
153154
typename cutlass::gemm::collective::CollectiveBuilder<
154155
ArchTag,
155156
OperatorClass,
156157
ElementInputA,
157-
LayoutInputA,
158+
cute::tuple<LayoutInputA, LayoutSFA>,
158159
AlignmentInputA,
159160
ElementInputB,
160-
LayoutInputB,
161+
cute::tuple<LayoutInputB, LayoutSFB>,
161162
AlignmentInputB,
162163
ElementAccumulator,
163164
TileShape,
@@ -178,24 +179,27 @@ at::Tensor f8f8bf16_blockwise_impl(
178179
using StrideOutput = typename Gemm::GemmKernel::StrideD;
179180

180181
StrideInputA stride_a = cutlass::make_cute_packed_stride(
181-
StrideInputA{}, cute::make_shape(N, K, 1));
182+
StrideInputA{}, cute::make_shape(M, K, 1));
182183
StrideInputB stride_b = cutlass::make_cute_packed_stride(
183-
StrideInputB{}, cute::make_shape(M, K, 1));
184+
StrideInputB{}, cute::make_shape(N, K, 1));
184185
StrideOutput stride_output = cutlass::make_cute_packed_stride(
185-
StrideOutput{}, cute::make_shape(N, M, 1));
186+
StrideOutput{}, cute::make_shape(M, N, 1));
187+
LayoutSFA layout_SFA =
188+
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
189+
LayoutSFB layout_SFB =
190+
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
186191

187192
typename Gemm::Arguments arguments{
188193
cutlass::gemm::GemmUniversalMode::kGemm,
189-
{N, M, K},
190-
{reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()),
194+
{M, N, K},
195+
{reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()),
191196
stride_a,
192-
reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()),
197+
reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()),
193198
stride_b,
194-
w_scale.data_ptr<float>(),
195199
x_scale.data_ptr<float>(),
196-
static_cast<uint8_t>(block_n / TB_M),
197-
static_cast<uint8_t>(block_m / TB_N),
198-
static_cast<uint8_t>(block_k / TB_K)},
200+
layout_SFA,
201+
w_scale.data_ptr<float>(),
202+
layout_SFB},
199203
{{},
200204
(cutlass::bfloat16_t*)Y.data_ptr<at::BFloat16>(),
201205
stride_output,
@@ -244,16 +248,19 @@ at::Tensor dispatch_fp8_blockwise_kernel(
244248
int64_t block_m,
245249
int64_t block_n,
246250
int64_t block_k) {
251+
TORCH_CHECK(
252+
block_m == 128 && block_n == 128 && block_k == 128,
253+
"Only 128x128x128 block size is supported");
247254
KernelMode kernel = get_kernel_mode(XQ, WQ);
248255
if (kernel == KernelMode::Small) {
249256
return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>(
250-
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
257+
XQ, WQ, x_scale, w_scale);
251258
} else if (kernel == KernelMode::Large) {
252259
return f8f8bf16_blockwise_impl<128, 128, 128, 2, 1, 1>(
253-
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
260+
XQ, WQ, x_scale, w_scale);
254261
} else {
255262
return f8f8bf16_blockwise_impl<128, 128, 128, 1, 2, 1>(
256-
XQ, WQ, x_scale, w_scale, block_m, block_n, block_k);
263+
XQ, WQ, x_scale, w_scale);
257264
}
258265
}
259266

@@ -281,9 +288,9 @@ at::Tensor f8f8bf16_blockwise(
281288
at::Tensor WQ, // FP8
282289
at::Tensor x_scale,
283290
at::Tensor w_scale,
284-
int64_t block_m = 256,
285-
int64_t block_n = 256,
286-
int64_t block_k = 256) {
291+
int64_t block_m = 128,
292+
int64_t block_n = 128,
293+
int64_t block_k = 128) {
287294
throw std::runtime_error(
288295
"CUDA version is older than 12.0"); // requires CUDA>=12
289296
}

0 commit comments

Comments
 (0)