Skip to content

Commit 492bd23

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Support Triton unpacked MXFP4 quantization kernel (#4116)
Summary: Pull Request resolved: #4116 X-link: facebookresearch/FBGEMM#1198 Support Triton unpacked MXFP4 quantization kernel. The previous MXFP4 quantization kernel only supported packed MXFP4 quantized tensor + scale together. Will follow up on enabling fast Triton quantization kernel for MXFP4 (grouped) GEMM: 1. Enable BF16 -> MXFP4 quantization (current kernels support FP32 -> MXFP4) 2. Enable swizzle layout in the Triton quantization kernel, which is required in both CUTLASS and Triton GEMMs/grouped GEMMs 3. Enable intrinsic FP4 support which will save ~20 instructions, which is essential to performance especially latency-bound cases Reviewed By: q10 Differential Revision: D74050324 fbshipit-source-id: 1ac9ba9f7961cdb078a5136ba81001eeb6192787
1 parent e051cad commit 492bd23

File tree

2 files changed

+480
-0
lines changed

2 files changed

+480
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import unittest
9+
from typing import Tuple
10+
11+
import torch
12+
13+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
14+
triton_quantize_mx4_unpack,
15+
)
16+
from fbgemm_gpu.quantize_utils import fp32_to_mx4, RoundingMode
17+
18+
19+
@unittest.skipIf(
20+
not torch.cuda.is_available()
21+
or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
22+
"Skip when H100 is not available",
23+
)
24+
class TestFp4Quantize(unittest.TestCase):
25+
def setUp(self) -> None:
26+
torch.manual_seed(0)
27+
28+
def test_quantize_fp4(self) -> None:
29+
def _test_quantize_fp4(
30+
shape: Tuple[int, int],
31+
device: str = "cuda",
32+
) -> None:
33+
M, N = shape
34+
group_size = 32
35+
rounding_mode = RoundingMode.even
36+
packed_group_size = group_size // 2
37+
groups_per_row = math.ceil(N / group_size)
38+
39+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
40+
xq_ref, x_scale_ref = triton_quantize_mx4_unpack(
41+
x, group_size=group_size, rounding_mode=rounding_mode
42+
)
43+
xq_packed = fp32_to_mx4(
44+
x, group_size=group_size, rounding_mode=rounding_mode
45+
)
46+
47+
xq = torch.empty([M, N // 2], device=x.device, dtype=torch.uint8)
48+
x_scale = torch.empty(
49+
[M, groups_per_row], device=x.device, dtype=torch.uint8
50+
)
51+
52+
for i in range(groups_per_row):
53+
start_idx = i * (packed_group_size + 1)
54+
end_idx = start_idx + packed_group_size
55+
xq[:, i * packed_group_size : (i + 1) * packed_group_size] = xq_packed[
56+
:, start_idx:end_idx
57+
]
58+
x_scale[:, i] = xq_packed[:, end_idx]
59+
60+
self.assertTrue(torch.equal(xq, xq_ref))
61+
self.assertTrue(torch.equal(x_scale, x_scale_ref))
62+
63+
_test_quantize_fp4((1, 128))
64+
_test_quantize_fp4((3, 512))
65+
_test_quantize_fp4((128, 1024))

0 commit comments

Comments
 (0)