Skip to content

Commit

Permalink
tests fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 15, 2025
1 parent a77e905 commit 565be9d
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions tests_pytest/base_test_classes/base_tpc_attach2fw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import abc
from typing import NamedTuple

import pytest
Expand Down Expand Up @@ -39,7 +40,7 @@
OpSet = NamedTuple("OpSet", [('op_name', str), ('op_list', list)])


class BaseTpcAttach2FrameworkTest:
class BaseTpcAttach2FrameworkTest(abc.ABC):

attach2fw_class = None

Expand All @@ -50,7 +51,8 @@ def setup_method(self):

def test_attach2fw_init(self):
# verify built-in opset to operator mapping structure
assert all([opset in self.attach2fw._opset2layer for opset in schema.OperatorSetNames.get_values()])
assert len(self.attach2fw._opset2layer) == 57 # number of built-in operator sets
assert all(opset in self.attach2fw._opset2layer for opset in schema.OperatorSetNames.get_values())
assert all(isinstance(key, schema.OperatorSetNames) for key in self.attach2fw._opset2layer.keys())
assert all(isinstance(value, list) for value in self.attach2fw._opset2layer.values())

Expand Down Expand Up @@ -82,14 +84,12 @@ def test_attach2fw_attach_without_attributes(self):

all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(opset.op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(opset.op_list)
assert len(all_mapped_ops) == len(opset.op_list)

if len(opset.op_list) > 0:
for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
assert qco.base_config.activation_n_bits == 42
assert qco.base_config == tested_op_cfg


def test_attach2fw_attach_linear_op_with_attributes(self):
Expand Down Expand Up @@ -128,10 +128,8 @@ def test_attach2fw_attach_linear_op_with_attributes(self):

all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(opset.op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(opset.op_list)
assert len(all_mapped_ops) == len(opset.op_list)
if len(opset.op_list) > 0:

for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
Expand All @@ -140,7 +138,7 @@ def test_attach2fw_attach_linear_op_with_attributes(self):
for attr_name, fw_layer2attr_mapping in fw_linear_attr_names.items():
assert isinstance(fw_layer2attr_mapping, DefaultDict)
layer_attr_mapping = fw_layer2attr_mapping.get(opset.op_list[0])
assert qco.base_config.attr_weights_configs_mapping.get(layer_attr_mapping) == tested_attr_cfg
assert qco.base_config.attr_weights_configs_mapping[layer_attr_mapping] == tested_attr_cfg


def test_attach2fw_attach_to_default_config(self):
Expand All @@ -149,15 +147,14 @@ def test_attach2fw_attach_to_default_config(self):
operator_set = schema.OperatorsSet(name=opset_name)

tpc = schema.TargetPlatformCapabilities(default_qco=default_qc_options,
operator_set=tuple([operator_set]))
operator_set=(operator_set,))

fw_quant_capabilities = self.attach2fw.attach(tpc)

assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities)
opset2layer = fw_quant_capabilities.op_sets_to_layers.get_layers_by_op(operator_set)
assert len(opset2layer) > 0
opset_cfg = fw_quant_capabilities.layer2qco.get(opset2layer[0])
assert opset_cfg is not None
opset_cfg = fw_quant_capabilities.layer2qco[opset2layer[0]]
assert opset_cfg == default_qc_options

def test_not_existing_opset_with_layers_to_attach(self):
Expand All @@ -166,40 +163,41 @@ def test_not_existing_opset_with_layers_to_attach(self):

tpc = schema.TargetPlatformCapabilities(default_qco=schema.QuantizationConfigOptions(
quantization_configurations=(default_op_cfg,)),
operator_set=tuple([operator_set]))
operator_set=(operator_set,))

with pytest.raises(Exception) as e_info:
with pytest.raises(Exception, match=f'{opset_name} is defined in TargetPlatformCapabilities'):
_ = self.attach2fw.attach(tpc)
assert f'{opset_name} is defined in TargetPlatformCapabilities' in str(e_info)

def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, kernel_attr_name):
test_bit = 42
opset_name = "Custom"
attr_name = "CustomAttr"
cfg = default_op_cfg.clone_and_edit(attr_weights_configs_mapping=
{KERNEL_ATTR: schema.AttributeQuantizationConfig(weights_n_bits=test_bit)})
{attr_name: schema.AttributeQuantizationConfig(weights_n_bits=test_bit)})
qc_options = schema.QuantizationConfigOptions(quantization_configurations=(cfg,))
opset_name = "Custom"

operator_set = schema.OperatorsSet(name=opset_name,
qc_options=qc_options.clone_and_edit(activation_n_bits=test_bit))

tpc = schema.TargetPlatformCapabilities(default_qco=schema.QuantizationConfigOptions(
quantization_configurations=(default_op_cfg,)),
operator_set=tuple([operator_set]))
operator_set=(operator_set,))

with pytest.raises(Exception) as e_info:
_ = self.attach2fw.attach(tpc)
assert f'{opset_name} is defined in TargetPlatformCapabilities' in str(e_info)

# Setting a layers mapping for the custom opset with a regular operator and a filter.
# We also test the option of passing an attributes mapping for the operator to set a specific attribute config.
custom_attr_name = 'CustomAttr'
fw_custom_attr_name = 'CustomAttrFW'

assert len(operators_list) == 1
non_filter_op = operators_list[0]
fw_quant_capabilities = self.attach2fw.attach(
tpc,
custom_opset2layer={opset_name: CustomOpsetLayers(operators=[non_filter_op, filter_op],
attr_mapping={KERNEL_ATTR: DefaultDict(
{filter_op: custom_attr_name},
attr_mapping={attr_name: DefaultDict(
{filter_op: fw_custom_attr_name},
default_value=kernel_attr_name)})
})

Expand All @@ -216,8 +214,8 @@ def _test_attach2fw_attach_with_custom_opset(self, operators_list, filter_op, ke

op_cfg = fw_quant_capabilities.filterlayer2qco[filter_op].base_config
assert op_cfg.activation_n_bits == test_bit
assert custom_attr_name in op_cfg.attr_weights_configs_mapping
assert op_cfg.attr_weights_configs_mapping[custom_attr_name].weights_n_bits == test_bit
assert fw_custom_attr_name in op_cfg.attr_weights_configs_mapping
assert op_cfg.attr_weights_configs_mapping[fw_custom_attr_name].weights_n_bits == test_bit


def _test_attach2fw_prioritize_custom_opset(self, op):
Expand All @@ -230,7 +228,7 @@ def _test_attach2fw_prioritize_custom_opset(self, op):

tpc = schema.TargetPlatformCapabilities(default_qco=schema.QuantizationConfigOptions(
quantization_configurations=(default_op_cfg,)),
operator_set=tuple([operator_set]))
operator_set=(operator_set,))

filter_op = LayerFilterParams(op, kernel_size=1)

Expand Down

0 comments on commit 565be9d

Please sign in to comment.