Skip to content

Commit

Permalink
Update docsrc (#1339)
Browse files Browse the repository at this point in the history
* Update docsrc.

* Fix comments that were parsed wrongly in docs.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Jan 21, 2025
1 parent ab507f9 commit 109fdbe
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 20 deletions.
Binary file added docsrc/images/tpc_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 20 additions & 1 deletion docsrc/source/api/api_docs/classes/GradientPTQConfig.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ GradientPTQConfig Class
=================================


**The following API can be used to create a GradientPTQConfig instance which can be used for post training quantization using knowledge distillation from a teacher (float Keras model) to a student (the quantized Keras model)**
**The following API can be used to create a GradientPTQConfig instance which can be used for post training quantization using knowledge distillation from a teacher (float model) to a student (the quantized model)**

.. autoclass:: model_compression_toolkit.gptq.GradientPTQConfig
:members:
Expand All @@ -30,3 +30,22 @@ RoundingType

.. autoclass:: model_compression_toolkit.gptq.RoundingType
:members:


=====================================
GradualActivationQuantizationConfig
=====================================

**The following API can be used to configure the gradual activation quantization when using GPTQ.**

.. autoclass:: model_compression_toolkit.gptq.GradualActivationQuantizationConfig
:members:


=====================================
QFractionLinearAnnealingConfig
=====================================

.. autoclass:: model_compression_toolkit.gptq.QFractionLinearAnnealingConfig
:members:

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Models for IMX500, TFLite and qnnpack can be observed `here <https://github.com/
The object MCT should get called TargetPlatformCapabilities (or shortly TPC).
This diagram demonstrates the main components:

.. image:: ../../../../images/tpc.jpg
.. image:: ../../../../images/tpc_diagram.png
:scale: 80%

Now, we will detail about the different components.
Expand Down
46 changes: 46 additions & 0 deletions docsrc/source/diagrams/tpc.mermaid
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
flowchart TB
subgraph TPC["Target Platform Capabilities"]
subgraph QCO["Quantization Config Options"]
Default["Default QCO (8-bit)"]
Mixed["Mixed Precision QCO (8, 4, 2 bits)"]
NoQuant["No Quantization QCO"]

subgraph OQC["Op Quantization Config"]
OQC1["8-bit"]
OQC2["4-bit"]
OQC3["2-bit"]
AQC["AttributeQuantizationConfig:<br>Kernel, Bias"]

OQC1 -->|contains| AQC
OQC2 -->|contains| AQC
OQC3 -->|contains| AQC
end

Mixed -->|contains| OQC1
Mixed -->|contains| OQC2
Mixed -->|contains| OQC3
end

subgraph OPS["Operators Sets"]
Conv["Conv, Conv Transpose,<br>Depthwise Conv"]
Act["ReLU, ReLU6,<br>Leaky ReLU, etc."]
NoQuantOps["Dropout, Flatten,<br>Reshape, etc."]
end

subgraph FP["Fusing Patterns"]
FP1["Conv + Activation"]
end

Mixed -->|attached to| Conv
Default -->|attached to| Act
NoQuant -->|attached to| NoQuantOps

FP1 -.-> Conv
FP1 -.-> Act
end

style TPC fill:#e6f3ff,stroke:#333
style QCO fill:#e6ffe6,stroke:#333
style OQC fill:#fff9e6,stroke:#333
style OPS fill:#ffe6e6,stroke:#333
style FP fill:#ffe6f0,stroke:#333
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ class QuantizationConfig:
>>> qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING, weights_error_method=mct.core.QuantizationErrorMethod.MSE, relu_bound_to_power_of_2=True, weights_bias_correction=True)
The QuantizationConfig instance can then be used in the quantization workflow,
such as with Keras in the function: :func:~model_compression_toolkit.ptq.keras_post_training_quantization`.
"""

activation_error_method: QuantizationErrorMethod = QuantizationErrorMethod.MSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_pytorch_gptq_config(n_epochs: int,
regularization_factor (float): A floating point number that defines the regularization factor.
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. GradualActivationQuantizationConfig object can be passed to use non-default settings.
returns:
a GradientPTQConfig object to use when fine-tuning the quantized model using gptq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

from model_compression_toolkit.gptq import GradientPTQConfig
from model_compression_toolkit.core import common
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
get_activation_inferable_quantizer_kwargs
from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
BasePytorchGPTQTrainableQuantizer
from mct_quantizers import QuantizationTarget
from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer

from model_compression_toolkit.logger import Logger
Expand Down
21 changes: 9 additions & 12 deletions model_compression_toolkit/qat/pytorch/quantization_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
User information that may be needed to handle the quantized model.
Examples:
Import MCT:
>>> import model_compression_toolkit as mct
Expand All @@ -120,21 +119,19 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
>>> from torchvision.models import mobilenet_v2
>>> model = mobilenet_v2(pretrained=True)
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches):
In this example a random dataset of 10 batches each containing 4 images is used.
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches). In this example, a random dataset of 10 batches each containing 4 images is used:
>>> import numpy as np
>>> num_calibration_batches = 10
>>> def repr_datagen():
>>> for _ in range(num_calibration_batches):
>>> yield [np.random.random((4, 3, 224, 224))]
>>> import numpy as np
>>> num_calibration_batches = 10
>>> def repr_datagen():
>>> for _ in range(num_calibration_batches):
>>> yield [np.random.random((4, 3, 224, 224))]
Create a MCT core config, containing the quantization configuration:
>>> config = mct.core.CoreConfig()
Pass the model, the representative dataset generator, the configuration and the target resource utilization to get a
quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
Pass the model, the representative dataset generator, the configuration and the target resource utilization to get a quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
>>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
Expand All @@ -149,8 +146,8 @@ def pytorch_quantization_aware_training_init_experimental(in_model: Module,
if core_config.is_mixed_precision_enabled:
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
Logger.critical("Given quantization config to mixed-precision facade is not of type "
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
"or pass a valid mixed precision configuration.")
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
"or pass a valid mixed precision configuration.")

tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
fw_impl = PytorchImplementation()
Expand Down

0 comments on commit 109fdbe

Please sign in to comment.