Skip to content

Commit a9c2b70

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Build MXFP4/NVFP4 CUTLASS grouped GEMM (#4128)
Summary: Pull Request resolved: #4128 X-link: facebookresearch/FBGEMM#1209 Build and optimize MXFP4/NVFP4 CUTLASS grouped GEMM on Nvidia Blackwell GPUs, which provides large speedup (2.1x on average) compared to FP8 grouped GEMM on B200 {F1977992465} Reviewed By: q10, jianyuh Differential Revision: D74431182 fbshipit-source-id: 753acb5a0d4d8b51b8ea371ac89506445bb1102f
1 parent 1e9425f commit a9c2b70

File tree

3 files changed

+1168
-0
lines changed

3 files changed

+1168
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,3 +2079,217 @@ def hip(self) -> bool:
20792079
@property
20802080
def cuda(self) -> bool:
20812081
return True
2082+
2083+
2084+
@register_quantize_op
2085+
class MXFP4GroupedGemm(QuantizeOpBase):
2086+
"""
2087+
MXFP4 grouped matmul with blockwise scaling.
2088+
"""
2089+
2090+
def preprocess(self, x, w):
2091+
wq, w_scale = zip(*[scale_mxfp4_quant(i) for i in w])
2092+
return x, wq, w_scale
2093+
2094+
def quantize(self, x, wq, w_scale):
2095+
xq, x_scale = zip(*[scale_mxfp4_quant(i) for i in x])
2096+
return xq, wq, x_scale, w_scale
2097+
2098+
def compute(self, xq, wq, x_scale, w_scale):
2099+
return torch.ops.fbgemm.f4f4bf16_grouped(
2100+
xq,
2101+
wq,
2102+
x_scale,
2103+
w_scale,
2104+
)
2105+
2106+
def quantize_and_compute(self, x, wq, w_scale):
2107+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
2108+
return self.compute(xq, wq, x_scale, w_scale)
2109+
2110+
@property
2111+
def name(self) -> str:
2112+
return "cutlass_f4f4bf16_grouped"
2113+
2114+
@property
2115+
def hip(self) -> bool:
2116+
# F4F4BF16_grouped only supported for cuda.
2117+
return False
2118+
2119+
@property
2120+
def cuda(self) -> bool:
2121+
return True
2122+
2123+
2124+
@register_quantize_op
2125+
class NVFP4GroupedGemm(QuantizeOpBase):
2126+
"""
2127+
NVFP4 grouped matmul with blockwise scaling.
2128+
"""
2129+
2130+
def quantize(self, x, w):
2131+
def get_global_scale(x, w):
2132+
x_global_scale = ((448.0 * 6.0) / torch.amax(x.flatten(), dim=-1)).to(
2133+
torch.float32
2134+
)
2135+
w_global_scale = ((448.0 * 6.0) / torch.amax(w.flatten(), dim=-1)).to(
2136+
torch.float32
2137+
)
2138+
global_scale = 1 / (x_global_scale * w_global_scale)
2139+
return x_global_scale, w_global_scale, global_scale
2140+
2141+
# Compute global scale for each group
2142+
G = len(x)
2143+
x_global_scale = []
2144+
w_global_scale = []
2145+
global_scale = []
2146+
for i in range(G):
2147+
x_global_scale_, w_global_scale_, global_scale_ = get_global_scale(
2148+
x[i], w[i]
2149+
)
2150+
x_global_scale.append(x_global_scale_)
2151+
w_global_scale.append(w_global_scale_)
2152+
global_scale.append(global_scale_)
2153+
2154+
# Quantize weights and activations
2155+
wq, w_scale = zip(
2156+
*[scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
2157+
)
2158+
xq, x_scale = zip(
2159+
*[scale_nvfp4_quant(x[i], x_global_scale[i]) for i in range(G)]
2160+
)
2161+
return xq, wq, x_scale, w_scale, global_scale
2162+
2163+
def compute(self, xq, wq, x_scale, w_scale, global_scale):
2164+
return torch.ops.fbgemm.f4f4bf16_grouped(
2165+
xq, wq, x_scale, w_scale, global_scale, use_mx=False
2166+
)
2167+
2168+
def quantize_and_compute(self, x, w):
2169+
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
2170+
return self.compute(xq, wq, x_scale, w_scale, global_scale)
2171+
2172+
@property
2173+
def name(self) -> str:
2174+
return "cutlass_nv_f4f4bf16_grouped"
2175+
2176+
@property
2177+
def hip(self) -> bool:
2178+
return False
2179+
2180+
@property
2181+
def cuda(self) -> bool:
2182+
return True
2183+
2184+
2185+
@register_quantize_op
2186+
class MXFP4StackedGroupedGemm(QuantizeOpBase):
2187+
"""
2188+
MXFP4 grouped matmul with blockwise scaling and stacked inputs.
2189+
"""
2190+
2191+
def preprocess(self, x, w):
2192+
m_values = [i.shape[0] for i in x]
2193+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2194+
wq, w_scale = zip(*[scale_mxfp4_quant(i) for i in w])
2195+
wq = torch.stack(wq, dim=0).contiguous()
2196+
w_scale = torch.stack(w_scale, dim=0).contiguous()
2197+
return x, wq, w_scale, m_sizes
2198+
2199+
def quantize(self, x, wq, w_scale, m_sizes):
2200+
xq, x_scale = zip(*[scale_mxfp4_quant(i) for i in x])
2201+
xq = torch.stack(xq, dim=0).contiguous()
2202+
x_scale = torch.stack(x_scale, dim=0).contiguous()
2203+
xq = xq.view(-1, xq.shape[-1])
2204+
return xq, wq, x_scale, w_scale, m_sizes
2205+
2206+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
2207+
return torch.ops.fbgemm.f4f4bf16_grouped_stacked(
2208+
xq, wq, x_scale, w_scale, m_sizes
2209+
)
2210+
2211+
def quantize_and_compute(self, x, w):
2212+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, w)
2213+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
2214+
2215+
@property
2216+
def name(self) -> str:
2217+
return "cutlass_f4f4bf16_grouped_stacked"
2218+
2219+
@property
2220+
def hip(self) -> bool:
2221+
return False
2222+
2223+
@property
2224+
def cuda(self) -> bool:
2225+
return True
2226+
2227+
2228+
@register_quantize_op
2229+
class NVFP4StackedGroupedGemm(QuantizeOpBase):
2230+
"""
2231+
NVFP4 grouped matmul with blockwise scaling and stacked inputs.
2232+
"""
2233+
2234+
def quantize(self, x, w):
2235+
def get_global_scale(x, w):
2236+
x_global_scale = ((448.0 * 6.0) / torch.amax(x.flatten(), dim=-1)).to(
2237+
torch.float32
2238+
)
2239+
w_global_scale = ((448.0 * 6.0) / torch.amax(w.flatten(), dim=-1)).to(
2240+
torch.float32
2241+
)
2242+
global_scale = 1 / (x_global_scale * w_global_scale)
2243+
return x_global_scale, w_global_scale, global_scale
2244+
2245+
m_values = [i.shape[0] for i in x]
2246+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2247+
2248+
# Compute global scale for each group
2249+
G = len(x)
2250+
x_global_scale = []
2251+
w_global_scale = []
2252+
global_scale = []
2253+
for i in range(G):
2254+
x_global_scale_, w_global_scale_, global_scale_ = get_global_scale(
2255+
x[i], w[i]
2256+
)
2257+
x_global_scale.append(x_global_scale_)
2258+
w_global_scale.append(w_global_scale_)
2259+
global_scale.append(global_scale_)
2260+
2261+
wq, w_scale = zip(
2262+
*[scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
2263+
)
2264+
wq = torch.stack(wq, dim=0).contiguous()
2265+
w_scale = torch.stack(w_scale, dim=0).contiguous()
2266+
2267+
xq, x_scale = zip(
2268+
*[scale_nvfp4_quant(x[i], x_global_scale[i]) for i in range(G)]
2269+
)
2270+
xq = torch.stack(xq, dim=0).contiguous()
2271+
x_scale = torch.stack(x_scale, dim=0).contiguous()
2272+
xq = xq.view(-1, xq.shape[-1])
2273+
global_scale = torch.stack(global_scale, dim=0).contiguous()
2274+
return xq, wq, x_scale, w_scale, m_sizes, global_scale
2275+
2276+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, global_scale):
2277+
return torch.ops.fbgemm.f4f4bf16_grouped_stacked(
2278+
xq, wq, x_scale, w_scale, m_sizes, global_scale, use_mx=False
2279+
)
2280+
2281+
def quantize_and_compute(self, x, w):
2282+
xq, wq, x_scale, w_scale, m_sizes, global_scale = self.quantize(x, w)
2283+
return self.compute(xq, wq, x_scale, w_scale, m_sizes, global_scale)
2284+
2285+
@property
2286+
def name(self) -> str:
2287+
return "cutlass_nv_f4f4bf16_grouped_stacked"
2288+
2289+
@property
2290+
def hip(self) -> bool:
2291+
return False
2292+
2293+
@property
2294+
def cuda(self) -> bool:
2295+
return True

0 commit comments

Comments
 (0)