+ +
+

Keras Structured Pruning

+
+
+model_compression_toolkit.pruning.keras_pruning_experimental(model, target_kpi, representative_data_gen, pruning_config=PruningConfig(), target_platform_capabilities=DEFAULT_KERAS_TPC)
+

Perform structured pruning on a Keras model to meet a specified target KPI. +This function prunes the provided model according to the target KPI by grouping and pruning +channels based on each layer’s SIMD configuration in the Target Platform Capabilities (TPC). +By default, the importance of each channel group is determined using the Label-Free Hessian +(LFH) method, assessing each channel’s sensitivity to the Hessian of the loss function. +This pruning strategy considers groups of channels together for a more hardware-friendly +architecture. The process involves analyzing the model with a representative dataset to +identify groups of channels that can be removed with minimal impact on performance.

+

Notice that the pruned model must be retrained to recover the compressed model’s performance.

+
+
Parameters:
+
    +
  • model (Model) – The original Keras model to be pruned.

  • +
  • target_kpi (KPI) – The target Key Performance Indicators to be achieved through pruning.

  • +
  • representative_data_gen (Callable) – A function to generate representative data for pruning analysis.

  • +
  • pruning_config (PruningConfig) – Configuration settings for the pruning process. Defaults to standard config.

  • +
  • target_platform_capabilities (TargetPlatformCapabilities) – Platform-specific constraints and capabilities. +Defaults to DEFAULT_KERAS_TPC.

  • +
+
+
Returns:
+

A tuple containing the pruned Keras model and associated pruning information.

+
+
Return type:
+

Tuple[Model, PruningInfo]

+
+
+

Examples

+

Import MCT:

+
>>> import model_compression_toolkit as mct
+
+
+

Import a Keras model:

+
>>> from tensorflow.keras.applications.resnet50 import ResNet50
+>>> model = ResNet50()
+
+
+

Create a random dataset generator:

+
>>> import numpy as np
+>>> def repr_datagen(): yield [np.random.random((1, 224, 224, 3))]
+
+
+

Define a target KPI for pruning. +Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights +are represented in float32 data type (thus, each parameter is represented using 4 bytes):

+
>>> dense_nparams = sum([l.count_params() for l in model.layers])
+>>> target_kpi = mct.KPI(weights_memory=dense_nparams * 4 * 0.5)
+
+
+

Optionally, define a pruning configuration. num_score_approximations can be passed +to configure the number of importance scores that will be calculated for each channel. +A higher value for this parameter yields more precise score approximations but also +extends the duration of the pruning process:

+
>>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)
+
+
+

Perform pruning:

+
>>> pruned_model, pruning_info = mct.pruning.keras_pruning_experimental(model=model, target_kpi=target_kpi, representative_data_gen=repr_datagen, pruning_config=pruning_config)
+
+
+
+
Return type:
+

Tuple[Model, PruningInfo]

+
+
+
+ +
+
+

Pruning Configuration

+
+
+model_compression_toolkit.pruning.PruningConfig(num_score_approximations=32, importance_metric=ImportanceMetric.LFH, channels_filtering_strategy=ChannelsFilteringStrategy.GREEDY)
+

Configuration class for specifying how a neural network should be pruned.

+
+
+model_compression_toolkit.pruning.num_score_approximations
+

The number of score approximations to perform +when calculating channel importance.

+
+
Type:
+

int

+
+
+
+ +
+
+model_compression_toolkit.pruning.importance_metric
+

The metric used to calculate channel importance.

+
+
Type:
+

ImportanceMetric

+
+
+
+ +
+
+model_compression_toolkit.pruning.channels_filtering_strategy
+

The strategy used to filter out channels.

+
+
Type:
+

ChannelsFilteringStrategy

+
+
+
+ +
+ +
+
+

Pruning Information

+
+
+model_compression_toolkit.pruning.PruningInfo(pruning_masks, importance_scores)
+

PruningInfo stores information about a pruned model, including the pruning masks +and importance scores for each layer. This class acts as a container for accessing +pruning-related metadata.

+
+
+model_compression_toolkit.pruning.pruning_masks
+

Stores the pruning masks for each layer. +A pruning mask is an array where each element indicates whether the corresponding +channel or neuron has been pruned (0) or kept (1).

+
+
Type:
+

Dict[BaseNode, np.ndarray]

+
+
+
+ +
+
+model_compression_toolkit.pruning.importance_scores
+

Stores the importance scores for each layer. +Importance scores quantify the significance of each channel in the layer.

+
+
Type:
+

Dict[BaseNode, np.ndarray]

+
+
+
+ +
+ +
+ + +
+