@@ -2079,3 +2079,217 @@ def hip(self) -> bool:
2079
2079
@property
2080
2080
def cuda (self ) -> bool :
2081
2081
return True
2082
+
2083
+
2084
+ @register_quantize_op
2085
+ class MXFP4GroupedGemm (QuantizeOpBase ):
2086
+ """
2087
+ MXFP4 grouped matmul with blockwise scaling.
2088
+ """
2089
+
2090
+ def preprocess (self , x , w ):
2091
+ wq , w_scale = zip (* [scale_mxfp4_quant (i ) for i in w ])
2092
+ return x , wq , w_scale
2093
+
2094
+ def quantize (self , x , wq , w_scale ):
2095
+ xq , x_scale = zip (* [scale_mxfp4_quant (i ) for i in x ])
2096
+ return xq , wq , x_scale , w_scale
2097
+
2098
+ def compute (self , xq , wq , x_scale , w_scale ):
2099
+ return torch .ops .fbgemm .f4f4bf16_grouped (
2100
+ xq ,
2101
+ wq ,
2102
+ x_scale ,
2103
+ w_scale ,
2104
+ )
2105
+
2106
+ def quantize_and_compute (self , x , wq , w_scale ):
2107
+ xq , wq , x_scale , w_scale = self .quantize (x , wq , w_scale )
2108
+ return self .compute (xq , wq , x_scale , w_scale )
2109
+
2110
+ @property
2111
+ def name (self ) -> str :
2112
+ return "cutlass_f4f4bf16_grouped"
2113
+
2114
+ @property
2115
+ def hip (self ) -> bool :
2116
+ # F4F4BF16_grouped only supported for cuda.
2117
+ return False
2118
+
2119
+ @property
2120
+ def cuda (self ) -> bool :
2121
+ return True
2122
+
2123
+
2124
+ @register_quantize_op
2125
+ class NVFP4GroupedGemm (QuantizeOpBase ):
2126
+ """
2127
+ NVFP4 grouped matmul with blockwise scaling.
2128
+ """
2129
+
2130
+ def quantize (self , x , w ):
2131
+ def get_global_scale (x , w ):
2132
+ x_global_scale = ((448.0 * 6.0 ) / torch .amax (x .flatten (), dim = - 1 )).to (
2133
+ torch .float32
2134
+ )
2135
+ w_global_scale = ((448.0 * 6.0 ) / torch .amax (w .flatten (), dim = - 1 )).to (
2136
+ torch .float32
2137
+ )
2138
+ global_scale = 1 / (x_global_scale * w_global_scale )
2139
+ return x_global_scale , w_global_scale , global_scale
2140
+
2141
+ # Compute global scale for each group
2142
+ G = len (x )
2143
+ x_global_scale = []
2144
+ w_global_scale = []
2145
+ global_scale = []
2146
+ for i in range (G ):
2147
+ x_global_scale_ , w_global_scale_ , global_scale_ = get_global_scale (
2148
+ x [i ], w [i ]
2149
+ )
2150
+ x_global_scale .append (x_global_scale_ )
2151
+ w_global_scale .append (w_global_scale_ )
2152
+ global_scale .append (global_scale_ )
2153
+
2154
+ # Quantize weights and activations
2155
+ wq , w_scale = zip (
2156
+ * [scale_nvfp4_quant (w [i ], w_global_scale [i ]) for i in range (G )]
2157
+ )
2158
+ xq , x_scale = zip (
2159
+ * [scale_nvfp4_quant (x [i ], x_global_scale [i ]) for i in range (G )]
2160
+ )
2161
+ return xq , wq , x_scale , w_scale , global_scale
2162
+
2163
+ def compute (self , xq , wq , x_scale , w_scale , global_scale ):
2164
+ return torch .ops .fbgemm .f4f4bf16_grouped (
2165
+ xq , wq , x_scale , w_scale , global_scale , use_mx = False
2166
+ )
2167
+
2168
+ def quantize_and_compute (self , x , w ):
2169
+ xq , wq , x_scale , w_scale , global_scale = self .quantize (x , w )
2170
+ return self .compute (xq , wq , x_scale , w_scale , global_scale )
2171
+
2172
+ @property
2173
+ def name (self ) -> str :
2174
+ return "cutlass_nv_f4f4bf16_grouped"
2175
+
2176
+ @property
2177
+ def hip (self ) -> bool :
2178
+ return False
2179
+
2180
+ @property
2181
+ def cuda (self ) -> bool :
2182
+ return True
2183
+
2184
+
2185
+ @register_quantize_op
2186
+ class MXFP4StackedGroupedGemm (QuantizeOpBase ):
2187
+ """
2188
+ MXFP4 grouped matmul with blockwise scaling and stacked inputs.
2189
+ """
2190
+
2191
+ def preprocess (self , x , w ):
2192
+ m_values = [i .shape [0 ] for i in x ]
2193
+ m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2194
+ wq , w_scale = zip (* [scale_mxfp4_quant (i ) for i in w ])
2195
+ wq = torch .stack (wq , dim = 0 ).contiguous ()
2196
+ w_scale = torch .stack (w_scale , dim = 0 ).contiguous ()
2197
+ return x , wq , w_scale , m_sizes
2198
+
2199
+ def quantize (self , x , wq , w_scale , m_sizes ):
2200
+ xq , x_scale = zip (* [scale_mxfp4_quant (i ) for i in x ])
2201
+ xq = torch .stack (xq , dim = 0 ).contiguous ()
2202
+ x_scale = torch .stack (x_scale , dim = 0 ).contiguous ()
2203
+ xq = xq .view (- 1 , xq .shape [- 1 ])
2204
+ return xq , wq , x_scale , w_scale , m_sizes
2205
+
2206
+ def compute (self , xq , wq , x_scale , w_scale , m_sizes ):
2207
+ return torch .ops .fbgemm .f4f4bf16_grouped_stacked (
2208
+ xq , wq , x_scale , w_scale , m_sizes
2209
+ )
2210
+
2211
+ def quantize_and_compute (self , x , w ):
2212
+ xq , wq , x_scale , w_scale , m_sizes = self .quantize (x , w )
2213
+ return self .compute (xq , wq , x_scale , w_scale , m_sizes )
2214
+
2215
+ @property
2216
+ def name (self ) -> str :
2217
+ return "cutlass_f4f4bf16_grouped_stacked"
2218
+
2219
+ @property
2220
+ def hip (self ) -> bool :
2221
+ return False
2222
+
2223
+ @property
2224
+ def cuda (self ) -> bool :
2225
+ return True
2226
+
2227
+
2228
+ @register_quantize_op
2229
+ class NVFP4StackedGroupedGemm (QuantizeOpBase ):
2230
+ """
2231
+ NVFP4 grouped matmul with blockwise scaling and stacked inputs.
2232
+ """
2233
+
2234
+ def quantize (self , x , w ):
2235
+ def get_global_scale (x , w ):
2236
+ x_global_scale = ((448.0 * 6.0 ) / torch .amax (x .flatten (), dim = - 1 )).to (
2237
+ torch .float32
2238
+ )
2239
+ w_global_scale = ((448.0 * 6.0 ) / torch .amax (w .flatten (), dim = - 1 )).to (
2240
+ torch .float32
2241
+ )
2242
+ global_scale = 1 / (x_global_scale * w_global_scale )
2243
+ return x_global_scale , w_global_scale , global_scale
2244
+
2245
+ m_values = [i .shape [0 ] for i in x ]
2246
+ m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2247
+
2248
+ # Compute global scale for each group
2249
+ G = len (x )
2250
+ x_global_scale = []
2251
+ w_global_scale = []
2252
+ global_scale = []
2253
+ for i in range (G ):
2254
+ x_global_scale_ , w_global_scale_ , global_scale_ = get_global_scale (
2255
+ x [i ], w [i ]
2256
+ )
2257
+ x_global_scale .append (x_global_scale_ )
2258
+ w_global_scale .append (w_global_scale_ )
2259
+ global_scale .append (global_scale_ )
2260
+
2261
+ wq , w_scale = zip (
2262
+ * [scale_nvfp4_quant (w [i ], w_global_scale [i ]) for i in range (G )]
2263
+ )
2264
+ wq = torch .stack (wq , dim = 0 ).contiguous ()
2265
+ w_scale = torch .stack (w_scale , dim = 0 ).contiguous ()
2266
+
2267
+ xq , x_scale = zip (
2268
+ * [scale_nvfp4_quant (x [i ], x_global_scale [i ]) for i in range (G )]
2269
+ )
2270
+ xq = torch .stack (xq , dim = 0 ).contiguous ()
2271
+ x_scale = torch .stack (x_scale , dim = 0 ).contiguous ()
2272
+ xq = xq .view (- 1 , xq .shape [- 1 ])
2273
+ global_scale = torch .stack (global_scale , dim = 0 ).contiguous ()
2274
+ return xq , wq , x_scale , w_scale , m_sizes , global_scale
2275
+
2276
+ def compute (self , xq , wq , x_scale , w_scale , m_sizes , global_scale ):
2277
+ return torch .ops .fbgemm .f4f4bf16_grouped_stacked (
2278
+ xq , wq , x_scale , w_scale , m_sizes , global_scale , use_mx = False
2279
+ )
2280
+
2281
+ def quantize_and_compute (self , x , w ):
2282
+ xq , wq , x_scale , w_scale , m_sizes , global_scale = self .quantize (x , w )
2283
+ return self .compute (xq , wq , x_scale , w_scale , m_sizes , global_scale )
2284
+
2285
+ @property
2286
+ def name (self ) -> str :
2287
+ return "cutlass_nv_f4f4bf16_grouped_stacked"
2288
+
2289
+ @property
2290
+ def hip (self ) -> bool :
2291
+ return False
2292
+
2293
+ @property
2294
+ def cuda (self ) -> bool :
2295
+ return True
0 commit comments