Skip to content

Commit

Permalink
Keras data generation (#794)
Browse files Browse the repository at this point in the history
* Add feture Keras data generation
  • Loading branch information
lapid92 authored Sep 12, 2023
1 parent ed871c7 commit 109c362
Show file tree
Hide file tree
Showing 31 changed files with 2,242 additions and 221 deletions.
52 changes: 50 additions & 2 deletions model_compression_toolkit/data_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ pip install model-compression-toolkit
```

## Usage
### PyTorch

```python
from model_compression_toolkit.data_generation import get_pytorch_data_generation_config, pytorch_data_generation_experimental

# Set the configuration parameters for data generation
data_gen_config = get_pytorch_data_generation_config(n_iter=500, # Number of iterations
data_gen_batch_size=32, # Batch size for data generation
image_padding=32, # image manipulation when generating data
extra_pixels=32, # image manipulation when generating data
# ... (other configuration parameters)
)

Expand All @@ -36,9 +37,29 @@ generated_images = pytorch_data_generation_experimental(model=my_model, # PyTor
)
```

### Keras

```python
from model_compression_toolkit.data_generation import get_tensorflow_data_generation_config, tensorflow_data_generation_experimental

# Set the configuration parameters for data generation
data_gen_config = get_tensorflow_data_generation_config(n_iter=500, # Number of iterations
data_gen_batch_size=32, # Batch size for data generation
extra_pixels=32, # image manipulation when generating data
# ... (other configuration parameters)
)

# Call the data generation function to generate images
generated_images = tensorflow_data_generation_experimental(model=my_model, # PyTorch model to generate data for
n_images=1024, # Number of images to generate
output_image_size=224, # Size of the output images
data_generation_config=data_gen_config # Configuration for data generation
)
```

## Configuration Parameters

The `get_pytorch_data_generation_config()` function allows you to customize various configuration parameters for data generation. Here are the essential parameters that can be tailored to your specific needs:
The `get_pytorch_data_generation_config()` and `get_tensorflow_data_generation_config()` functions allow you to customize various configuration parameters for data generation. Here are the essential parameters that can be tailored to your specific needs:
- **'n_iter'** (int): The number of iterations for the data generation optimization process. Controls the number of iterations used during the optimization process for generating data. Higher values may improve data quality at the cost of increased computation time.
- **'optimizer'** (Optimizer): The optimizer used during data generation to update the generated images. Specifies the optimization algorithm used to update the generated images during the data generation process. Common optimizers include RAdam, Adam, SGD, etc.
- **'data_gen_batch_size'** (int): The batch size used during data generation optimization. Determines the number of images processed in each optimization step. A larger batch size may speed up the optimization but requires more memory.
Expand All @@ -58,6 +79,7 @@ The `get_pytorch_data_generation_config()` function allows you to customize vari
- **'reflection'** (bool): Indicates whether reflection is used during image clipping. Determines whether reflection is applied to the images during clipping. Reflection can help maintain image realism and continuity in certain cases.

## Results Using Generated Data
## PyTorch
### Experimental setup
##### Quantization Algorithms
Four quantization algorithms were utilized to evaluate the generated data:
Expand Down Expand Up @@ -90,3 +112,29 @@ Please note that the choice of quantization algorithms and data generation param
| Real Data | 69.49 | 58.48 | 66.24 | 69.30 | 71.168 | 64.52 | 64.4 | 70.6 | 36.23 | 29.79 | 25.82 |
| Random Noise | 8.3 | 43.58 | 12.6 | 11.13 | 7.9 | 30.02 | 7.14 | 11.30 | 27.45 | 4.15 | 2.68 |
| Image Generation | 69.51 | 58.57 | 65.70 | 69.07 | 70.155 | 62.82 | 62.49 | 69.59 | 35.12 | 27.77 | 25.02 |


## Keras
### Experimental setup
##### Quantization Algorithms
Post Training Quantization (PTQ) algorithm was utilized to evaluate the generated data:

All experiments were tested with symmetric weights and Power-Of-Two activation quantizers.

To ensure reliable results, all experiments were averaged over 5 different random seeds (0-4).

##### Data Generations Parameters
The evaluation was performed on the following neural network models:

- Mobilenet and Mobilenetv2 from the keras applications library.

The quantization algorithms were tested using three different data types as input: real data, random noise, and generated data.
The generated data was produced using the default data generation configuration with 500 iterations (better results may be achieved with a larger iteration budget).
- 1024 images were generated using a data generation batch size of 32 and a resolution of 224x224.

| Model (float) | Mobilenet (70.558) | Mobilenetv2 (71.812) |
|:---------------------------------------------------:|:------------------:|:--------------------:|
| Data type (rows) \ Quantization algorithm (columns) | PTQ W8A8 | PTQ W8A8 |
| Real Data | 70.427 | 71.599 |
| Random Noise | 58.938 | 70.932 |
| Image Generation | 70.39 | 71.574 |
11 changes: 9 additions & 2 deletions model_compression_toolkit/data_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.data_generation.pytorch.pytorch_data_generation import pytorch_data_generation_experimental
from model_compression_toolkit.data_generation.pytorch.pytorch_data_generation import get_pytorch_data_generation_config
from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF

if FOUND_TF:
from model_compression_toolkit.data_generation.keras.keras_data_generation import (
tensorflow_data_generation_experimental, get_tensorflow_data_generation_config)

if FOUND_TORCH:
from model_compression_toolkit.data_generation.pytorch.pytorch_data_generation import (
pytorch_data_generation_experimental, get_pytorch_data_generation_config)
27 changes: 27 additions & 0 deletions model_compression_toolkit/data_generation/common/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Common constants for Data Generation

# Define a constant for the image input key.
IMAGE_INPUT = 'image_input'

# Define a constant for the number of channels in input image.
NUM_INPUT_CHANNELS = 3

# Default batch size for data generator.
DEFAULT_DATA_GEN_BS = 32

# Default number of iterations.
DEFAULT_N_ITER = 500
188 changes: 92 additions & 96 deletions model_compression_toolkit/data_generation/common/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,113 +13,109 @@
# limitations under the License.
# ==============================================================================
# Import required modules and classes
import time
from typing import Callable, Any, List
from typing import Any, Tuple, Dict, Callable, List

import torch
from tqdm import tqdm

from model_compression_toolkit.core.pytorch.utils import get_working_device
from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType, \
BNLayerWeightingType, DataInitType, BatchNormAlignemntLossType, OutputLossType
from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
from model_compression_toolkit.data_generation.common.model_info_exctractors import ActivationExtractor, \
OriginalBNStatsHolder
from model_compression_toolkit.data_generation.common.optimization_utils import ImagesOptimizationHandler
from model_compression_toolkit.logger import Logger


def data_generation(
def get_data_generation_classes(
data_generation_config: DataGenerationConfig,
activation_extractor: ActivationExtractor,
orig_bn_stats_holder: OriginalBNStatsHolder,
all_imgs_opt_handler: ImagesOptimizationHandler,
image_pipeline: BaseImagePipeline,
bn_layer_weighting_fn: Callable,
bn_alignment_loss_fn: Callable,
output_loss_fn: Callable,
output_loss_multiplier: float
) -> List[Any]:
output_image_size: Tuple,
n_images: int,
image_pipeline_dict: Dict,
image_normalization_dict: Dict,
bn_layer_weighting_function_dict: Dict,
image_initialization_function_dict: Dict,
bn_alignment_loss_function_dict: Dict,
output_loss_function_dict: Dict) \
-> Tuple[BaseImagePipeline, List[List[float]], Callable, Callable, Callable, Any]:
"""
Function to perform data generation using the provided model and data generation configuration.
Function to create a DataGenerationConfig object with the specified configuration parameters.
Args:
data_generation_config (DataGenerationConfig): Configuration for data generation.
activation_extractor (ActivationExtractor): The activation extractor for the model.
orig_bn_stats_holder (OriginalBNStatsHolder): Object to hold original BatchNorm statistics.
all_imgs_opt_handler (ImagesOptimizationHandler): Handles the images optimization process.
image_pipeline (Callable): Callable image pipeline for image manipulation.
bn_layer_weighting_fn (Callable): Function to compute layer weighting for the BatchNorm alignment loss .
bn_alignment_loss_fn (Callable): Function to compute BatchNorm alignment loss.
output_loss_fn (Callable): Function to compute output loss.
output_loss_multiplier (float): Multiplier for the output loss.
output_image_size (Tuple): The desired output image size.
n_images (int): The number of random samples.
image_pipeline_dict (Dict): Dictionary mapping ImagePipelineType to corresponding image pipeline classes.
image_normalization_dict (Dict): Dictionary mapping ImageNormalizationType to corresponding
normalization values.
bn_layer_weighting_function_dict (Dict): Dictionary of layer weighting functions.
image_initialization_function_dict (Dict): Dictionary of image initialization functions.
bn_alignment_loss_function_dict (Dict): Dictionary of batch normalization alignment loss functions.
output_loss_function_dict (Dict): Dictionary of output loss functions.
Returns:
List: Finalized list containing generated images.
image_pipeline (BaseImagePipeline): The image pipeline for processing images during optimization.
normalization (List[List[float]]): The image normalization values for processing images during optimization.
bn_layer_weighting_fn (Callable): Function to compute layer weighting for the BatchNorm alignment loss.
bn_alignment_loss_fn (Callable): Function to compute BatchNorm alignment loss.
output_loss_fn (Callable): Function to compute output loss.
init_dataset (Any): The initial dataset used for image generation.
"""

# Compute the layer weights based on orig_bn_stats_holder
bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder)

# Get the current time to measure the total time taken
total_time = time.time()

# Create a tqdm progress bar for iterating over data_generation_config.n_iter iterations
ibar = tqdm(range(data_generation_config.n_iter))

# Perform data generation iterations
for i_ter in ibar:

# Randomly reorder the batches
all_imgs_opt_handler.random_batch_reorder()

# Iterate over each batch
for i_batch in range(all_imgs_opt_handler.n_batches):
# Get the random batch index
random_batch_index = all_imgs_opt_handler.get_random_batch_index(i_batch)

# Get the images to optimize and the optimizer for the batch
imgs_to_optimize = all_imgs_opt_handler.get_images_by_batch_index(random_batch_index)

# Zero gradients
all_imgs_opt_handler.zero_grad(random_batch_index)

# Perform image input manipulation
input_imgs = image_pipeline.image_input_manipulation(imgs_to_optimize)

# Forward pass to extract activations
output = activation_extractor.run_model(input_imgs)

# Compute BatchNorm alignment loss
bn_loss = all_imgs_opt_handler.compute_bn_loss(input_imgs=input_imgs,
batch_index=random_batch_index,
activation_extractor=activation_extractor,
orig_bn_stats_holder=orig_bn_stats_holder,
bn_alignment_loss_fn=bn_alignment_loss_fn,
bn_layer_weights=bn_layer_weights)


# Compute output loss
output_loss = output_loss_fn(output_imgs=output) if output_loss_multiplier > 0 else torch.zeros(1).to(get_working_device())

# Compute total loss
total_loss = bn_loss + output_loss_multiplier * output_loss

# Perform optimiztion step
all_imgs_opt_handler.optimization_step(random_batch_index, total_loss, i_ter)

# Update the statistics based on the updated images
if all_imgs_opt_handler.use_all_data_stats:
final_imgs = image_pipeline.image_output_finalize(imgs_to_optimize)
all_imgs_opt_handler.update_statistics(input_imgs=final_imgs,
batch_index=random_batch_index,
activation_extractor=activation_extractor)

ibar.set_description(f"Total Loss: {total_loss.item():.5f}, "
f"BN Loss: {bn_loss.item():.5f}, "
f"Output Loss: {output_loss.item():.5f}")


# Return a list containing the finalized generated images
finalized_imgs = all_imgs_opt_handler.get_finalized_images()
Logger.info(f'Total time to generate {len(finalized_imgs)} images (seconds): {int(time.time() - total_time)}')
return finalized_imgs
# Get the image pipeline class corresponding to the specified type
image_pipeline = (
image_pipeline_dict.get(data_generation_config.image_pipeline_type)(
output_image_size=output_image_size,
extra_pixels=data_generation_config.extra_pixels))

# Check if the image pipeline type is valid
if image_pipeline is None:
Logger.exception(
f'Invalid image_pipeline_type {data_generation_config.image_pipeline_type}.'
f'Please choose one of {ImagePipelineType.get_values()}')

# Get the normalization values corresponding to the specified type
normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)

# Check if the image normalization type is valid
if normalization is None:
Logger.exception(
f'Invalid image_normalization_type {data_generation_config.image_normalization_type}.'
f'Please choose one of {ImageNormalizationType.get_values()}')

# Get the layer weighting function corresponding to the specified type
bn_layer_weighting_fn = bn_layer_weighting_function_dict.get(data_generation_config.layer_weighting_type)

if bn_layer_weighting_fn is None:
Logger.exception(
f'Invalid layer_weighting_type {data_generation_config.layer_weighting_type}.'
f'Please choose one of {BNLayerWeightingType.get_values()}')

# Get the image initialization function corresponding to the specified type
image_initialization_fn = image_initialization_function_dict.get(data_generation_config.data_init_type)

# Check if the data initialization type is valid
if image_initialization_fn is None:
Logger.exception(
f'Invalid data_init_type {data_generation_config.data_init_type}.'
f'Please choose one of {DataInitType.get_values()}')

# Get the BatchNorm alignment loss function corresponding to the specified type
bn_alignment_loss_fn = bn_alignment_loss_function_dict.get(data_generation_config.bn_alignment_loss_type)

# Check if the BatchNorm alignment loss type is valid
if bn_alignment_loss_fn is None:
Logger.exception(
f'Invalid bn_alignment_loss_type {data_generation_config.bn_alignment_loss_type}.'
f'Please choose one of {BatchNormAlignemntLossType.get_values()}')

# Get the output loss function corresponding to the specified type
output_loss_fn = output_loss_function_dict.get(data_generation_config.output_loss_type)

# Check if the output loss type is valid
if output_loss_fn is None:
Logger.exception(
f'Invalid output_loss_type {data_generation_config.output_loss_type}.'
f'Please choose one of {OutputLossType.get_values()}')

# Initialize the dataset for data generation
init_dataset = image_initialization_fn(
n_images=n_images,
size=image_pipeline.get_image_input_size(),
batch_size=data_generation_config.data_gen_batch_size)

return image_pipeline, normalization, bn_layer_weighting_fn, bn_alignment_loss_fn, output_loss_fn, init_dataset
Loading

0 comments on commit 109c362

Please sign in to comment.