Skip to content

Commit

Permalink
align torch and keras gptq tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Jul 3, 2024
1 parent ed0a314 commit 2bb3359
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
"## Overview\n",
"\n",
"This tutorial demonstrates a pre-trained model quantization using the **Model Compression Toolkit (MCT)** with **Gradient-based PTQ (GPTQ)**. \n",
"\n",
"As we will see, GPTQ stands as an optimization procedure that markedly enhances the performance of models undergoing post-training quantization.\n",
"GPTQ stands as an optimization procedure that markedly enhances the performance of models undergoing post-training quantization.\n",
"This is achieved through an optimization process applied post-quantization, specifically adjusting the rounding of quantized weights.\n",
"GPTQ is especially effective after mixed precision quantization. \n",
"\n",
"This tutorial's scope is limited to demonstrating GPTQ usage. In this example, we quantize the model and evaluate the accuracy before and after quantization.\n",
"\n",
"In this example, we quantize the model and evaluate the accuracy before and after quantization.\n",
"For an example of a full quantization flow utilizing GPTQ see [full quantization tutorial](https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/imx500_notebooks/keras/keras_yolov8n_for_imx500.ipynb)\n",
"\n",
"## Summary\n",
"\n",
Expand All @@ -44,10 +46,24 @@
},
"outputs": [],
"source": [
"!pip install -q tensorflow==2.14\n",
"!pip install -q mct-nightly"
"!pip install -q tensorflow==2.15"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import importlib.util\n",
"\n",
"if not importlib.util.find_spec('model_compression_toolkit'):\n",
" !pip install -q model_compression_toolkit"
],
"metadata": {
"collapsed": false
},
"id": "2c13aff20d208c51"
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -59,8 +75,8 @@
"source": [
"import tensorflow as tf\n",
"import keras\n",
"import model_compression_toolkit as mct\n",
"import os"
"\n",
"import model_compression_toolkit as mct\n"
]
},
{
Expand All @@ -82,12 +98,15 @@
"execution_count": null,
"outputs": [],
"source": [
"import os\n",
" \n",
"if not os.path.isdir('imagenet'):\n",
" !mkdir imagenet\n",
" !wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz\n",
" !mv ILSVRC2012_devkit_t12.tar.gz imagenet/\n",
" !wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar\n",
" !mv ILSVRC2012_img_val.tar imagenet/"
" !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz\n",
" !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar\n",
" \n",
" !cd imagenet && tar -xzf ILSVRC2012_devkit_t12.tar.gz && \\\n",
" mkdir ILSVRC2012_img_val && tar -xf ILSVRC2012_img_val.tar -C ILSVRC2012_img_val"
],
"metadata": {
"collapsed": false
Expand All @@ -97,21 +116,40 @@
{
"cell_type": "markdown",
"source": [
"Extract ImageNet validation dataset using torchvision \"datasets\" module"
"Rearrange the extracted data into folders per label "
],
"metadata": {
"collapsed": false
},
"id": "36529e754226cb00"
"id": "c6ac5021ed7a2ac3"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import torchvision\n",
"if not os.path.isdir('imagenet/val'):\n",
" torchvision.datasets.ImageNet(root='./imagenet', split='val')"
"from pathlib import Path\n",
"import shutil\n",
"\n",
"root = Path('./imagenet')\n",
"imgs_dir = root / 'ILSVRC2012_img_val'\n",
"target_dir = root /'val'\n",
"\n",
"def extract_labels():\n",
" !pip install -q scipy\n",
" import scipy\n",
" mat = scipy.io.loadmat(root / 'ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)\n",
" cls_to_nid = {s[0]: s[1] for i, s in enumerate(mat['synsets']) if s[4] == 0} \n",
" with open(root / 'ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt', 'r') as f:\n",
" return [cls_to_nid[int(cls)] for cls in f.readlines()]\n",
"\n",
"if not target_dir.exists():\n",
" labels = extract_labels()\n",
" for lbl in set(labels):\n",
" os.makedirs(target_dir / lbl)\n",
" \n",
" for img_file, lbl in zip(sorted(os.listdir(imgs_dir)), labels):\n",
" shutil.move(imgs_dir / img_file, target_dir / lbl)\n"
],
"metadata": {
"collapsed": false
Expand All @@ -125,16 +163,11 @@
"id": "028112db-3143-4fcb-96ae-e639e6476c31"
},
"source": [
"Define the required preprocessing method for the pretrained model,\n",
"and create a generator for the representative dataset, which is required for post-training quantization.\n",
"Create a generator for the representative dataset. The representative dataset is used for collecting statistics on the inference outputs of all layers in the model. The batch size and the number of iterations determine the size of the representative dataset (n_iter x batch_size). In this example we set batch_size = 50 and n_iter = 10, resulting in a total of 500 representative images.\n",
"\n",
"The representative dataset is used for collecting statistics on the inference outputs of all layers in the model.\n",
" \n",
"In order to decide on the size of the representative dataset, we configure the batch size and the number of calibration iterations.\n",
"This gives us the total number of samples that will be used during PTQ (batch_size x n_iter).\n",
"In this example we set `batch_size = 50` and `n_iter = 10`, resulting in a total of 500 representative images.\n",
"GPTQ is a gradient-based optimization process, which requires representative dataset to perform inference and compute gradients. It is possible to define a separate representative dataset in addition to the one used for the PTQ statistics collection. A possible reason to do so is, for example, to use a larger dataset in the optimization process.\n",
"\n",
"Please ensure that the dataset path has been set correctly."
"In this tutorial we use the same representative dataset for both statistics collection and GPTQ. A complete pass through the representative dataset generator constitutes an epoch (batch_size x n_iter samples). In this example we use the same dataloader iterator for all epochs, i.e. different images are used in different epochs."
]
},
{
Expand All @@ -147,14 +180,6 @@
"outputs": [],
"source": [
"def imagenet_preprocess_input(images, labels):\n",
" \"\"\"\n",
" Use the keras applications preprocess function.\n",
" Args:\n",
" images: input image batch.\n",
" labels: input label batch.\n",
" Returns:\n",
" preprocessed images & labels\n",
" \"\"\"\n",
" return tf.keras.applications.mobilenet_v2.preprocess_input(images), labels"
]
},
Expand All @@ -168,12 +193,6 @@
"outputs": [],
"source": [
"def get_representative_dataset(n_iter=10, batch_size=50):\n",
" \"\"\"\n",
" Download the ImageNet validation set locally and create the representative dataset generator.\n",
" Returns:\n",
" representative dataset generator for calibration\n",
" \"\"\"\n",
" print('loading dataset, this may take a few minutes ...')\n",
" dataset = tf.keras.utils.image_dataset_from_directory(\n",
" directory='./imagenet/val',\n",
" batch_size=batch_size,\n",
Expand Down Expand Up @@ -216,6 +235,7 @@
"outputs": [],
"source": [
"from keras.applications.mobilenet_v2 import MobileNetV2\n",
"\n",
"float_model = MobileNetV2()"
]
},
Expand All @@ -226,32 +246,23 @@
"id": "8a8b486a-ca39-45d9-8699-f7116b0414c9"
},
"source": [
"Next, we create a **GPTQ configuration** with possible GPTQ optimization options (such as the number of epochs for the optimization process). \n",
"MCT will quantize the model and start the GPTQ process to optimize the model's parameters and quantization parameters.\n",
"\n",
"Note that GPTQ is a gradient-based optimization process, which requires representative dataset to perform inference and compute gradients.\n",
"It is possible to define a separate representative dataset than the one used for the PTQ statistics collection.\n",
"A possible reason to do so is, for example, to use a larger dataset in the optimization process.\n",
"In this tutorial we do not create a separate representative dataset, thus, MCT will automatically use the original representative dataset that was passed to the procedure.\n",
"Next, we create a GPTQ configuration with possible GPTQ optimization options (such as the number of epochs for the optimization process). MCT will quantize the model and start the GPTQ process to optimize the model’s parameters and quantization parameters.\n",
"\n",
"In addition, we need to define a `TargetPlatformCapability` object, representing the HW specifications on which we wish to eventually deploy our quantized model."
"In addition, we need to define a TargetPlatformCapability object, representing the HW specifications on which we wish to eventually deploy our quantized model."
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT\n",
"# Create a GPTQ quantization configuration and set the number of training iterations. \n",
"# 50 epochs are sufficient for this tutorial. For GPTQ run after mixed precision quantization a higher number of iterations\n",
"# will be required.\n",
"gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=50)\n",
"\n",
"# Specify the target platform capability (TPC)\n",
"tpc = mct.get_target_platform_capabilities(\"tensorflow\", 'imx500', target_platform_version='v1')\n",
"\n",
"# Create a GPTQ quantization configuration and set the number of training iterations. \n",
"# For the sake of running faster, the hessian based weights are disabled in this tutorial.\n",
"gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=5000,\n",
" use_hessian_based_weights=False,\n",
" regularization_factor=REG_DEFAULT)"
"tpc = mct.get_target_platform_capabilities(\"tensorflow\", 'imx500', target_platform_version='v1')\n"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -318,11 +329,6 @@
"outputs": [],
"source": [
"def get_validation_dataset():\n",
" \"\"\"\n",
" Generate validation dataset\n",
" Returns:\n",
" the validation dataset\n",
" \"\"\"\n",
" dataset = tf.keras.utils.image_dataset_from_directory(\n",
" directory='./imagenet/val',\n",
" batch_size=50,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@
"\n",
"if not os.path.isdir('imagenet'):\n",
" !mkdir imagenet\n",
" !wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz\n",
" !mv ILSVRC2012_devkit_t12.tar.gz imagenet/\n",
" !wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar\n",
" !mv ILSVRC2012_img_val.tar imagenet/"
" !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz\n",
" !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar"
],
"metadata": {
"collapsed": false
Expand Down

0 comments on commit 2bb3359

Please sign in to comment.