Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 1, 2025
2 parents 417a3d3 + afed6e3 commit 1c909b9
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 243 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
sensitivity_evaluator: SensitivityEvaluation,
ru_functions: Dict[RUTarget, RuFunctions[MpRuMetric, MpRuAggregation]],
ru_functions: Dict[RUTarget, RuFunctions],
target_resource_utilization: ResourceUtilization,
original_graph: Graph = None):
"""
Expand Down
301 changes: 129 additions & 172 deletions model_compression_toolkit/target_platform_capabilities/schema/v1.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from common.constants import TENSORFLOW, PYTORCH
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
TFLITE_TP_MODEL, QNNPACK_TP_MODEL
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_tp_model_imx500_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model as get_tp_model_tflite_v1
from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as get_tp_model_qnnpack_v1


# TODO: These methods need to be replaced once modifying the TPC API.


def get_target_platform_capabilities(fw_name: str,
target_platform_name: str,
target_platform_version: str = None) -> TargetPlatformModel:
Expand All @@ -37,9 +40,22 @@ def get_target_platform_capabilities(fw_name: str,
A default TargetPlatformModel object.
"""

assert fw_name == DEFAULT_TP_MODEL or fw_name == 'v1', \
assert fw_name == TENSORFLOW or fw_name == PYTORCH, f"Unsupported framework {fw_name}."

if target_platform_name == DEFAULT_TP_MODEL:
return get_tp_model_imx500_v1()

assert target_platform_version == 'v1', \
"The usage of get_target_platform_capabilities API is supported only with the default TPC ('v1')."
return get_tp_model()

if target_platform_name == IMX500_TP_MODEL:
return get_tp_model_imx500_v1()
elif target_platform_name == TFLITE_TP_MODEL:
return get_tp_model_tflite_v1()
elif target_platform_name == QNNPACK_TP_MODEL:
return get_tp_model_qnnpack_v1()

raise ValueError(f"Unsupported target platform name {target_platform_name}.")


