Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into substitution/scaled_dot_product_a…
Browse files Browse the repository at this point in the history
…ttention
  • Loading branch information
yardeny-sony committed Oct 1, 2024
2 parents 5d74d6a + 3b8f1cc commit a82f570
Show file tree
Hide file tree
Showing 20 changed files with 1,887 additions and 2,483 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Explore the Model Compression Toolkit (MCT) through our tutorials,
covering compression techniques for Keras and PyTorch models. Access interactive [notebooks](https://github.com/sony/model_optimization/blob/main/tutorials/README.md)
for hands-on learning. For example:
* [Keras MobileNetV2 post training quantization](https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/imx500_notebooks/keras/example_keras_mobilenetv2_for_imx500.ipynb)
* [Post training quantization with PyTorch](https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_ptq_mnist.ipynb)
* [Post training quantization with PyTorch](https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_post_training_quantization.ipynb)
* [Data Generation for ResNet18 with PyTorch](https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_data_generation.ipynb).


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy

import time
from typing import Callable, Any, Tuple, List, Union

Expand Down Expand Up @@ -179,8 +181,11 @@ def pytorch_data_generation_experimental(
# get the model device
device = get_working_device()

# copy model for data generation
model_for_data_gen = copy.deepcopy(model)

# get a static graph representation of the model using torch.fx
fx_model = symbolic_trace(model)
fx_model = symbolic_trace(model_for_data_gen)

# Get Data Generation functions and classes
image_pipeline, normalization, bn_layer_weighting_fn, bn_alignment_loss_fn, output_loss_fn, \
Expand Down Expand Up @@ -208,23 +213,23 @@ def pytorch_data_generation_experimental(
scheduler = scheduler_get_fn(data_generation_config.n_iter)

# Set the current model
set_model(model)
set_model(model_for_data_gen)

# Create an activation extractor object to extract activations from the model
activation_extractor = PytorchActivationExtractor(
model,
model_for_data_gen,
fx_model,
data_generation_config.bn_layer_types,
data_generation_config.last_layer_types)

# Create an orig_bn_stats_holder object to hold original BatchNorm statistics
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model_for_data_gen, data_generation_config.bn_layer_types)
if orig_bn_stats_holder.get_num_bn_layers() == 0:
Logger.critical(
f'Data generation requires a model with at least one BatchNorm layer.') # pragma: no cover

# Create an ImagesOptimizationHandler object for handling optimization
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model_for_data_gen,
data_gen_batch_size=data_generation_config.data_gen_batch_size,
init_dataset=init_dataset,
optimizer=data_generation_config.optimizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
base_config=const_config_input16_per_tensor)

qpreserving_const_config = const_config.clone_and_edit(enable_activation_quantization=False,
quantization_preserving=True)
quantization_preserving=True,
default_weight_attr_config=const_config.default_weight_attr_config.clone_and_edit(
weights_per_channel_threshold=False))
qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config])

# Create a TargetPlatformModel and set its default quantization config.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract
from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
from torch.nn import Dropout, Flatten, Hardtanh, Identity
from torch.nn import Dropout, Flatten, Hardtanh
from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu

Expand Down Expand Up @@ -87,7 +87,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
squeeze,
permute,
transpose])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather])
tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather, torch.Tensor.expand])
tp.OperationsSetToLayers(OPSET_MERGE_OPS,
[torch.stack, torch.cat, torch.concat, torch.concatenate])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import torch.nn as nn
import numpy as np
import model_compression_toolkit as mct
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import Signedness
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy, set_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.pytorch_tests.utils import get_layers_from_model_by_type
from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, DEFAULT_WEIGHT_ATTR_CONFIG
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.constants import PYTORCH
from mct_quantizers import PytorchQuantizationWrapper
Expand Down Expand Up @@ -196,3 +198,77 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
for qlayer in get_layers_from_model_by_type(quantized_model, op):
self.unit_test.assertTrue(isinstance(qlayer, PytorchQuantizationWrapper),
msg=f"{op} should be quantized.")


