File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
neural_compressor/adaptor/torch_utils Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -403,7 +403,7 @@ def rtn_quantize(
403403 model: fake quantized torch module
404404 """
405405 assert isinstance (model , torch .nn .Module ), "only support torch module"
406- supported_layers = [ " Linear" ]
406+ supported_layers = ( torch . nn . Linear ,)
407407 if return_int :
408408 compression_dtype = kwargs .get ("compression_dtype" , torch .int32 )
409409 compression_dim = kwargs .get ("compression_dim" , 1 )
@@ -412,7 +412,7 @@ def rtn_quantize(
412412 use_optimum_format = kwargs .get ("use_optimum_format" , True )
413413 with torch .no_grad ():
414414 for name , m in model .named_modules ():
415- if m . __class__ . __name__ not in supported_layers :
415+ if not isinstance ( m , supported_layers ) :
416416 continue
417417 orig_dtype = next (m .parameters ()).dtype
418418 if orig_dtype != torch .float :
You can’t perform that action at this time.
0 commit comments