From d73b4a7304fcd7d997945f40f894f1accb676a4e Mon Sep 17 00:00:00 2001 From: Ofir Gordon Date: Thu, 16 Jan 2025 09:08:33 +0200 Subject: [PATCH] use load_target_platform_capabilities in all facades (#1330) --- .../keras/resource_utilization_data_facade.py | 8 +++--- .../resource_utilization_data_facade.py | 8 +++--- .../gptq/keras/quantization_facade.py | 7 +++-- .../gptq/pytorch/quantization_facade.py | 6 +++-- .../pruning/keras/pruning_facade.py | 9 ++++--- .../pruning/pytorch/pruning_facade.py | 9 ++++--- .../ptq/keras/quantization_facade.py | 4 ++- .../ptq/pytorch/quantization_facade.py | 8 +++--- .../qat/keras/quantization_facade.py | 9 ++++--- .../qat/pytorch/quantization_facade.py | 11 +++++--- .../tpc_io_handler.py | 6 ++--- tests/common_tests/test_tpc.py | 26 +++++++++---------- 12 files changed, 67 insertions(+), 44 deletions(-) diff --git a/model_compression_toolkit/core/keras/resource_utilization_data_facade.py b/model_compression_toolkit/core/keras/resource_utilization_data_facade.py index c1c127094..68233b308 100644 --- a/model_compression_toolkit/core/keras/resource_utilization_data_facade.py +++ b/model_compression_toolkit/core/keras/resource_utilization_data_facade.py @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, Union from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, CoreConfig from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import TENSORFLOW from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TF if FOUND_TF: @@ -38,7 +39,7 @@ def keras_resource_utilization_data(in_model: Model, representative_data_gen: Callable, core_config: CoreConfig = CoreConfig( mixed_precision_config=MixedPrecisionQuantizationConfig()), - target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC + target_platform_capabilities: Union[TargetPlatformCapabilities, str] = KERAS_DEFAULT_TPC ) -> ResourceUtilization: """ Computes resource utilization data that can be used to calculate the desired target resource utilization @@ -50,7 +51,7 @@ def keras_resource_utilization_data(in_model: Model, in_model (Model): Keras model to quantize. representative_data_gen (Callable): Dataset used for calibration. core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized. - target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the Keras model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the Keras model according to. Returns: @@ -81,6 +82,7 @@ def keras_resource_utilization_data(in_model: Model, fw_impl = KerasImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2keras = AttachTpcToKeras() target_platform_capabilities = attach2keras.attach( diff --git a/model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py b/model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py index a2f7c93f2..1e0d7b4f0 100644 --- a/model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +++ b/model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, Union from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH @@ -23,6 +23,7 @@ from model_compression_toolkit.core.common.quantization.core_config import CoreConfig from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TORCH if FOUND_TORCH: @@ -40,7 +41,7 @@ def pytorch_resource_utilization_data(in_model: Module, representative_data_gen: Callable, core_config: CoreConfig = CoreConfig(), - target_platform_capabilities: TargetPlatformCapabilities= PYTORCH_DEFAULT_TPC + target_platform_capabilities: Union[TargetPlatformCapabilities, str] = PYTORCH_DEFAULT_TPC ) -> ResourceUtilization: """ Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization. @@ -50,7 +51,7 @@ def pytorch_resource_utilization_data(in_model: Module, in_model (Model): PyTorch model to quantize. representative_data_gen (Callable): Dataset used for calibration. core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision - target_platform_capabilities (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities to optimize the PyTorch model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): FrameworkQuantizationCapabilities to optimize the PyTorch model according to. Returns: @@ -81,6 +82,7 @@ def pytorch_resource_utilization_data(in_model: Module, fw_impl = PytorchImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2pytorch = AttachTpcToPytorch() target_platform_capabilities = ( diff --git a/model_compression_toolkit/gptq/keras/quantization_facade.py b/model_compression_toolkit/gptq/keras/quantization_facade.py index a32cf6556..a5e0e4429 100644 --- a/model_compression_toolkit/gptq/keras/quantization_facade.py +++ b/model_compression_toolkit/gptq/keras/quantization_facade.py @@ -25,6 +25,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ AttachTpcToKeras +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TF from model_compression_toolkit.core.common.user_info import UserInformation from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \ @@ -156,7 +157,8 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da gptq_representative_data_gen: Callable = None, target_resource_utilization: ResourceUtilization = None, core_config: CoreConfig = CoreConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]: + target_platform_capabilities: Union[TargetPlatformCapabilities, str] + = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]: """ Quantize a trained Keras model using post-training quantization. The model is quantized using a symmetric constraint quantization thresholds (power of two). @@ -180,7 +182,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to. Returns: @@ -241,6 +243,7 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da fw_impl = GPTQKerasImplemantation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2keras = AttachTpcToKeras() framework_platform_capabilities = attach2keras.attach( diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 1bc455c34..c16ec0ab6 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -32,6 +32,7 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.metadata import create_model_metadata from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TORCH @@ -145,7 +146,7 @@ def pytorch_gradient_post_training_quantization(model: Module, core_config: CoreConfig = CoreConfig(), gptq_config: GradientPTQConfig = None, gptq_representative_data_gen: Callable = None, - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC): + target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC): """ Quantize a trained Pytorch module using post-training quantization. By default, the module is quantized using a symmetric constraint quantization thresholds @@ -169,7 +170,7 @@ def pytorch_gradient_post_training_quantization(model: Module, core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer). gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the PyTorch model according to. Returns: A quantized module and information the user may need to handle the quantized module. @@ -214,6 +215,7 @@ def pytorch_gradient_post_training_quantization(model: Module, fw_impl = GPTQPytorchImplemantation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2pytorch = AttachTpcToPytorch() framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities, diff --git a/model_compression_toolkit/pruning/keras/pruning_facade.py b/model_compression_toolkit/pruning/keras/pruning_facade.py index 70162b9c9..31a0bcd15 100644 --- a/model_compression_toolkit/pruning/keras/pruning_facade.py +++ b/model_compression_toolkit/pruning/keras/pruning_facade.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from typing import Callable, Tuple, Union from model_compression_toolkit import get_target_platform_capabilities from model_compression_toolkit.constants import TENSORFLOW from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TF from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.common.pruning.pruner import Pruner @@ -43,7 +44,8 @@ def keras_pruning_experimental(model: Model, target_resource_utilization: ResourceUtilization, representative_data_gen: Callable, pruning_config: PruningConfig = PruningConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]: + target_platform_capabilities: Union[TargetPlatformCapabilities, str] + = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]: """ Perform structured pruning on a Keras model to meet a specified target resource utilization. This function prunes the provided model according to the target resource utilization by grouping and pruning @@ -61,7 +63,7 @@ def keras_pruning_experimental(model: Model, target_resource_utilization (ResourceUtilization): The target Key Performance Indicators to be achieved through pruning. representative_data_gen (Callable): A function to generate representative data for pruning analysis. pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config. - target_platform_capabilities (FrameworkQuantizationCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC. Returns: Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information. @@ -112,6 +114,7 @@ def keras_pruning_experimental(model: Model, # Instantiate the Keras framework implementation. fw_impl = PruningKerasImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2keras = AttachTpcToKeras() target_platform_capabilities = attach2keras.attach(target_platform_capabilities) diff --git a/model_compression_toolkit/pruning/pytorch/pruning_facade.py b/model_compression_toolkit/pruning/pytorch/pruning_facade.py index 4b97cc7bd..8a92b3345 100644 --- a/model_compression_toolkit/pruning/pytorch/pruning_facade.py +++ b/model_compression_toolkit/pruning/pytorch/pruning_facade.py @@ -13,10 +13,11 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Tuple +from typing import Callable, Tuple, Union from model_compression_toolkit import get_target_platform_capabilities from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TORCH from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.common.pruning.pruner import Pruner @@ -47,7 +48,8 @@ def pytorch_pruning_experimental(model: Module, target_resource_utilization: ResourceUtilization, representative_data_gen: Callable, pruning_config: PruningConfig = PruningConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \ + target_platform_capabilities: Union[TargetPlatformCapabilities, str] + = DEFAULT_PYOTRCH_TPC) -> \ Tuple[Module, PruningInfo]: """ Perform structured pruning on a Pytorch model to meet a specified target resource utilization. @@ -66,7 +68,7 @@ def pytorch_pruning_experimental(model: Module, target_resource_utilization (ResourceUtilization): Key Performance Indicators specifying the pruning targets. representative_data_gen (Callable): A function to generate representative data for pruning analysis. pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config. - target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): Platform-specific constraints and capabilities. Defaults to DEFAULT_PYTORCH_TPC. Returns: @@ -118,6 +120,7 @@ def pytorch_pruning_experimental(model: Module, # Instantiate the Pytorch framework implementation. fw_impl = PruningPytorchImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach TPC to framework attach2pytorch = AttachTpcToPytorch() framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities) diff --git a/model_compression_toolkit/ptq/keras/quantization_facade.py b/model_compression_toolkit/ptq/keras/quantization_facade.py index 1a65525a5..883220d23 100644 --- a/model_compression_toolkit/ptq/keras/quantization_facade.py +++ b/model_compression_toolkit/ptq/keras/quantization_facade.py @@ -23,6 +23,7 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import TENSORFLOW from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TF from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ @@ -70,7 +71,7 @@ def keras_post_training_quantization(in_model: Model, representative_data_gen (Callable): Dataset used for calibration. target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to. Returns: @@ -137,6 +138,7 @@ def keras_post_training_quantization(in_model: Model, fw_impl = KerasImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) attach2keras = AttachTpcToKeras() framework_platform_capabilities = attach2keras.attach( target_platform_capabilities, diff --git a/model_compression_toolkit/ptq/pytorch/quantization_facade.py b/model_compression_toolkit/ptq/pytorch/quantization_facade.py index fdba01ecb..efebab82a 100644 --- a/model_compression_toolkit/ptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/ptq/pytorch/quantization_facade.py @@ -14,12 +14,13 @@ # ============================================================================== import copy -from typing import Callable +from typing import Callable, Union from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TORCH from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core import CoreConfig @@ -48,7 +49,7 @@ def pytorch_post_training_quantization(in_module: Module, representative_data_gen: Callable, target_resource_utilization: ResourceUtilization = None, core_config: CoreConfig = CoreConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC): + target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC): """ Quantize a trained Pytorch module using post-training quantization. By default, the module is quantized using a symmetric constraint quantization thresholds @@ -67,7 +68,7 @@ def pytorch_post_training_quantization(in_module: Module, representative_data_gen (Callable): Dataset used for calibration. target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the PyTorch model according to. Returns: A quantized module and information the user may need to handle the quantized module. @@ -109,6 +110,7 @@ def pytorch_post_training_quantization(in_module: Module, fw_impl = PytorchImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2pytorch = AttachTpcToPytorch() framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities, diff --git a/model_compression_toolkit/qat/keras/quantization_facade.py b/model_compression_toolkit/qat/keras/quantization_facade.py index 9480c018f..a9304ec80 100644 --- a/model_compression_toolkit/qat/keras/quantization_facade.py +++ b/model_compression_toolkit/qat/keras/quantization_facade.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, Union from functools import partial from model_compression_toolkit.core import CoreConfig @@ -22,6 +22,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ AttachTpcToKeras +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TF from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ @@ -49,7 +50,6 @@ from model_compression_toolkit.core import common from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.constants import TENSORFLOW - from model_compression_toolkit.core.common.framework_info import FrameworkInfo from model_compression_toolkit.qat.common.qat_config import is_qat_applicable from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO @@ -92,7 +92,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, target_resource_utilization: ResourceUtilization = None, core_config: CoreConfig = CoreConfig(), qat_config: QATConfig = QATConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC): + target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_KERAS_TPC): """ Prepare a trained Keras model for quantization aware training. First the model quantization is optimized with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is @@ -114,7 +114,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. qat_config (QATConfig): QAT configuration - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Keras model according to. Returns: @@ -188,6 +188,7 @@ def keras_quantization_aware_training_init_experimental(in_model: Model, fw_impl = KerasImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) attach2keras = AttachTpcToKeras() target_platform_capabilities = attach2keras.attach( target_platform_capabilities, diff --git a/model_compression_toolkit/qat/pytorch/quantization_facade.py b/model_compression_toolkit/qat/pytorch/quantization_facade.py index 55537eadb..938471034 100644 --- a/model_compression_toolkit/qat/pytorch/quantization_facade.py +++ b/model_compression_toolkit/qat/pytorch/quantization_facade.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, Union from functools import partial from model_compression_toolkit.constants import PYTORCH from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ AttachTpcToPytorch +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities from model_compression_toolkit.verify_packages import FOUND_TORCH from model_compression_toolkit.core import CoreConfig @@ -78,7 +79,8 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module, target_resource_utilization: ResourceUtilization = None, core_config: CoreConfig = CoreConfig(), qat_config: QATConfig = QATConfig(), - target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC): + target_platform_capabilities: Union[TargetPlatformCapabilities, str] + = DEFAULT_PYTORCH_TPC): """ Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is @@ -100,7 +102,7 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module, target_resource_utilization (ResourceUtilization): ResourceUtilization object to limit the search of the mixed-precision configuration as desired. core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters. qat_config (QATConfig): QAT configuration - target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to. + target_platform_capabilities (Union[TargetPlatformCapabilities, str]): TargetPlatformCapabilities to optimize the Pytorch model according to. Returns: @@ -153,10 +155,11 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module, tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) fw_impl = PytorchImplementation() + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) # Attach tpc model to framework attach2pytorch = AttachTpcToPytorch() framework_platform_capabilities = attach2pytorch.attach(target_platform_capabilities, - core_config.quantization_config.custom_tpc_opset_to_layer) + core_config.quantization_config.custom_tpc_opset_to_layer) # Ignore hessian scores service as we do not use it here tg, bit_widths_config, _, _ = core_runner(in_model=in_model, diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py index 0a844d035..49ad7a28c 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py @@ -20,13 +20,13 @@ import json -def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities: +def load_target_platform_capabilities(tpc_obj_or_path: Union[TargetPlatformCapabilities, str]) -> TargetPlatformCapabilities: """ Parses the tpc input, which can be either a TargetPlatformCapabilities object or a string path to a JSON file. Parameters: - tpc_obj_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file. + tpc_obj_or_path (Union[TargetPlatformCapabilities, str]): Input target platform model or path to .JSON file. Returns: TargetPlatformCapabilities: The parsed TargetPlatformCapabilities. @@ -66,7 +66,7 @@ def load_target_platform_model(tpc_obj_or_path: Union[TargetPlatformCapabilities ) -def export_target_platform_model(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None: +def export_target_platform_capabilities(model: TargetPlatformCapabilities, export_path: Union[str, Path]) -> None: """ Exports a TargetPlatformCapabilities instance to a JSON file. diff --git a/tests/common_tests/test_tpc.py b/tests/common_tests/test_tpc.py index e4a5645d1..db1800019 100644 --- a/tests/common_tests/test_tpc.py +++ b/tests/common_tests/test_tpc.py @@ -22,8 +22,8 @@ from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ get_config_options_by_operators_set, is_opset_in_model -from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_model, \ - export_target_platform_model +from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities, \ + export_target_platform_capabilities from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc @@ -65,19 +65,19 @@ def tearDown(self): def test_valid_model_object(self): """Test that a valid TargetPlatformCapabilities object is returned unchanged.""" - result = load_target_platform_model(self.tpc) + result = load_target_platform_capabilities(self.tpc) self.assertEqual(self.tpc, result) def test_invalid_json_parsing(self): """Test that invalid JSON content raises a ValueError.""" with self.assertRaises(ValueError) as context: - load_target_platform_model(self.invalid_json_file) + load_target_platform_capabilities(self.invalid_json_file) self.assertIn("Invalid JSON for loading TargetPlatformCapabilities in", str(context.exception)) def test_nonexistent_file(self): """Test that a nonexistent file raises FileNotFoundError.""" with self.assertRaises(FileNotFoundError) as context: - load_target_platform_model(self.nonexistent_file) + load_target_platform_capabilities(self.nonexistent_file) self.assertIn("is not a valid file", str(context.exception)) def test_non_json_extension(self): @@ -87,7 +87,7 @@ def test_non_json_extension(self): with open(non_json_file, "w") as file: file.write(self.invalid_json_content) with self.assertRaises(ValueError) as context: - load_target_platform_model(non_json_file) + load_target_platform_capabilities(non_json_file) self.assertIn("does not have a '.json' extension", str(context.exception)) finally: os.remove(non_json_file) @@ -96,12 +96,12 @@ def test_invalid_input_type(self): """Test that an unsupported input type raises TypeError.""" invalid_input = 123 # Not a string or TargetPlatformCapabilities with self.assertRaises(TypeError) as context: - load_target_platform_model(invalid_input) + load_target_platform_capabilities(invalid_input) self.assertIn("must be either a TargetPlatformCapabilities instance or a string path", str(context.exception)) def test_valid_export(self): """Test exporting a valid TargetPlatformCapabilities instance to a file.""" - export_target_platform_model(self.tpc, self.valid_export_path) + export_target_platform_capabilities(self.tpc, self.valid_export_path) # Verify the file exists self.assertTrue(os.path.exists(self.valid_export_path)) @@ -113,20 +113,20 @@ def test_valid_export(self): def test_export_with_invalid_model(self): """Test that exporting an invalid model raises a ValueError.""" with self.assertRaises(ValueError) as context: - export_target_platform_model("not_a_model", self.valid_export_path) + export_target_platform_capabilities("not_a_model", self.valid_export_path) self.assertIn("not a valid TargetPlatformCapabilities instance", str(context.exception)) def test_export_with_invalid_path(self): """Test that exporting to an invalid path raises an OSError.""" with self.assertRaises(OSError) as context: - export_target_platform_model(self.tpc, self.invalid_export_path) + export_target_platform_capabilities(self.tpc, self.invalid_export_path) self.assertIn("Failed to write to file", str(context.exception)) def test_export_creates_parent_directories(self): """Test that exporting creates missing parent directories.""" nested_path = "nested/directory/exported_model.json" try: - export_target_platform_model(self.tpc, nested_path) + export_target_platform_capabilities(self.tpc, nested_path) # Verify the file exists self.assertTrue(os.path.exists(nested_path)) @@ -145,8 +145,8 @@ def test_export_creates_parent_directories(self): def test_export_then_import(self): """Test that a model exported and then imported is identical.""" - export_target_platform_model(self.tpc, self.valid_export_path) - imported_model = load_target_platform_model(self.valid_export_path) + export_target_platform_capabilities(self.tpc, self.valid_export_path) + imported_model = load_target_platform_capabilities(self.valid_export_path) self.assertEqual(self.tpc, imported_model) class TargetPlatformModelingTest(unittest.TestCase):