Skip to content

Commit

Permalink
Fix TPC.v4 quantization preserving for weights to be per tensor, and …
Browse files Browse the repository at this point in the history
…add torch.Tensor.expand.
  • Loading branch information
elad-c committed Sep 25, 2024
1 parent c90c5f2 commit 16a47bf
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
base_config=const_config_input16_per_tensor)

qpreserving_const_config = const_config.clone_and_edit(enable_activation_quantization=False,
quantization_preserving=True)
quantization_preserving=True,
default_weight_attr_config=const_config.default_weight_attr_config.clone_and_edit(
weights_per_channel_threshold=False))
qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config])

# Create a TargetPlatformModel and set its default quantization config.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
from torch.nn import Dropout, Flatten, Hardtanh, Identity
from torch.nn import Dropout, Flatten, Hardtanh
from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu

Expand Down Expand Up @@ -87,7 +87,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
squeeze,
permute,
transpose])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather, torch.Tensor.expand])
tp.OperationsSetToLayers(OPSET_MERGE_OPS,
[torch.stack, torch.cat, torch.concat, torch.concatenate])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import torch.nn as nn
import numpy as np
import model_compression_toolkit as mct
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import Signedness
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.pytorch_tests.utils import get_layers_from_model_by_type
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.constants import PYTORCH
from mct_quantizers import PytorchQuantizationWrapper
Expand Down Expand Up @@ -196,3 +198,78 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
for qlayer in get_layers_from_model_by_type(quantized_model, op):
self.unit_test.assertTrue(isinstance(qlayer, PytorchQuantizationWrapper),
msg=f"{op} should be quantized.")


class ExpandConstQuantizationNet(nn.Module):
def __init__(self, batch_size):
super().__init__()
self.register_buffer('cat_const', to_torch_tensor(np.random.randint(-128, 127, size=(batch_size, 3, 32, 32)).astype(np.float32)))
self.register_parameter('expand_const',
nn.Parameter(to_torch_tensor(np.random.randint(-128, 127, size=(1, 2, 32, 1)).astype(np.float32)),
requires_grad=False))

def forward(self, x):
expanded_const = self.expand_const.expand(x.shape[0], -1, -1, 32)
x = torch.cat([expanded_const, self.cat_const, x], dim=1)
return x


class ConstQuantizationExpandTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test, input_shape=(16, 32, 32), val_batch_size=5)

def generate_inputs(self):
return [np.random.randint(-128, 127, size=in_shape).astype(np.float32) for in_shape in self.get_input_shapes()]

def get_tpc(self):
tp = mct.target_platform
attr_cfg = generate_test_attr_configs()
base_cfg = tp.OpQuantizationConfig(activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
enable_activation_quantization=True,
activation_n_bits=32,
supported_input_activation_n_bits=32,
default_weight_attr_config=attr_cfg[DEFAULT_WEIGHT_ATTR_CONFIG],
attr_weights_configs_mapping={},
quantization_preserving=False,
fixed_scale=1.0,
fixed_zero_point=0,
simd_size=32,
signedness=Signedness.AUTO)

default_configuration_options = tp.QuantizationConfigOptions([base_cfg])

const_config = base_cfg.clone_and_edit(enable_activation_quantization=False,
default_weight_attr_config=base_cfg.default_weight_attr_config.clone_and_edit(
enable_weights_quantization=True, weights_per_channel_threshold=False,
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO))
const_configuration_options = tp.QuantizationConfigOptions([const_config])

tp_model = tp.TargetPlatformModel(default_configuration_options)
with tp_model:
tp.OperatorsSet("WeightQuant", const_configuration_options)

tpc = tp.TargetPlatformCapabilities(tp_model)
with tpc:
# No need to quantize Flatten and Dropout layers
tp.OperationsSetToLayers("WeightQuant", [torch.Tensor.expand, torch.cat])

return tpc

def create_networks(self):
return ExpandConstQuantizationNet(self.val_batch_size)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat))
self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-3), msg=f'fail cosine similarity check: {cs}')

# check quantization layers:
for op in [torch.cat, torch.Tensor.expand]:
for qlayer in get_layers_from_model_by_type(quantized_model, op):
self.unit_test.assertTrue(isinstance(qlayer, PytorchQuantizationWrapper),
msg=f"{op} should be quantized.")
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
ConstRepresentationCodeTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \
AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest
AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest, ConstQuantizationExpandTest
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest
from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \
Activation16BitMixedPrecisionTest
Expand Down Expand Up @@ -264,6 +264,7 @@ def test_const_quantization(self):

AdvancedConstQuantizationTest(self).run_test()
ConstQuantizationMultiInputTest(self).run_test()
ConstQuantizationExpandTest(self).run_test()

def test_const_representation(self):
for const_dtype in [np.float32, np.int64, np.int32]:
Expand Down
6 changes: 4 additions & 2 deletions tests/pytorch_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def get_layers_from_model_by_type(model: torch.nn.Module,
Returns:
List of layers of type layer_type from the model.
"""
match_layer_type = lambda _layer: layer_type in [type(_layer), _layer]
if include_wrapped_layers:
return [layer[1] for layer in model.named_children() if type(layer[1])==layer_type or (isinstance(layer[1], PytorchQuantizationWrapper) and type(layer[1].layer)==layer_type)]
return [layer[1] for layer in model.named_children() if type(layer[1])==layer_type]
return [layer[1] for layer in model.named_children() if match_layer_type(layer[1]) or
(isinstance(layer[1], PytorchQuantizationWrapper) and match_layer_type(layer[1].layer))]
return [layer[1] for layer in model.named_children() if match_layer_type(layer[1])]


def count_model_prunable_params(model: torch.nn.Module) -> int:
Expand Down

0 comments on commit 16a47bf

Please sign in to comment.