def get_tpc_model(name: str, tp_model: TargetPlatformModel):
Expand All @@ -55,4 +71,4 @@ def get_tpc_model(name: str, tp_model: TargetPlatformModel):
"""

return tp_model
return get_tp_model_imx500_v1
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers==1.5.2
pydantic
pydantic<2.0
3 changes: 1 addition & 2 deletions tests/common_tests/helpers/generate_test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def generate_test_tp_model(edit_params_dict, name=""):
base_config, op_cfg_list, default_config = get_op_quantization_configs()

# separate weights attribute parameters from the requested param to edit
weights_params_names = [name for name in schema.AttributeQuantizationConfig.model_fields.keys() if
name != 'self']
weights_params_names = base_config.default_weight_attr_config.field_names
weights_params = {k: v for k, v in edit_params_dict.items() if k in weights_params_names}
rest_params = {k: v for k, v in edit_params_dict.items() if k not in list(weights_params.keys())}

Expand Down
12 changes: 5 additions & 7 deletions tests/common_tests/test_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
import os

import unittest
from pydantic_core import from_json

import model_compression_toolkit as mct
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \
get_config_options_by_operators_set, is_opset_in_model
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc
Expand Down Expand Up @@ -51,7 +49,7 @@ def test_dump_to_json(self):
tpc_patch_version=0,
tpc_platform_type="dump_to_json",
add_metadata=False)
json_str = model.model_dump_json()
json_str = model.json()
# Define the output file path
file_path = "target_platform_model.json"
# Register cleanup to delete the file if it exists
Expand All @@ -64,7 +62,7 @@ def test_dump_to_json(self):
with open(file_path, "r") as f:
json_content = f.read()

loaded_target_model = schema.TargetPlatformModel.model_validate_json(json_content)
loaded_target_model = schema.TargetPlatformModel.parse_raw(json_content)
self.assertEqual(model, loaded_target_model)


Expand All @@ -78,7 +76,7 @@ def test_immutable_tp(self):
tpc_platform_type=None,
add_metadata=False)
model.operator_set = tuple()
self.assertEqual("1 validation error for TargetPlatformModel\noperator_set\n Instance is frozen", str(e.exception)[:76])
self.assertEqual('"TargetPlatformModel" is immutable and does not support item assignment', str(e.exception))

def test_default_options_more_than_single_qc(self):
test_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC, TEST_QC]), base_config=TEST_QC)
Expand Down Expand Up @@ -167,7 +165,7 @@ def test_list_of_no_qc(self):
with self.assertRaises(Exception) as e:
schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC, 3]), base_config=TEST_QC)
self.assertTrue(
'1 validation error for QuantizationConfigOptions\nquantization_configurations.1\n Input should be a valid dictionary or instance of OpQuantizationConfig [type=model_type, input_value=3, input_type=int]\n' in str(
"1 validation error for QuantizationConfigOptions\nquantization_configurations -> 1\n value is not a valid dict (type=type_error.dict)" in str(
e.exception))

def test_clone_and_edit_options(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.
# ==============================================================================
import unittest
import model_compression_toolkit as mct
from model_compression_toolkit.constants import THRESHOLD, TENSORFLOW
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.keras.constants import FUNCTION

Expand Down
44 changes: 21 additions & 23 deletions tests/keras_tests/non_parallel_tests/test_keras_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,23 @@ def test_qco_by_keras_layer(self):
self.assertEqual(tanh_qco, sevenbit_qco)
self.assertEqual(relu_qco, default_qco)

# TODO: need to test as part of attach to fw
def test_opset_not_in_tp(self):
default_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))
hm = schema.TargetPlatformModel(default_qco=default_qco,
tpc_minor_version=None,
tpc_patch_version=None,
tpc_platform_type=None,
operator_set=tuple([schema.OperatorsSet(name="opA")]),
add_metadata=False)
hm_keras = tp.TargetPlatformCapabilities(hm)
with self.assertRaises(Exception) as e:
with hm_keras:
tp.OperationsSetToLayers("conv", [Conv2D])
self.assertEqual(
'conv is not defined in the target platform model that is associated with the target platform capabilities.',
str(e.exception))

def test_keras_fusing_patterns(self):
default_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))
a = schema.OperatorsSet(name="opA")
Expand Down Expand Up @@ -306,35 +323,16 @@ def rep_data():

def test_get_keras_supported_version(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) # Latest
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1_pot')
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1_lut')
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1')
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v2_lut')
self.assertTrue(tpc.tp_model.tpc_minor_version == 2)
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v2')
self.assertTrue(tpc.tp_model.tpc_minor_version == 2)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1_lut")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2_lut")
self.assertTrue(tpc.tp_model.tpc_minor_version == 2)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1_pot")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, TFLITE_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(TENSORFLOW, QNNPACK_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

def test_get_keras_not_supported_platform(self):
with self.assertRaises(Exception) as e:
Expand Down
45 changes: 20 additions & 25 deletions tests/pytorch_tests/function_tests/test_pytorch_tp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,22 @@ def test_filter_layer_attached_to_multiple_opsets(self):
tp.OperationsSetToLayers('opsetB', [LayerFilterParams(torch.nn.Softmax, dim=2)])
self.assertEqual('Found layer Softmax(dim=2) in more than one OperatorsSet', str(e.exception))

# TODO: bring back the test if we decide that this needs to be enforced by the TPC during initialization
# def test_opset_not_in_tp(self):
# default_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))
# hm = schema.TargetPlatformModel(default_qco=default_qco,
# tpc_minor_version=None,
# tpc_patch_version=None,
# tpc_platform_type=None,
# operator_set=tuple([schema.OperatorsSet(name="opA")]),
# add_metadata=False)
# hm_pytorch = tp.TargetPlatformCapabilities(hm)
# with self.assertRaises(Exception) as e:
# with hm_pytorch:
# tp.OperationsSetToLayers("conv", [torch.nn.Conv2d])
# self.assertEqual(
# 'conv is not defined in the target platform model that is associated with the target platform capabilities.',
# str(e.exception))
# TODO: need to test as part of attach to fw
def test_opset_not_in_tp(self):
default_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple([TEST_QC]))
hm = schema.TargetPlatformModel(default_qco=default_qco,
tpc_minor_version=None,
tpc_patch_version=None,
tpc_platform_type=None,
operator_set=tuple([schema.OperatorsSet(name="opA")]),
add_metadata=False)
hm_pytorch = tp.TargetPlatformCapabilities(hm)
with self.assertRaises(Exception) as e:
with hm_pytorch:
tp.OperationsSetToLayers("conv", [torch.nn.Conv2d])
self.assertEqual(
'conv is not defined in the target platform model that is associated with the target platform capabilities.',
str(e.exception))

def test_pytorch_fusing_patterns(self):
default_qco = schema.QuantizationConfigOptions(quantization_configurations=tuple(
Expand Down Expand Up @@ -299,21 +299,16 @@ def rep_data():

def test_get_pytorch_supported_version(self):
tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL) # Latest
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL, 'v1')
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL, 'v2')
self.assertTrue(tpc.tp_model.tpc_minor_version == 2)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(PYTORCH, TFLITE_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

tpc = mct.get_target_platform_capabilities(PYTORCH, QNNPACK_TP_MODEL, "v1")
self.assertTrue(tpc.tp_model.tpc_minor_version == 1)
self.assertTrue(tpc.tpc_minor_version == 1)

def test_get_pytorch_not_supported_platform(self):
with self.assertRaises(Exception) as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.
# ==============================================================================
import unittest
import model_compression_toolkit as mct
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.keras.constants import FUNCTION

Expand Down

0 comments on commit 1c909b9

Please sign in to comment.