Skip to content

Commit 0791776

Browse files
authored
Fix RTN supported layer checking condition (#1705)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 14868c0 commit 0791776

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def rtn_quantize(
403403
model: fake quantized torch module
404404
"""
405405
assert isinstance(model, torch.nn.Module), "only support torch module"
406-
supported_layers = ["Linear"]
406+
supported_layers = (torch.nn.Linear,)
407407
if return_int:
408408
compression_dtype = kwargs.get("compression_dtype", torch.int32)
409409
compression_dim = kwargs.get("compression_dim", 1)
@@ -412,7 +412,7 @@ def rtn_quantize(
412412
use_optimum_format = kwargs.get("use_optimum_format", True)
413413
with torch.no_grad():
414414
for name, m in model.named_modules():
415-
if m.__class__.__name__ not in supported_layers:
415+
if not isinstance(m, supported_layers):
416416
continue
417417
orig_dtype = next(m.parameters()).dtype
418418
if orig_dtype != torch.float:

0 commit comments

Comments
 (0)