diff --git a/guides/int8_quantization_in_keras.py b/guides/int8_quantization_in_keras.py new file mode 100644 index 0000000000..b162e14cf9 --- /dev/null +++ b/guides/int8_quantization_in_keras.py @@ -0,0 +1,191 @@ +""" +Title: 8-bit Integer Quantization in Keras +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) +Date created: 2025/10/14 +Last modified: 2025/10/14 +Description: Complete guide to using INT8 quantization in Keras and KerasHub. +Accelerator: GPU +""" + +""" +## What is INT8 quantization? + +Quantization lowers the numerical precision of weights and activations to reduce memory use +and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to +`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs +`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also +improve throughput and latency. Actual gains depend on your backend and device. + +### How it works + +Quantization maps real values to 8-bit integers with a scale: + +* Integer domain: `[-128, 127]` (256 levels). +* For a tensor (often per-output-channel for weights) with values `w`: + * Compute `a_max = max(abs(w))`. + * Set scale `s = (2 * a_max) / 256`. + * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`. +* Inference uses `q` and `s` to reconstruct effective weights on the fly + (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency. + +### Benefits + +* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O, + reducing the computation time does not reduce their overall runtime. INT8 reduces bytes + moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls; + this often helps more than increasing raw FLOPs. +* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8 + [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv, + boosting throughput on compute-limited layers. +* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest + drop (often ~1-5% depending on task/model/data). Always validate on your own dataset. + +### What Keras does in INT8 mode + +* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale. +* **Weights**: per-output-channel scales to preserve accuracy. +* **Activations**: **dynamic AbsMax** scaling computed at runtime. +* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph + is rewritten so you can run or save immediately. +""" + +""" +## Overview + +This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras: + +1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model) +2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model) +3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model) + +## Quantizing a minimal functional model. + +We build a small functional model, capture a baseline output, quantize to INT8 in-place, +and then compare outputs with an MSE metric. +""" + +import os +import numpy as np +import keras +from keras import layers + + +# Create a random number generator. +rng = np.random.default_rng() + +# Create a simple functional model. +inputs = keras.Input(shape=(10,)) +x = layers.Dense(32, activation="relu")(inputs) +outputs = layers.Dense(1, name="target")(x) +model = keras.Model(inputs, outputs) + +# Compile and train briefly to materialize meaningful weights. +model.compile(optimizer="adam", loss="mse") +x_train = rng.random((256, 10)).astype("float32") +y_train = rng.random((256, 1)).astype("float32") +model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) + +# Sample inputs for evaluation. +x_eval = rng.random((32, 10)).astype("float32") + +# Baseline (FP) outputs. +y_fp32 = model(x_eval) + +# Quantize the model in-place to INT8. +model.quantize("int8") + +# INT8 outputs after quantization. +y_int8 = model(x_eval) + +# Compute a simple MSE between FP and INT8 outputs. +mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8)) +print("Full-Precision vs INT8 MSE:", float(mse)) + + +""" +It is evident that the INT8 quantized model produces outputs close to the original FP32 +model, as indicated by the low MSE value. + +## Saving and reloading a quantized model + +You can use the standard Keras saving and loading APIs with quantized models. Quantization +is preserved when saving to `.keras` and loading back. +""" + +# Save the quantized model and reload to verify round-trip. +model.save("int8.keras") +int8_reloaded = keras.saving.load_model("int8.keras") +y_int8_reloaded = int8_reloaded(x_eval) +roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded)) +print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse)) + +""" +## Quantizing a KerasHub model + +All KerasHub models support the `.quantize(...)` API for post-training quantization, +and follow the same workflow as above. + +In this example, we will: + +1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b) + preset from KerasHub +2. Generate text using both the full-precision and quantized models, and compare outputs. +3. Save both models to disk and compute storage savings. +4. Reload the INT8 model and verify output consistency with the original quantized model. +""" + +from keras_hub.models import Gemma3CausalLM + +# Load from Gemma3 preset +gemma3 = Gemma3CausalLM.from_preset("gemma3_1b") + +# Generate text for a single prompt +output = gemma3.generate("Keras is a", max_length=50) +print("Full-precision output:", output) + +# Save FP32 Gemma3 model for size comparison. +gemma3.save_to_preset("gemma3_fp32") + +# Quantize in-place to INT8 and generate again +gemma3.quantize("int8") + +output = gemma3.generate("Keras is a", max_length=50) +print("Quantized output:", output) + +# Save INT8 Gemma3 model +gemma3.save_to_preset("gemma3_int8") + +# Reload and compare outputs +gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8") + +output = gemma3_int8.generate("Keras is a", max_length=50) +print("Quantized reloaded output:", output) + + +# Compute storage savings +def bytes_to_mib(n): + return n / (1024**2) + + +gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5") +gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5") + +gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1))) +print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB") +print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB") +print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%") + +""" +## Practical tips + +* Post-training quantization (PTQ) is a one-time operation; you cannot train a model + after quantizing it to INT8. +* Always materialize weights before quantization (e.g., `build()` or a forward pass). +* Expect small numerical deltas; quantify with a metric like MSE on a validation batch. +* Storage savings are immediate; speedups depend on backend/device kernels. + +## References + +* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations) +* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/) +""" diff --git a/guides/ipynb/int8_quantization_in_keras.ipynb b/guides/ipynb/int8_quantization_in_keras.ipynb new file mode 100644 index 0000000000..bc409ae23d --- /dev/null +++ b/guides/ipynb/int8_quantization_in_keras.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# 8-bit Integer Quantization in Keras\n", + "\n", + "**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
\n", + "**Date created:** 2025/10/14
\n", + "**Last modified:** 2025/10/14
\n", + "**Description:** Complete guide to using INT8 quantization in Keras and KerasHub." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## What is INT8 quantization?\n", + "\n", + "Quantization lowers the numerical precision of weights and activations to reduce memory use\n", + "and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to\n", + "`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs\n", + "`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also\n", + "improve throughput and latency. Actual gains depend on your backend and device.\n", + "\n", + "### How it works\n", + "\n", + "Quantization maps real values to 8-bit integers with a scale:\n", + "\n", + "* Integer domain: `[-128, 127]` (256 levels).\n", + "* For a tensor (often per-output-channel for weights) with values `w`:\n", + " * Compute `a_max = max(abs(w))`.\n", + " * Set scale `s = (2 * a_max) / 256`.\n", + " * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.\n", + "* Inference uses `q` and `s` to reconstruct effective weights on the fly\n", + " (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.\n", + "\n", + "### Benefits\n", + "\n", + "* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,\n", + " reducing the computation time does not reduce their overall runtime. INT8 reduces bytes\n", + " moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;\n", + " this often helps more than increasing raw FLOPs.\n", + "* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8\n", + " [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,\n", + " boosting throughput on compute-limited layers.\n", + "* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest\n", + " drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.\n", + "\n", + "### What Keras does in INT8 mode\n", + "\n", + "* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.\n", + "* **Weights**: per-output-channel scales to preserve accuracy.\n", + "* **Activations**: **dynamic AbsMax** scaling computed at runtime.\n", + "* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph\n", + " is rewritten so you can run or save immediately." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Overview\n", + "\n", + "This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:\n", + "\n", + "1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)\n", + "2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)\n", + "3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)\n", + "\n", + "## Quantizing a minimal functional model.\n", + "\n", + "We build a small functional model, capture a baseline output, quantize to INT8 in-place,\n", + "and then compare outputs with an MSE metric." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import keras\n", + "from keras import layers\n", + "\n", + "\n", + "# Create a random number generator.\n", + "rng = np.random.default_rng()\n", + "\n", + "# Create a simple functional model.\n", + "inputs = keras.Input(shape=(10,))\n", + "x = layers.Dense(32, activation=\"relu\")(inputs)\n", + "outputs = layers.Dense(1, name=\"target\")(x)\n", + "model = keras.Model(inputs, outputs)\n", + "\n", + "# Compile and train briefly to materialize meaningful weights.\n", + "model.compile(optimizer=\"adam\", loss=\"mse\")\n", + "x_train = rng.random((256, 10)).astype(\"float32\")\n", + "y_train = rng.random((256, 1)).astype(\"float32\")\n", + "model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)\n", + "\n", + "# Sample inputs for evaluation.\n", + "x_eval = rng.random((32, 10)).astype(\"float32\")\n", + "\n", + "# Baseline (FP) outputs.\n", + "y_fp32 = model(x_eval)\n", + "\n", + "# Quantize the model in-place to INT8.\n", + "model.quantize(\"int8\")\n", + "\n", + "# INT8 outputs after quantization.\n", + "y_int8 = model(x_eval)\n", + "\n", + "# Compute a simple MSE between FP and INT8 outputs.\n", + "mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))\n", + "print(\"Full-Precision vs INT8 MSE:\", float(mse))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "It is evident that the INT8 quantized model produces outputs close to the original FP32\n", + "model, as indicated by the low MSE value.\n", + "\n", + "## Saving and reloading a quantized model\n", + "\n", + "You can use the standard Keras saving and loading APIs with quantized models. Quantization\n", + "is preserved when saving to `.keras` and loading back." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# Save the quantized model and reload to verify round-trip.\n", + "model.save(\"int8.keras\")\n", + "int8_reloaded = keras.saving.load_model(\"int8.keras\")\n", + "y_int8_reloaded = int8_reloaded(x_eval)\n", + "roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))\n", + "print(\"MSE (INT8 vs reloaded-INT8):\", float(roundtrip_mse))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Quantizing a KerasHub model\n", + "\n", + "All KerasHub models support the `.quantize(...)` API for post-training quantization,\n", + "and follow the same workflow as above.\n", + "\n", + "In this example, we will:\n", + "\n", + "1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)\n", + " preset from KerasHub\n", + "2. Generate text using both the full-precision and quantized models, and compare outputs.\n", + "3. Save both models to disk and compute storage savings.\n", + "4. Reload the INT8 model and verify output consistency with the original quantized model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "from keras_hub.models import Gemma3CausalLM\n", + "\n", + "# Load from Gemma3 preset\n", + "gemma3 = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n", + "\n", + "# Generate text for a single prompt\n", + "output = gemma3.generate(\"Keras is a\", max_length=50)\n", + "print(\"Full-precision output:\", output)\n", + "\n", + "# Save FP32 Gemma3 model for size comparison.\n", + "gemma3.save_to_preset(\"gemma3_fp32\")\n", + "\n", + "# Quantize in-place to INT8 and generate again\n", + "gemma3.quantize(\"int8\")\n", + "\n", + "output = gemma3.generate(\"Keras is a\", max_length=50)\n", + "print(\"Quantized output:\", output)\n", + "\n", + "# Save INT8 Gemma3 model\n", + "gemma3.save_to_preset(\"gemma3_int8\")\n", + "\n", + "# Reload and compare outputs\n", + "gemma3_int8 = Gemma3CausalLM.from_preset(\"gemma3_int8\")\n", + "\n", + "output = gemma3_int8.generate(\"Keras is a\", max_length=50)\n", + "print(\"Quantized reloaded output:\", output)\n", + "\n", + "\n", + "# Compute storage savings\n", + "def bytes_to_mib(n):\n", + " return n / (1024**2)\n", + "\n", + "\n", + "gemma_fp32_size = os.path.getsize(\"gemma3_fp32/model.weights.h5\")\n", + "gemma_int8_size = os.path.getsize(\"gemma3_int8/model.weights.h5\")\n", + "\n", + "gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))\n", + "print(f\"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB\")\n", + "print(f\"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB\")\n", + "print(f\"Gemma3: Size reduction: {gemma_reduction:.1f}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Practical tips\n", + "\n", + "* Post-training quantization (PTQ) is a one-time operation; you cannot train a model\n", + " after quantizing it to INT8.\n", + "* Always materialize weights before quantization (e.g., `build()` or a forward pass).\n", + "* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.\n", + "* Storage savings are immediate; speedups depend on backend/device kernels.\n", + "\n", + "## References\n", + "\n", + "* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)\n", + "* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "int8_quantization_in_keras", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/guides/ipynb/quantization_overview.ipynb b/guides/ipynb/quantization_overview.ipynb index 5d3f43da80..5ed373771e 100644 --- a/guides/ipynb/quantization_overview.ipynb +++ b/guides/ipynb/quantization_overview.ipynb @@ -115,9 +115,12 @@ "import keras\n", "import numpy as np\n", "\n", + "# Create a random number generator.\n", + "rng = np.random.default_rng()\n", + "\n", "# Sample training data.\n", - "x_train = keras.ops.array(np.random.rand(100, 10))\n", - "y_train = keras.ops.array(np.random.rand(100, 1))\n", + "x_train = rng.random((100, 10)).astype(\"float32\")\n", + "y_train = rng.random((100, 1)).astype(\"float32\")\n", "\n", "# Build the model.\n", "model = keras.Sequential(\n", @@ -198,7 +201,7 @@ "\n", "Any composite layers that are built from the above (for example, `MultiHeadAttention`, `GroupedQueryAttention`, feed-forward blocks in Transformers) inherit quantization support by construction. This covers the majority of modern encoder-only and decoder-only Transformer architectures.\n", "\n", - "Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it\u2014without changing your training code.\n", + "Since all KerasHub models subclass `keras.Model`, they automatically support the `model.quantize(...)` API. In practice, this means you can take a popular LLM preset, call a single function to obtain an int8/int4/GPTQ-quantized variant, and then save or serve it—without changing your training code.\n", "\n", "## Practical guidance\n", "\n", @@ -236,4 +239,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/guides/md/int8_quantization_in_keras.md b/guides/md/int8_quantization_in_keras.md new file mode 100644 index 0000000000..c221a806a7 --- /dev/null +++ b/guides/md/int8_quantization_in_keras.md @@ -0,0 +1,223 @@ +# 8-bit Integer Quantization in Keras + +**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
+**Date created:** 2025/10/14
+**Last modified:** 2025/10/14
+**Description:** Complete guide to using INT8 quantization in Keras and KerasHub. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/int8_quantization_in_keras.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/int8_quantization_in_keras.py) + + + +--- +## What is INT8 quantization? + +Quantization lowers the numerical precision of weights and activations to reduce memory use +and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to +`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs +`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also +improve throughput and latency. Actual gains depend on your backend and device. + +### How it works + +Quantization maps real values to 8-bit integers with a scale: + +* Integer domain: `[-128, 127]` (256 levels). +* For a tensor (often per-output-channel for weights) with values `w`: + * Compute `a_max = max(abs(w))`. + * Set scale `s = (2 * a_max) / 256`. + * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`. +* Inference uses `q` and `s` to reconstruct effective weights on the fly + (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency. + +### Benefits + +* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O, + reducing the computation time does not reduce their overall runtime. INT8 reduces bytes + moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls; + this often helps more than increasing raw FLOPs. +* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8 + [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv, + boosting throughput on compute-limited layers. +* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest + drop (often ~1-5% depending on task/model/data). Always validate on your own dataset. + +### What Keras does in INT8 mode + +* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale. +* **Weights**: per-output-channel scales to preserve accuracy. +* **Activations**: **dynamic AbsMax** scaling computed at runtime. +* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph + is rewritten so you can run or save immediately. + +--- +## Overview + +This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras: + +1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model) +2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model) +3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model) + +--- +## Quantizing a minimal functional model. + +We build a small functional model, capture a baseline output, quantize to INT8 in-place, +and then compare outputs with an MSE metric. + + +```python +import os +import numpy as np +import keras +from keras import layers + + +# Create a random number generator. +rng = np.random.default_rng() + +# Create a simple functional model. +inputs = keras.Input(shape=(10,)) +x = layers.Dense(32, activation="relu")(inputs) +outputs = layers.Dense(1, name="target")(x) +model = keras.Model(inputs, outputs) + +# Compile and train briefly to materialize meaningful weights. +model.compile(optimizer="adam", loss="mse") +x_train = rng.random((256, 10)).astype("float32") +y_train = rng.random((256, 1)).astype("float32") +model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) + +# Sample inputs for evaluation. +x_eval = rng.random((32, 10)).astype("float32") + +# Baseline (FP) outputs. +y_fp32 = model(x_eval) + +# Quantize the model in-place to INT8. +model.quantize("int8") + +# INT8 outputs after quantization. +y_int8 = model(x_eval) + +# Compute a simple MSE between FP and INT8 outputs. +mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8)) +print("Full-Precision vs INT8 MSE:", float(mse)) + +``` + +
+``` +Full-Precision vs INT8 MSE: 4.982496648153756e-06 +``` +
+ +It is evident that the INT8 quantized model produces outputs close to the original FP32 +model, as indicated by the low MSE value. + +--- +## Saving and reloading a quantized model + +You can use the standard Keras saving and loading APIs with quantized models. Quantization +is preserved when saving to `.keras` and loading back. + + +```python +# Save the quantized model and reload to verify round-trip. +model.save("int8.keras") +int8_reloaded = keras.saving.load_model("int8.keras") +y_int8_reloaded = int8_reloaded(x_eval) +roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded)) +print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse)) +``` + +
+``` +MSE (INT8 vs reloaded-INT8): 0.0 +``` +
+ +--- +## Quantizing a KerasHub model + +All KerasHub models support the `.quantize(...)` API for post-training quantization, +and follow the same workflow as above. + +In this example, we will: + +1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b) + preset from KerasHub +2. Generate text using both the full-precision and quantized models, and compare outputs. +3. Save both models to disk and compute storage savings. +4. Reload the INT8 model and verify output consistency with the original quantized model. + + +```python +from keras_hub.models import Gemma3CausalLM + +# Load from Gemma3 preset +gemma3 = Gemma3CausalLM.from_preset("gemma3_1b") + +# Generate text for a single prompt +output = gemma3.generate("Keras is a", max_length=50) +print("Full-precision output:", output) + +# Save FP32 Gemma3 model for size comparison. +gemma3.save_to_preset("gemma3_fp32") + +# Quantize in-place to INT8 and generate again +gemma3.quantize("int8") + +output = gemma3.generate("Keras is a", max_length=50) +print("Quantized output:", output) + +# Save INT8 Gemma3 model +gemma3.save_to_preset("gemma3_int8") + +# Reload and compare outputs +gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8") + +output = gemma3_int8.generate("Keras is a", max_length=50) +print("Quantized reloaded output:", output) + + +# Compute storage savings +def bytes_to_mib(n): + return n / (1024**2) + + +gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5") +gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5") + +gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1))) +print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB") +print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB") +print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%") +``` + +
+``` +Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a +Quantized output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy +Quantized reloaded output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy +Gemma3: FP32 file size: 3815.32 MiB +Gemma3: INT8 file size: 957.81 MiB +Gemma3: Size reduction: 74.9% +``` +
+ +--- +## Practical tips + +* Post-training quantization (PTQ) is a one-time operation; you cannot train a model + after quantizing it to INT8. +* Always materialize weights before quantization (e.g., `build()` or a forward pass). +* Expect small numerical deltas; quantify with a metric like MSE on a validation batch. +* Storage savings are immediate; speedups depend on backend/device kernels. + +--- +## References + +* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations) +* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/) diff --git a/guides/md/quantization_overview.md b/guides/md/quantization_overview.md index a210273ea5..1fa9c128d3 100644 --- a/guides/md/quantization_overview.md +++ b/guides/md/quantization_overview.md @@ -87,9 +87,12 @@ Typical workflow: import keras import numpy as np +# Create a random number generator. +rng = np.random.default_rng() + # Sample training data. -x_train = keras.ops.array(np.random.rand(100, 10)) -y_train = keras.ops.array(np.random.rand(100, 1)) +x_train = rng.random((100, 10)).astype("float32") +y_train = rng.random((100, 1)).astype("float32") # Build the model. model = keras.Sequential( diff --git a/guides/quantization_overview.py b/guides/quantization_overview.py index 2e95dff7f9..e5daa42637 100644 --- a/guides/quantization_overview.py +++ b/guides/quantization_overview.py @@ -85,9 +85,12 @@ import keras import numpy as np +# Create a random number generator. +rng = np.random.default_rng() + # Sample training data. -x_train = keras.ops.array(np.random.rand(100, 10)) -y_train = keras.ops.array(np.random.rand(100, 1)) +x_train = rng.random((100, 10)).astype("float32") +y_train = rng.random((100, 1)).astype("float32") # Build the model. model = keras.Sequential( diff --git a/scripts/guides_master.py b/scripts/guides_master.py index 1b82d10d92..9f7d645370 100644 --- a/scripts/guides_master.py +++ b/scripts/guides_master.py @@ -127,6 +127,10 @@ "path": "quantization_overview", "title": "Quantization in Keras", }, + { + "path": "int8_quantization_in_keras", + "title": "8-bit integer quantization in Keras", + } # { # "path": "preprocessing_layers", # "title": "Working with preprocessing layers",