Skip to content

Commit cd605e8

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Add MXFP4 PT reference quantization kernel and refactor CUTLASS FP4 GEMM (#4117)
Summary: Pull Request resolved: #4117 X-link: facebookresearch/FBGEMM#1199 Refactor FP4 CUTLASS GEMM to be more general to MXFP4 and NVFP4, and easier for future extension. Also add MXFP4 PyTorch reference quantization kernel for MXFP4 GEMM numeric verification. Reviewed By: q10 Differential Revision: D74270499 fbshipit-source-id: 621a8e97a45b1bfccc7ca05d952b14a4a480e149
1 parent c1abe76 commit cd605e8

30 files changed

+498
-131
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
)
2828
from fbgemm_gpu.experimental.gen_ai.quantize import (
2929
quantize_int4_preshuffle,
30-
scaled_fp4_quant,
30+
scale_mxfp4_quant,
31+
scale_nvfp4_quant,
3132
)
3233

3334
try:
@@ -2005,9 +2006,9 @@ def cuda(self) -> bool:
20052006

20062007

20072008
@register_quantize_op
2008-
class FP4Gemm(QuantizeOpBase):
2009+
class NVFP4Gemm(QuantizeOpBase):
20092010
"""
2010-
FP4 matmul with block-wise scaling.
2011+
NVFP4 matmul with block-wise scaling.
20112012
"""
20122013

20132014
def quantize(self, x, w):
@@ -2019,16 +2020,52 @@ def quantize(self, x, w):
20192020
)
20202021
global_scale = 1 / (x_global_scale * w_global_scale)
20212022

2022-
xq, x_scale = scaled_fp4_quant(x, x_global_scale)
2023-
wq, w_scale = scaled_fp4_quant(w, w_global_scale)
2023+
xq, x_scale = scale_nvfp4_quant(x, x_global_scale)
2024+
wq, w_scale = scale_nvfp4_quant(w, w_global_scale)
20242025
return xq, wq, x_scale, w_scale, global_scale
20252026

20262027
def compute(self, xq, wq, x_scale, w_scale, global_scale):
2027-
return torch.ops.fbgemm.f4f4bf16(xq, wq, x_scale, w_scale, global_scale)
2028+
return torch.ops.fbgemm.f4f4bf16(
2029+
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
2030+
)
20282031

20292032
def quantize_and_compute(self, x, w):
20302033
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
2031-
return self.compute(xq, wq, x_scale, w_scale, global_scale)
2034+
return self.compute(
2035+
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
2036+
)
2037+
2038+
@property
2039+
def name(self) -> str:
2040+
return "cutlass_nv_f4f4bf16"
2041+
2042+
@property
2043+
def hip(self) -> bool:
2044+
# F4F4BF16 only supported for cuda.
2045+
return False
2046+
2047+
@property
2048+
def cuda(self) -> bool:
2049+
return True
2050+
2051+
2052+
@register_quantize_op
2053+
class MXFP4Gemm(QuantizeOpBase):
2054+
"""
2055+
MXFP4 matmul with block-wise scaling.
2056+
"""
2057+
2058+
def quantize(self, x, w):
2059+
xq, x_scale = scale_mxfp4_quant(x)
2060+
wq, w_scale = scale_mxfp4_quant(w)
2061+
return xq, wq, x_scale, w_scale
2062+
2063+
def compute(self, xq, wq, x_scale, w_scale):
2064+
return torch.ops.fbgemm.f4f4bf16(xq, wq, x_scale, w_scale)
2065+
2066+
def quantize_and_compute(self, x, w):
2067+
xq, wq, x_scale, w_scale = self.quantize(x, w)
2068+
return self.compute(xq, wq, x_scale, w_scale)
20322069