class ExpandConstQuantizationNet(nn.Module):
def __init__(self, batch_size):
super().__init__()
self.register_buffer('cat_const', to_torch_tensor(np.random.randint(-128, 127, size=(batch_size, 3, 32, 32)).astype(np.float32)))
self.register_parameter('expand_const',
nn.Parameter(to_torch_tensor(np.random.randint(-128, 127, size=(1, 2, 32, 1)).astype(np.float32)),
requires_grad=False))

def forward(self, x):
expanded_const = self.expand_const.expand(x.shape[0], -1, -1, 32)
x = torch.cat([expanded_const, self.cat_const, x], dim=1)
return x


class ConstQuantizationExpandTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test=unit_test, input_shape=(16, 32, 32), val_batch_size=5)

def generate_inputs(self):
return [np.random.randint(-128, 127, size=in_shape).astype(np.float32) for in_shape in self.get_input_shapes()]

def get_tpc(self):
tp = mct.target_platform
attr_cfg = generate_test_attr_configs()
base_cfg = tp.OpQuantizationConfig(activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
enable_activation_quantization=True,
activation_n_bits=32,
supported_input_activation_n_bits=32,
default_weight_attr_config=attr_cfg[DEFAULT_WEIGHT_ATTR_CONFIG],
attr_weights_configs_mapping={},
quantization_preserving=False,
fixed_scale=1.0,
fixed_zero_point=0,
simd_size=32,
signedness=Signedness.AUTO)

default_configuration_options = tp.QuantizationConfigOptions([base_cfg])

const_config = base_cfg.clone_and_edit(enable_activation_quantization=False,
default_weight_attr_config=base_cfg.default_weight_attr_config.clone_and_edit(
enable_weights_quantization=True, weights_per_channel_threshold=False,
weights_quantization_method=tp.QuantizationMethod.POWER_OF_TWO))
const_configuration_options = tp.QuantizationConfigOptions([const_config])

tp_model = tp.TargetPlatformModel(default_configuration_options)
with tp_model:
tp.OperatorsSet("WeightQuant", const_configuration_options)

tpc = tp.TargetPlatformCapabilities(tp_model)
with tpc:
tp.OperationsSetToLayers("WeightQuant", [torch.Tensor.expand, torch.cat])

return tpc

def create_networks(self):
return ExpandConstQuantizationNet(self.val_batch_size)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
in_torch_tensor = to_torch_tensor(input_x[0])
set_model(float_model)
y = float_model(in_torch_tensor)
y_hat = quantized_model(in_torch_tensor)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(torch_tensor_to_numpy(y), torch_tensor_to_numpy(y_hat))
self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-3), msg=f'fail cosine similarity check: {cs}')

# check quantization layers:
for op in [torch.cat, torch.Tensor.expand]:
for qlayer in get_layers_from_model_by_type(quantized_model, op):
self.unit_test.assertTrue(isinstance(qlayer, PytorchQuantizationWrapper),
msg=f"{op} should be quantized.")
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
ConstRepresentationCodeTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.const_quantization_test import ConstQuantizationTest, \
AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest
AdvancedConstQuantizationTest, ConstQuantizationMultiInputTest, ConstQuantizationExpandTest
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest
from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \
Activation16BitMixedPrecisionTest
Expand Down Expand Up @@ -267,6 +267,7 @@ def test_const_quantization(self):

AdvancedConstQuantizationTest(self).run_test()
ConstQuantizationMultiInputTest(self).run_test()
ConstQuantizationExpandTest(self).run_test()

