|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": { |
| 6 | + "colab_type": "text" |
| 7 | + }, |
| 8 | + "source": [ |
| 9 | + "# GPTQ Quantization in Keras\n", |
| 10 | + "\n", |
| 11 | + "**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>\n", |
| 12 | + "**Date created:** 2025/10/16<br>\n", |
| 13 | + "**Last modified:** 2025/10/16<br>\n", |
| 14 | + "**Description:** How to run weight-only GPTQ quantization for Keras & KerasHub models." |
| 15 | + ] |
| 16 | + }, |
| 17 | + { |
| 18 | + "cell_type": "markdown", |
| 19 | + "metadata": { |
| 20 | + "colab_type": "text" |
| 21 | + }, |
| 22 | + "source": [ |
| 23 | + "## What is GPTQ?\n", |
| 24 | + "\n", |
| 25 | + "GPTQ (\"Generative Pre-Training Quantization\") is a post-training, weight-only\n", |
| 26 | + "quantization method that uses a second-order approximation of the loss (via a\n", |
| 27 | + "Hessian estimate) to minimize the error introduced when compressing weights to\n", |
| 28 | + "lower precision, typically 4-bit integers.\n", |
| 29 | + "\n", |
| 30 | + "Unlike standard post-training techniques, GPTQ keeps activations in\n", |
| 31 | + "higher-precision and only quantizes the weights. This often preserves model\n", |
| 32 | + "quality in low bit-width settings while still providing large storage and\n", |
| 33 | + "memory savings.\n", |
| 34 | + "\n", |
| 35 | + "Keras supports GPTQ quantization for KerasHub models via the\n", |
| 36 | + "`keras.quantizers.GPTQConfig` class." |
| 37 | + ] |
| 38 | + }, |
| 39 | + { |
| 40 | + "cell_type": "markdown", |
| 41 | + "metadata": { |
| 42 | + "colab_type": "text" |
| 43 | + }, |
| 44 | + "source": [ |
| 45 | + "## Load a KerasHub model\n", |
| 46 | + "\n", |
| 47 | + "This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B\n", |
| 48 | + "parameter) causal language model." |
| 49 | + ] |
| 50 | + }, |
| 51 | + { |
| 52 | + "cell_type": "code", |
| 53 | + "execution_count": 0, |
| 54 | + "metadata": { |
| 55 | + "colab_type": "code" |
| 56 | + }, |
| 57 | + "outputs": [], |
| 58 | + "source": [ |
| 59 | + "import keras\n", |
| 60 | + "from keras_hub.models import Gemma3CausalLM\n", |
| 61 | + "from datasets import load_dataset\n", |
| 62 | + "\n", |
| 63 | + "\n", |
| 64 | + "prompt = \"Keras is a\"\n", |
| 65 | + "\n", |
| 66 | + "model = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n", |
| 67 | + "\n", |
| 68 | + "outputs = model.generate(prompt, max_length=30)\n", |
| 69 | + "print(outputs)" |
| 70 | + ] |
| 71 | + }, |
| 72 | + { |
| 73 | + "cell_type": "markdown", |
| 74 | + "metadata": { |
| 75 | + "colab_type": "text" |
| 76 | + }, |
| 77 | + "source": [ |
| 78 | + "## Configure & run GPTQ quantization\n", |
| 79 | + "\n", |
| 80 | + "You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.\n", |
| 81 | + "\n", |
| 82 | + "The GPTQ configuration requires a calibration dataset and tokenizer, which it\n", |
| 83 | + "uses to estimate the Hessian and quantization error. Here, we use a small slice\n", |
| 84 | + "of the WikiText-2 dataset for calibration.\n", |
| 85 | + "\n", |
| 86 | + "You can tune several parameters to trade off speed, memory, and accuracy. The\n", |
| 87 | + "most important of these are `weight_bits` (the bit-width to quantize weights to)\n", |
| 88 | + "and `group_size` (the number of weights to quantize together). The group size\n", |
| 89 | + "controls the granularity of quantization: smaller groups typically yield better\n", |
| 90 | + "accuracy but are slower to quantize and may use more memory. A good starting\n", |
| 91 | + "point is `group_size=128` for 4-bit quantization (`weight_bits=4`).\n", |
| 92 | + "\n", |
| 93 | + "In this example, we first prepare a tiny calibration set, and then run GPTQ on\n", |
| 94 | + "the model using the `.quantize(...)` API." |
| 95 | + ] |
| 96 | + }, |
| 97 | + { |
| 98 | + "cell_type": "code", |
| 99 | + "execution_count": 0, |
| 100 | + "metadata": { |
| 101 | + "colab_type": "code" |
| 102 | + }, |
| 103 | + "outputs": [], |
| 104 | + "source": [ |
| 105 | + "# Calibration slice (use a larger/representative set in practice)\n", |
| 106 | + "texts = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")[\"text\"]\n", |
| 107 | + "\n", |
| 108 | + "calibration_dataset = [\n", |
| 109 | + " s + \".\" for text in texts for s in map(str.strip, text.split(\".\")) if s\n", |
| 110 | + "]\n", |
| 111 | + "\n", |
| 112 | + "gptq_config = keras.quantizers.GPTQConfig(\n", |
| 113 | + " dataset=calibration_dataset,\n", |
| 114 | + " tokenizer=model.preprocessor.tokenizer,\n", |
| 115 | + " weight_bits=4,\n", |
| 116 | + " group_size=128,\n", |
| 117 | + " num_samples=256,\n", |
| 118 | + " sequence_length=256,\n", |
| 119 | + " hessian_damping=0.01,\n", |
| 120 | + " symmetric=False,\n", |
| 121 | + " activation_order=False,\n", |
| 122 | + ")\n", |
| 123 | + "\n", |
| 124 | + "model.quantize(\"gptq\", config=gptq_config)\n", |
| 125 | + "\n", |
| 126 | + "outputs = model.generate(prompt, max_length=30)\n", |
| 127 | + "print(outputs)" |
| 128 | + ] |
| 129 | + }, |
| 130 | + { |
| 131 | + "cell_type": "markdown", |
| 132 | + "metadata": { |
| 133 | + "colab_type": "text" |
| 134 | + }, |
| 135 | + "source": [ |
| 136 | + "## Model Export\n", |
| 137 | + "\n", |
| 138 | + "The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just\n", |
| 139 | + "like any other KerasHub model." |
| 140 | + ] |
| 141 | + }, |
| 142 | + { |
| 143 | + "cell_type": "code", |
| 144 | + "execution_count": 0, |
| 145 | + "metadata": { |
| 146 | + "colab_type": "code" |
| 147 | + }, |
| 148 | + "outputs": [], |
| 149 | + "source": [ |
| 150 | + "model.save_to_preset(\"gemma3_gptq_w4gs128_preset\")\n", |
| 151 | + "model_from_preset = Gemma3CausalLM.from_preset(\"gemma3_gptq_w4gs128_preset\")\n", |
| 152 | + "output = model_from_preset.generate(prompt, max_length=30)\n", |
| 153 | + "print(output)" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "markdown", |
| 158 | + "metadata": { |
| 159 | + "colab_type": "text" |
| 160 | + }, |
| 161 | + "source": [ |
| 162 | + "## Performance & Benchmarking\n", |
| 163 | + "\n", |
| 164 | + "Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).\n", |
| 165 | + "Baselines are FP32.\n", |
| 166 | + "\n", |
| 167 | + "Dataset: WikiText-2.\n", |
| 168 | + "\n", |
| 169 | + "\n", |
| 170 | + "| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |\n", |
| 171 | + "| ------------------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n", |
| 172 | + "| gpt2_causal_lm (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |\n", |
| 173 | + "| opt_causal_lm (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |\n", |
| 174 | + "| bloom_causal_lm (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |\n", |
| 175 | + "| gemma3_causal_lm (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |\n", |
| 176 | + "\n", |
| 177 | + "\n", |
| 178 | + "Detailed benchmarking numbers and scripts are available\n", |
| 179 | + "[here](https://github.com/keras-team/keras/pull/21641).\n", |
| 180 | + "\n", |
| 181 | + "### Analysis\n", |
| 182 | + "\n", |
| 183 | + "There is notable reduction in disk space and VRAM usage across all models, with\n", |
| 184 | + "disk space savings around 50% and VRAM savings ranging from 41% to 54%. The\n", |
| 185 | + "reported disk savings understate the true weight compression because presets\n", |
| 186 | + "also include non-weight assets.\n", |
| 187 | + "\n", |
| 188 | + "Perplexity increases only marginally, indicating model quality is largely\n", |
| 189 | + "preserved after quantization." |
| 190 | + ] |
| 191 | + }, |
| 192 | + { |
| 193 | + "cell_type": "markdown", |
| 194 | + "metadata": { |
| 195 | + "colab_type": "text" |
| 196 | + }, |
| 197 | + "source": [ |
| 198 | + "## Practical tips\n", |
| 199 | + "\n", |
| 200 | + "* GPTQ is weight-only; training after quantization is not supported.\n", |
| 201 | + "* Always use the model's own tokenizer for calibration.\n", |
| 202 | + "* Use a representative calibration set; small slices are only for demos.\n", |
| 203 | + "* Start with W4 group_size=128; tune per model/task.\n", |
| 204 | + "* Save to `.keras` or to a preset for reuse elsewhere." |
| 205 | + ] |
| 206 | + } |
| 207 | + ], |
| 208 | + "metadata": { |
| 209 | + "accelerator": "GPU", |
| 210 | + "colab": { |
| 211 | + "collapsed_sections": [], |
| 212 | + "name": "gptq_quantization_in_keras", |
| 213 | + "private_outputs": false, |
| 214 | + "provenance": [], |
| 215 | + "toc_visible": true |
| 216 | + }, |
| 217 | + "kernelspec": { |
| 218 | + "display_name": "Python 3", |
| 219 | + "language": "python", |
| 220 | + "name": "python3" |
| 221 | + }, |
| 222 | + "language_info": { |
| 223 | + "codemirror_mode": { |
| 224 | + "name": "ipython", |
| 225 | + "version": 3 |
| 226 | + }, |
| 227 | + "file_extension": ".py", |
| 228 | + "mimetype": "text/x-python", |
| 229 | + "name": "python", |
| 230 | + "nbconvert_exporter": "python", |
| 231 | + "pygments_lexer": "ipython3", |
| 232 | + "version": "3.7.0" |
| 233 | + } |
| 234 | + }, |
| 235 | + "nbformat": 4, |
| 236 | + "nbformat_minor": 0 |
| 237 | +} |
0 commit comments