@@ -202,6 +202,7 @@ def __init__(
202
202
op_types_to_quantize : tuple [str , ...] | None = None ,
203
203
quant_axes : tuple [tuple [str , int ], ...] | None = None ,
204
204
bits : int = 4 ,
205
+ channel_wised_quantize : bool = False ,
205
206
):
206
207
"""
207
208
This is a class for weight only affine quantization configuration.
@@ -236,6 +237,9 @@ def __init__(
236
237
self .is_symmetric = is_symmetric
237
238
self .bits = bits
238
239
self .accuracy_level = accuracy_level
240
+ self .channel_wised_quantize = channel_wised_quantize
241
+ if channel_wised_quantize and quant_format == QuantFormat .QOperator :
242
+ raise NotImplementedError ("QuantFormat.QOperator is not supported channel_wised_quantize yet" )
239
243
240
244
241
245
class NVAWQWeightOnlyQuantConfig (WeightOnlyQuantConfig ):
@@ -734,6 +738,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr
734
738
return None , None
735
739
736
740
741
+ # transpose int4 matrix (packed as uint8)
742
+ def transpose_packed_int4_matrix (packed , rows , cols ):
743
+ # unpack to int4 matrix
744
+ total = rows * cols
745
+ high = (packed >> 4 ) & 0x0F
746
+ low = packed & 0x0F
747
+ int4_vals = np .empty (total , dtype = np .uint8 )
748
+ int4_vals [0 ::2 ] = low
749
+ int4_vals [1 ::2 ] = high
750
+ int4_matrix = int4_vals .reshape ((rows , cols ))
751
+
752
+ # transpose int4 matrix
753
+ int4_matrix_transposed = int4_matrix .T
754
+
755
+ # pack to uint8
756
+ flat = int4_matrix_transposed .reshape (- 1 )
757
+ packed = ((flat [1 ::2 ] << 4 ) & 0xF0 ) | (flat [0 ::2 ] & 0x0F )
758
+ return packed .astype (np .uint8 )
759
+
760
+
737
761
class DefaultWeightOnlyQuantizer :
738
762
def __init__ (self , config : DefaultWeightOnlyQuantConfig ):
739
763
self .config = config
@@ -770,6 +794,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
770
794
packed , fp32weight , scales , zero_point , block_size , cols , rows , self .config .is_symmetric
771
795
)
772
796
else :
797
+ # block size equal to rows (K) if channel wised quantize enabled
798
+ block_size = rows if self .config .channel_wised_quantize else self .config .block_size
799
+ k_blocks = (rows + block_size - 1 ) // block_size
800
+
773
801
assert qbits == 4 , "QDQ format only support 4 bits quantization"
774
802
packed = np .zeros ((rows * cols + 1 ) // 2 , dtype = "uint8" )
775
803
zero_point = np .zeros ((cols * k_blocks + 1 ) // 2 , dtype = "uint8" )
@@ -812,6 +840,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
812
840
)
813
841
scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
814
842
843
+ # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance
844
+ qdq_opt_for_intel_npu_enabled = self .config .quant_format == QuantFormat .QDQ \
845
+ and self .config .channel_wised_quantize and self .config .is_symmetric
846
+ if qdq_opt_for_intel_npu_enabled :
847
+ rows , cols = b_ndarray .shape
848
+ packed = transpose_packed_int4_matrix (packed , rows , cols )
849
+ scales = scales .reshape ((cols , 1 )) # (cols, 1)
850
+ b_quant = onnx .helper .make_tensor (b_tensor .name + f"_DQ_Q{ bits } " , qtype , [cols , rows ], packed .tobytes (), True )
851
+ scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
852
+
815
853
for input in b_graph .input :
816
854
if input .name == input_b :
817
855
b_graph .input .remove (input )
@@ -849,15 +887,21 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
849
887
else :
850
888
dq_input_names = [b_quant .name , scales_tensor .name ]
851
889
dq_output_names = [b_quant .name + "_output" ]
852
- matmul_input_names = [node .input [0 ], dq_output_names [0 ]]
890
+ tp_input_names = [dq_output_names [0 ]]
891
+ tp_output_names = [dq_output_names [0 ] + "_transposed" ]
892
+ matmul_input_names = [node .input [0 ], tp_output_names [0 ] if qdq_opt_for_intel_npu_enabled else dq_output_names [0 ]]
853
893
matmul_output_names = [node .output [0 ]]
854
894
if not self .config .is_symmetric :
855
895
zp_tensor = onnx .helper .make_tensor (
856
896
b_tensor .name + "_DQ_zero_points" , qtype , scales .shape , zero_points .tobytes (), True
857
897
)
858
898
dq_input_names .append (zp_tensor .name )
859
899
b_graph .initializer .extend ([zp_tensor ])
860
- dq_kwargs = {"axis" : 0 , "block_size" : self .config .block_size }
900
+ rows , cols = b_ndarray .shape
901
+ dq_kwargs = {
902
+ "axis" : 1 if qdq_opt_for_intel_npu_enabled else 0 ,
903
+ "block_size" : rows if self .config .channel_wised_quantize else self .config .block_size
904
+ }
861
905
dq_node = onnx .helper .make_node (
862
906
"DequantizeLinear" ,
863
907
inputs = dq_input_names ,
@@ -871,7 +915,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
871
915
outputs = matmul_output_names ,
872
916
name = node .name + f"_matmul_Q{ bits } " if node .name else "" ,
873
917
)
874
- output_nodes .extend ([dq_node , matmul_node ])
918
+ if qdq_opt_for_intel_npu_enabled :
919
+ tp_node = onnx .helper .make_node (
920
+ "Transpose" ,
921
+ inputs = tp_input_names ,
922
+ outputs = tp_output_names ,
923
+ perm = [1 ,0 ],
924
+ )
925
+ output_nodes .extend ([dq_node , tp_node , matmul_node ])
926
+ else :
927
+ output_nodes .extend ([dq_node , matmul_node ])
875
928
876
929
return output_nodes
877
930
@@ -1136,6 +1189,7 @@ def __init__(
1136
1189
quant_format = QuantFormat .QOperator ,
1137
1190
op_types_to_quantize : tuple [str , ...] | None = None ,
1138
1191
quant_axes : tuple [tuple [str , int ], ...] | None = None ,
1192
+ channel_wised_quantize : bool = False ,
1139
1193
algo_config : WeightOnlyQuantConfig | None = None ,
1140
1194
):
1141
1195
if nodes_to_exclude is None :
@@ -1158,6 +1212,7 @@ def __init__(
1158
1212
op_types_to_quantize = op_types_to_quantize ,
1159
1213
quant_axes = quant_axes ,
1160
1214
bits = 4 , # default to 4 bits
1215
+ channel_wised_quantize = channel_wised_quantize ,
1161
1216
)
1162
1217
1163
1218
self .algo_config = algo_config
0 commit comments