Skip to content

Commit

Permalink
fix const quantization test tpc
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 5, 2025
1 parent 620abd1 commit 7dcedd2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@

from tests.common_tests.helpers.generate_test_tp_model import generate_custom_test_tp_model
from tests.common_tests.helpers.tpcs_for_tests.v3.tp_model import get_tp_model as get_tp_v3
from tests.common_tests.helpers.tpcs_for_tests.v4.tp_model import get_tp_model as get_tp_v4, get_tp_model
from tests.common_tests.helpers.tpcs_for_tests.v4.tp_model import get_tp_model as get_tp_v4
from tests.common_tests.helpers.tpcs_for_tests.v4.tp_model import generate_tp_model, get_op_quantization_configs
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.keras_tests.utils import get_layers_from_model_by_type
from mct_quantizers import KerasQuantizationWrapper

from model_compression_toolkit.constants import TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL

keras = tf.keras
layers = keras.layers
Expand All @@ -56,18 +54,19 @@ def create_const_quant_tpc(qmethod):
const_merge_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([const_merge_config]))

operator_sets_dict = {}
operator_sets_dict["Add"] = const_configuration_options
operator_sets_dict["Sub"] = const_configuration_options
operator_sets_dict["Mul"] = const_configuration_options
operator_sets_dict["Div"] = const_configuration_options
operator_sets_dict["MergeOps"] = const_merge_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_ADD.value] = const_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_SUB.value] = const_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_MUL.value] = const_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_DIV.value] = const_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_STACK.value] = const_merge_configuration_options
operator_sets_dict[schema.OperatorSetNames.OPSET_CONCATENATE.value] = const_merge_configuration_options

tp_model = generate_custom_test_tp_model(name=name,
base_cfg=base_cfg,
base_tp_model=base_tp_model,
operator_sets_dict=operator_sets_dict)

return generate_keras_tpc(name="const_quant_tpc", tp_model=tp_model)
return tp_model


class ConstQuantizationTest(BaseKerasFeatureNetworkTest):
Expand All @@ -93,7 +92,7 @@ def get_quantization_config(self):
return mct.core.QuantizationConfig(weights_error_method=self.error_method)

def get_tpc(self):
return get_tp_model()
return create_const_quant_tpc(self.qmethod)

def create_networks(self):
inputs = layers.Input(shape=self.get_input_shapes()[0][1:])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.pytorch_tests.utils import get_layers_from_model_by_type
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.constants import PYTORCH
from mct_quantizers import PytorchQuantizationWrapper

tp = mct.target_platform
Expand Down

0 comments on commit 7dcedd2

Please sign in to comment.