diff --git a/tests_pytest/base_test_classes/base_tpc_attach2fw_test.py b/tests_pytest/base_test_classes/base_tpc_attach2fw_test.py index 215ab2024..5345324e2 100644 --- a/tests_pytest/base_test_classes/base_tpc_attach2fw_test.py +++ b/tests_pytest/base_test_classes/base_tpc_attach2fw_test.py @@ -52,7 +52,7 @@ def setup_method(self): def test_attach2fw_init(self): # verify built-in opset to operator mapping structure assert len(self.attach2fw._opset2layer) == 57 # number of built-in operator sets - assert all(opset in self.attach2fw._opset2layer for opset in schema.OperatorSetNames.get_values()) + assert all(opset in self.attach2fw._opset2layer for opset in list(schema.OperatorSetNames)) assert all(isinstance(key, schema.OperatorSetNames) for key in self.attach2fw._opset2layer.keys()) assert all(isinstance(value, list) for value in self.attach2fw._opset2layer.values()) @@ -168,7 +168,7 @@ def test_not_existing_opset_with_layers_to_attach(self): with pytest.raises(Exception, match=f'{opset_name} is defined in TargetPlatformCapabilities'): _ = self.attach2fw.attach(tpc) - def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, kernel_attr_name): + def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, fw_attr_name): test_bit = 42 opset_name = "Custom" attr_name = "CustomAttr" @@ -183,9 +183,8 @@ def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, ke quantization_configurations=(default_op_cfg,)), operator_set=(operator_set,)) - with pytest.raises(Exception) as e_info: + with pytest.raises(Exception, match=f'{opset_name} is defined in TargetPlatformCapabilities'): _ = self.attach2fw.attach(tpc) - assert f'{opset_name} is defined in TargetPlatformCapabilities' in str(e_info) # Setting a layers mapping for the custom opset with a regular operator and a filter. # We also test the option of passing an attributes mapping for the operator to set a specific attribute config. @@ -193,24 +192,25 @@ def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, ke assert len(operators_list) == 1 non_filter_op = operators_list[0] + layers = [non_filter_op, filter_op] fw_quant_capabilities = self.attach2fw.attach( tpc, - custom_opset2layer={opset_name: CustomOpsetLayers(operators=[non_filter_op, filter_op], + custom_opset2layer={opset_name: CustomOpsetLayers(operators=layers, attr_mapping={attr_name: DefaultDict( {filter_op: fw_custom_attr_name}, - default_value=kernel_attr_name)}) + default_value=fw_attr_name)}) }) assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities) opset_to_layers = fw_quant_capabilities.op_sets_to_layers.op_sets_to_layers assert len(opset_to_layers) == 1 assert opset_to_layers[0].name == opset_name - assert len(opset_to_layers[0].layers) == 2 + assert opset_to_layers[0].layers == layers op_cfg = fw_quant_capabilities.layer2qco[non_filter_op].base_config assert op_cfg.activation_n_bits == test_bit - assert kernel_attr_name in op_cfg.attr_weights_configs_mapping - assert op_cfg.attr_weights_configs_mapping[kernel_attr_name].weights_n_bits == test_bit + assert fw_attr_name in op_cfg.attr_weights_configs_mapping + assert op_cfg.attr_weights_configs_mapping[fw_attr_name].weights_n_bits == test_bit op_cfg = fw_quant_capabilities.filterlayer2qco[filter_op].base_config assert op_cfg.activation_n_bits == test_bit diff --git a/tests_pytest/keras/target_platform_capabilities/targetplatform2framework/test_attach2keras.py b/tests_pytest/keras/target_platform_capabilities/targetplatform2framework/test_attach2keras.py index 54121d23f..934453da4 100644 --- a/tests_pytest/keras/target_platform_capabilities/targetplatform2framework/test_attach2keras.py +++ b/tests_pytest/keras/target_platform_capabilities/targetplatform2framework/test_attach2keras.py @@ -23,7 +23,6 @@ else: from keras.layers import Conv2D, Identity -from model_compression_toolkit.target_platform_capabilities.constants import KERAS_KERNEL from model_compression_toolkit.target_platform_capabilities import LayerFilterParams from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ AttachTpcToKeras @@ -46,12 +45,12 @@ def test_attach2fw_attach_to_default_config(self): super().test_attach2fw_attach_to_default_config() def test_not_existing_opset_with_layers_to_attach(self): - super().test_attach2fw_attach_to_default_config() + super().test_not_existing_opset_with_layers_to_attach() def test_attach2pytorch_attach_with_custom_opset(self): self._test_attach2fw_attach_with_custom_opset([Identity], LayerFilterParams(Conv2D, stride=2), - KERAS_KERNEL) + "CustomAttr") def test_attach2pytorch_prioritize_custom_opset(self): self._test_attach2fw_prioritize_custom_opset(Conv2D) diff --git a/tests_pytest/pytorch/target_platform_capabilities/targetplatform2framework/test_attach2pytorch.py b/tests_pytest/pytorch/target_platform_capabilities/targetplatform2framework/test_attach2pytorch.py index 2eaa4e87f..12bce2956 100644 --- a/tests_pytest/pytorch/target_platform_capabilities/targetplatform2framework/test_attach2pytorch.py +++ b/tests_pytest/pytorch/target_platform_capabilities/targetplatform2framework/test_attach2pytorch.py @@ -15,7 +15,6 @@ import torch.nn -from model_compression_toolkit.target_platform_capabilities.constants import PYTORCH_KERNEL from model_compression_toolkit.target_platform_capabilities import LayerFilterParams from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ AttachTpcToPytorch @@ -39,12 +38,12 @@ def test_attach2fw_attach_to_default_config(self): super().test_attach2fw_attach_to_default_config() def test_not_existing_opset_with_layers_to_attach(self): - super().test_attach2fw_attach_to_default_config() + super().test_not_existing_opset_with_layers_to_attach() def test_attach2pytorch_attach_with_custom_opset(self): self._test_attach2fw_attach_with_custom_opset([torch.nn.Identity], LayerFilterParams(torch.nn.Conv2d, stride=2), - PYTORCH_KERNEL) + "CustomAttr") def test_attach2pytorch_prioritize_custom_opset(self): self._test_attach2fw_prioritize_custom_opset(torch.nn.Conv2d)