def test_const_representation(self):
for const_dtype in [np.float32, np.int64, np.int32]:
Expand Down
7 changes: 5 additions & 2 deletions tests/pytorch_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def get_layers_from_model_by_type(model: torch.nn.Module,
Returns:
List of layers of type layer_type from the model.
"""
match_layer_type = lambda _layer: layer_type in [type(_layer), _layer]
matched_list = [layer[1] for layer in model.named_children() if match_layer_type(layer[1])]
if include_wrapped_layers:
return [layer[1] for layer in model.named_children() if type(layer[1])==layer_type or (isinstance(layer[1], PytorchQuantizationWrapper) and type(layer[1].layer)==layer_type)]
return [layer[1] for layer in model.named_children() if type(layer[1])==layer_type]
matched_list.extend([layer[1] for layer in model.named_children()
if (isinstance(layer[1], PytorchQuantizationWrapper) and match_layer_type(layer[1].layer))])
return matched_list


def count_model_prunable_params(model: torch.nn.Module) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Access interactive Jupyter notebooks for hands-on learning.
## Getting started
Learn how to quickly quantize pre-trained models using MCT's post-training quantization technique for both Keras and PyTorch models.
- [Post training quantization with Keras](notebooks/imx500_notebooks/keras/example_keras_mobilenetv2_for_imx500.ipynb)
- [Post training quantization with PyTorch](notebooks/mct_features_notebooks/pytorch/example_pytorch_ptq_mnist.ipynb)
- [Post training quantization with PyTorch](notebooks/mct_features_notebooks/pytorch/example_pytorch_post_training_quantization.ipynb)

## MCT Features
This set of tutorials covers all the quantization tools provided by MCT.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@
"source": [
"from tutorials.mct_model_garden.models_pytorch.yolov8.yolov8 import seg_model_predict\n",
"from tutorials.mct_model_garden.evaluation_metrics.coco_evaluation import evaluate_yolov8_segmentation\n",
"from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device\n",
"device = get_working_device()\n",
"model = model.to(device)\n",
"evaluate_yolov8_segmentation(model, seg_model_predict, data_dir='coco', data_type='val2017', img_ids_limit=100, output_file='results.json', iou_thresh=0.7, conf=0.001, max_dets=300,mask_thresh=0.55)"
Expand Down
28 changes: 16 additions & 12 deletions tutorials/notebooks/mct_features_notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,11 @@ These techniques are essential for further optimizing models and achieving super
<details id="pytorch-ptq">
<summary>Post-Training Quantization (PTQ)</summary>

| Tutorial | Included Features |
|---------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| [Training & Quantizing Model on MNIST](pytorch/example_pytorch_ptq_mnist.ipynb) | &#x2705; PTQ |
| [Mixed-Precision MobileNetV2 on Cifar100](pytorch/example_pytorch_mobilenetv2_cifar100_mixed_precision.ipynb) | &#x2705; PTQ <br/> &#x2705; Mixed-Precision |
| [SSDLite MobileNetV3 Quantization](pytorch/example_pytorch_ssdlite_mobilenetv3_object_detection.ipynb) | &#x2705; PTQ |

</details>


| Tutorial | Included Features |
|-----------------------------------------------------------------------------------------------------------|---------------------------------------------|
| [Basic Post-Training Quantization (PTQ)](pytorch/example_pytorch_post_training_quantization.ipynb) | &#x2705; PTQ |
| [Mixed-Precision Post-Training Quantization](pytorch/example_pytorch_mixed_precision_ptq.ipynb) | &#x2705; PTQ <br/> &#x2705; Mixed-Precision |
| [Advanced Gradient-Based Post-Training Quantization (GPTQ)](pytorch/example_pytorch_mobilenet_gptq.ipynb) | &#x2705; GPTQ |

</details>

Expand All @@ -97,9 +93,9 @@ These techniques are essential for further optimizing models and achieving super
<details id="pytorch-data-generation">
<summary>Data Generation</summary>

| Tutorial | Included Features |
|-----------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------|
| [Data-Free Quantization using Data Generation](pytorch/example_pytorch_data_generation.ipynb) | &#x2705; PTQ <br/> &#x2705; Data-Free Quantization <br/> &#x2705; Data Generation |
| Tutorial | Included Features |
|-----------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|
| [Zero-Shot Quantization (ZSQ) using Data Generation](pytorch/example_pytorch_data_generation.ipynb) | &#x2705; PTQ <br/> &#x2705; ZSQ <br/> &#x2705; Data-Free Quantization <br/> &#x2705; Data Generation |

</details>

Expand All @@ -112,3 +108,11 @@ These techniques are essential for further optimizing models and achieving super
| [Exporter Usage](pytorch/example_pytorch_export.ipynb) | &#x2705; Export |

</details>
<details id="pytorch-xquant">
<summary>Quantization Troubleshooting</summary>

| Tutorial | Included Features |
|------------------------------------------------------------------------------------------------|-------------------|
| [Quantization Troubleshooting using the Xquant Feature](pytorch/example_pytorch_xquant.ipynb) | &#x2705; Debug |

</details>
Loading

0 comments on commit a82f570

Please sign in to comment.