20332070
@property
20342071
def name(self) -> str:

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 205 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _quantize(
164164
return wq, scales
165165

166166

167-
def scaled_fp4_quant(
167+
def scale_nvfp4_quant(
168168
input: torch.Tensor, input_global_scale: torch.Tensor
169169
) -> Tuple[torch.Tensor, torch.Tensor]:
170170
"""
@@ -216,3 +216,207 @@ def round_up(x: int, y: int) -> int:
216216
torch.ops.fbgemm.scaled_fp4_quant(output, input, output_scale, input_global_scale)
217217
output_scale = output_scale.view(torch.float8_e4m3fn)
218218
return output, output_scale
219+
220+
221+
def _fp32_to_fp4_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor:
222+
"""Converts a float32 tensor to a unpacked float4 tensor.
223+
Args:
224+
x (torch.Tensor): The input float32 tensor.
225+
ebits (int): The number of bits in the exponent.
226+
mbits (int): The number of bits in the mantissa.
227+
Returns:
228+
torch.Tensor: The resulting unpacked float4 tensor.
229+
"""
230+
231+
def _n_ones(n: int) -> int:
232+
return (1 << n) - 1
233+
234+
EBITS_F32, MBITS_F32 = 8, 23
235+
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
236+
237+
assert x.dtype == torch.float
238+
assert 1 + ebits + mbits <= 8
239+
240+
# calculate constants
241+
exp_bias = _n_ones(ebits - 1)
242+
max_int = _n_ones(ebits + mbits)
243+
sign_mask = 1 << (ebits + mbits)
244+
245+
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
246+
247+
# all E bits and M bits are 1s
248+
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
249+
250+
# E bits = 1, M bits = 0
251+
min_normal = 2 ** (1 - exp_bias)
252+
253+
denorm_exp = (
254+
# exp bias conversion between formats
255+
(F32_EXP_BIAS - exp_bias)
256+
# mantissa length difference between formats
257+
+ (MBITS_F32 - mbits)
258+
# add one to encoded exponent for denormalized numbers
259+
+ 1
260+
)
261+
denorm_mask_int = denorm_exp << MBITS_F32
262+
263+
# reinterpret int32 as float32
264+
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
265+
torch.float32
266+
)
267+
268+
# save the sign
269+
# Note that we have torch.uint32, but some ops like cpu bit shifts
270+
# do not work on it. So, we stay in int32.
271+
x = x.view(torch.int32)
272+
sign = x & 0x80000000
273+
274+
# set everything to positive, will add sign back at the end
275+
x = x ^ sign
276+
x = x.view(torch.float)
277+
278+
# rewrite saturate/denorm/norm branches without explicit data dependent
279+
# control flow, to be more compiler friendly
280+
saturate_mask = x >= max_normal
281+
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
282+
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
283+
284+
denormal_x = x + denorm_mask_float
285+
denormal_x = denormal_x.view(torch.int32)
286+
denormal_x -= denorm_mask_int
287+
denormal_x = denormal_x.to(torch.uint8)
288+
289+
normal_x = x.view(torch.int32)
290+
# resulting mantissa is odd
291+
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
292+
# update exponent, rounding bias part 1
293+
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
294+
normal_x += val_to_add
295+
# rounding bias part 2
296+
normal_x += mant_odd
297+
# take the bits!
298+
normal_x = normal_x >> (MBITS_F32 - mbits)
299+
normal_x = normal_x.to(torch.uint8)
300+
301+
x = torch.full_like(x, max_int, dtype=torch.uint8)
302+
x = torch.where(denormal_mask, denormal_x, x)
303+
x = torch.where(normal_mask, normal_x, x)
304+
305+
# add sign back
306+
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
307+
sign_lp = sign_lp.to(torch.uint8)
308+
# Right shift of a negative signed integer can fill the least significant
309+
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
310+
# doesn't have an uint32 dtype, we mask out these bits to get just the
311+
# f4 sign bit
312+
sign_lp = sign_lp & sign_mask
313+
x = x | sign_lp
314+
315+
return x.to(torch.uint8)
316+
317+
318+
def _to_blocked(x: torch.Tensor) -> torch.Tensor:
319+
"""Converts a tensor to the blocked layout.
320+
Args:
321+
x (torch.Tensor): The input tensor in non-blocked layout.
322+
Returns:
323+
torch.Tensor: The output tensor in the blocked layout.
324+
"""
325+
326+
def ceil_div(a: int, b: int) -> int:
327+
return (a + b - 1) // b
328+
329+
rows, cols = x.shape
330+
n_row_blocks = ceil_div(rows, 128)
331+
n_col_blocks = ceil_div(cols, 4)
332+
333+
# Calculate the padded shape
334+
padded_rows = n_row_blocks * 128
335+
padded_cols = n_col_blocks * 4
336+
337+
padded = x
338+
if (rows, cols) != (padded_rows, padded_cols):
339+
padded = torch.zeros(
340+
(padded_rows, padded_cols),
341+
device=x.device,
342+
dtype=x.dtype,
343+
)
344+
padded[:rows, :cols] = x
345+
346+
# Rearrange the blocks
347+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
348+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
349+
350+
return rearranged.flatten()
351+
352+
353+
# This PyTorch version refers to https://github.com/pytorch/ao/blob/v0.10.0/torchao/prototype/mx_formats/mx_tensor.py#L146
354+
def scale_mxfp4_quant(
355+
x: torch.Tensor, block_size: int = 32
356+
) -> Tuple[torch.Tensor, torch.Tensor]:
357+
"""
358+
Quantize input tensor to FP4 and return quantized tensor and scale.
359+
Args:
360+
x (torch.Tensor): The input tensor to be quantized to FP4
361+
block_size (int): The block size to use for quantization. Default is 32.
362+
Returns:
363+
xq (torch.Tensor): Quantized FP4 output tensor
364+
scale (torch.Tensor): Scale E8M0 tensor
365+
"""
366+
367+
F4_E2M1_MAX = 6.0
368+
E8M0_EXPONENT_BIAS = 127
369+
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
370+
371+
# calculate the scale in e8m0 format
372+
orig_shape = x.shape
373+
x = x.reshape(-1, block_size)
374+
375+
# find max value of the data
376+
# Note: this only implements the `minimally supported` version of
377+
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
378+
# section 6.3.
379+
max_abs = torch.amax(torch.abs(x), 1)
380+
max_pos = F4_E2M1_MAX
381+
382+
descale = max_abs / max_pos
383+
scale = torch.where(
384+
torch.isnan(descale),
385+
0xFF, # Handle biased exponent for nan
386+
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
387+
(
388+
torch.clamp(
389+
torch.ceil(torch.log2(descale)),
390+
min=-E8M0_EXPONENT_BIAS,
391+
max=E8M0_EXPONENT_BIAS,
392+
)
393+
+ E8M0_EXPONENT_BIAS
394+
).to(torch.uint8),
395+
)
396+
397+
descale_fp = torch.where(
398+
scale == 0,
399+
1.0,
400+
torch.exp2(E8M0_EXPONENT_BIAS - scale.to(torch.float32)),
401+
)
402+
403+
# scale and saturated cast the data elements to max of target dtype
404+
xq = torch.clamp(x * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos)
405+
406+
xq = xq.reshape(orig_shape)
407+
xq = _fp32_to_fp4_unpacked(xq, EBITS_F4_E2M1, MBITS_F4_E2M1)
408+
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
409+
410+
shape = xq.shape
411+
assert shape[-1] % 2 == 0
412+
xq = xq.contiguous().view(-1)
413+
xq = (xq[::2] << 4 | xq[1::2]).view((*shape[:-1], shape[-1] // 2))
414+
415+
target_numel = scale.numel() * block_size / 2
416+
assert target_numel == xq.numel(), f"{target_numel} != {xq.numel()}"
417+
418+
scale = scale.view(torch.float8_e8m0fnu)
419+
scale = scale.view(orig_shape[0], -1)
420+
scale = _to_blocked(scale)
421+
422+
return xq, scale

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ at::Tensor dispatch_f4f4bf16_kernel(
2424
at::Tensor WQ, // FP4
2525
at::Tensor x_scale,
2626
at::Tensor w_scale,
27-
at::Tensor global_scale,
28-
bool use_mx = false) {
27+
std::optional<at::Tensor> global_scale,
28+
bool use_mx = true) {
2929
auto M = XQ.size(0);
3030
auto K = XQ.size(1);
3131
auto N = WQ.size(0);
@@ -173,8 +173,8 @@ at::Tensor f4f4bf16(
173173
at::Tensor WQ, // FP4
174174
at::Tensor x_scale,
175175
at::Tensor w_scale,
176-
at::Tensor global_scale,
177-
bool use_mx = false) {
176+
std::optional<at::Tensor> global_scale,
177+
bool use_mx = true) {
178178
return dispatch_f4f4bf16_kernel(
179179
XQ, WQ, x_scale, w_scale, global_scale, use_mx);
180180
}
@@ -186,8 +186,8 @@ at::Tensor f4f4bf16(
186186
at::Tensor WQ, // FP4
187187
at::Tensor x_scale,
188188
at::Tensor w_scale,
189-
at::Tensor global_scale,
190-
bool use_mx = false) {
189+
std::optional<at::Tensor> global_scale,
190+
bool use_mx = true) {
191191
throw std::runtime_error(
192192
"CUDA version is older than 12.8"); // requires CUDA>=12.8
193193
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_f.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ at::Tensor f4f4bf16_128_128_4_1_1_f(
1717
at::Tensor WQ, // FP4
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
20-
at::Tensor global_scale) {
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
2121
// Dispatch this kernel to the correct underlying implementation.
22-
return _f4f4bf16<128, 128, 4, 1, 1, false>(
23-
XQ, WQ, x_scale, w_scale, global_scale);
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
4,
27+
1,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
2429
}
2530

2631
#endif

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_128_4_1_1_t.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ at::Tensor f4f4bf16_128_128_4_1_1_t(
1717
at::Tensor WQ, // FP4
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
20-
at::Tensor global_scale) {
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
2121
// Dispatch this kernel to the correct underlying implementation.
22-
return _f4f4bf16<128, 128, 4, 1, 1, true>(
23-
XQ, WQ, x_scale, w_scale, global_scale);
22+
return _f4f4bf16<
23+
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
4,
27+
1,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
2429
}
2530

2631
#endif

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_128_192_2_2_1_f.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ at::Tensor f4f4bf16_128_192_2_2_1_f(
1717
at::Tensor WQ, // FP4
1818
at::Tensor x_scale,
1919
at::Tensor w_scale,
20-
at::Tensor global_scale) {
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
2121
// Dispatch this kernel to the correct underlying implementation.
22-
return _f4f4bf16<128, 192, 2, 2, 1, false>(
23-
XQ, WQ, x_scale, w_scale, global_scale);
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
192,
26+
2,
27+
2,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
2429
}
2530

2631
#endif

0 commit comments

Comments
 (0)