Skip to content

Commit 372dc47

Browse files
Adds int8 quantization docs
1 parent 62d2de3 commit 372dc47

File tree

4 files changed

+756
-0
lines changed

4 files changed

+756
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
Title: 8-bit Integer Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/10/14
5+
Last modified: 2025/10/14
6+
Description: Minimal, end-to-end examples of INT8 post-training quantization in Keras.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is INT8 quantization?
12+
13+
Quantization lowers the numerical precision of weights and activations to reduce memory use
14+
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
15+
`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs
16+
`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also
17+
improve throughput and latency. Actual gains depend on your backend and device.
18+
19+
### How it works (symmetric, linear)
20+
21+
Quantization maps real values to 8-bit integers with a scale:
22+
23+
* Integer domain: `[-128, 127]` (256 levels).
24+
* For a tensor (often per-output-channel for weights) with values `w`:
25+
26+
* Compute `a_max = max(abs(w))`.
27+
* Set scale `s = (2 * a_max) / 256`.
28+
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`.
29+
* Inference uses `q` and `s` to reconstruct effective weights on the fly
30+
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
31+
32+
### Trade-off
33+
Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;
34+
tighter range reduces rounding error but risks more clipping. Per-channel scaling
35+
for weights typically helps recover accuracy as compared to per-tensor scaling.
36+
37+
### Benefits
38+
39+
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
40+
reducing the computation time does not reduce their overall runtime. `int8` reduces bytes
41+
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
42+
this often helps more than increasing raw FLOPs.
43+
* Compute bound layers on supported hardware: On NVIDIA GPUs, int8
44+
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
45+
boosting throughput on compute-limited layers.
46+
* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest
47+
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
48+
49+
### What Keras does in `int8` mode
50+
51+
* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale.
52+
* **Weights**: per-output-channel scales to preserve accuracy.
53+
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
54+
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
55+
is rewritten so you can run or save immediately.
56+
"""
57+
58+
"""
59+
## Overview
60+
61+
This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:
62+
63+
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
64+
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
65+
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
66+
67+
## Quantizing a minimal functional model.
68+
69+
We build a small functional model, capture a baseline output, quantize to INT8 in-place,
70+
and then compare outputs with an MSE metric.
71+
"""
72+
73+
import os
74+
import numpy as np
75+
import keras
76+
from keras import layers
77+
78+
79+
np.random.seed(7)
80+
81+
# Create a simple functional model.
82+
inputs = keras.Input(shape=(10,))
83+
x = layers.Dense(32, activation="relu")(inputs)
84+
outputs = layers.Dense(1, name="target")(x)
85+
model = keras.Model(inputs, outputs)
86+
87+
# Compile and train briefly to materialize meaningful weights.
88+
model.compile(optimizer="adam", loss="mse")
89+
x_train = np.random.rand(256, 10).astype("float32")
90+
y_train = np.random.rand(256, 1).astype("float32")
91+
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
92+
93+
# Sample inputs for evaluation.
94+
x_eval = np.random.rand(32, 10).astype("float32")
95+
96+
# Baseline (FP) outputs.
97+
y_fp32 = model(x_eval)
98+
99+
# Quantize the model in-place to INT8.
100+
model.quantize("int8")
101+
102+
# INT8 outputs after quantization.
103+
y_int8 = model(x_eval)
104+
105+
# Compute a simple MSE between FP and INT8 outputs.
106+
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
107+
print("Full-Precision vs INT8 MSE:", float(mse))
108+
109+
110+
"""
111+
It is evident that the INT8 quantized model produces outputs close to the original FP32
112+
model, as indicated by the low MSE value.
113+
114+
## Saving and reloading a quantized model.
115+
116+
You can use the standard Keras saving and loading APIs with quantized models. Quantization
117+
is preserved when saving to `.keras` and loading back.
118+
"""
119+
120+
from keras import saving
121+
122+
# Build a functional model.
123+
inputs = keras.Input(shape=(10,))
124+
x = layers.Dense(32, activation="relu")(inputs)
125+
outputs = layers.Dense(1, name="target")(x)
126+
model = keras.Model(inputs, outputs)
127+
model.build((None, 10))
128+
129+
# Quantize the model in-place to INT8.
130+
model.quantize("int8")
131+
132+
# INT8 outputs after quantization.
133+
y_int8 = model(x_eval)
134+
135+
# Save the quantized model and reload to verify round-trip.
136+
model.save("int8.keras")
137+
int8_reloaded = saving.load_model("int8.keras")
138+
y_int8_reloaded = int8_reloaded(x_eval)
139+
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
140+
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
141+
142+
"""
143+
## Quantizing a KerasHub model.
144+
145+
All KerasHub models support the `.quantize(...)` API for post-training quantization,
146+
and follow the same workflow as above.
147+
148+
In this example, we will:
149+
150+
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
151+
preset from KerasHub
152+
1. Generate text using both the full-precision and quantized models, and compare outputs.
153+
1. Save both models to disk and compute storage savings.
154+
1. Reload the INT8 model and verify output consistency with the original quantized model.
155+
"""
156+
157+
from keras_hub.models import Gemma3CausalLM
158+
159+
# Load from Gemma3 preset
160+
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
161+
162+
# Generate text for a single prompt
163+
output = gemma3.generate("Keras is a", max_length=30)
164+
print("Full-precision output:", output)
165+
166+
# Save FP32 Gemma3 model
167+
gemma3.save_to_preset("gemma3_fp32")
168+
169+
# Quantize in-place to INT8 and generate again
170+
gemma3.quantize("int8")
171+
172+
output = gemma3.generate("Keras is a", max_length=30)
173+
print("Quantized output:", output)
174+
175+
# Save INT8 Gemma3 model
176+
gemma3.save_to_preset("gemma3_int8")
177+
178+
179+
def bytes_to_mib(n):
180+
return n / (1024**2)
181+
182+
183+
gemma_fp32_size = os.path.getsize("gemma3_fp32")
184+
gemma_int8_size = os.path.getsize("gemma3_int8")
185+
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
186+
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
187+
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
188+
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
189+
190+
# Reload and compare outputs
191+
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
192+
193+
output = gemma3_int8.generate("Keras is a", max_length=30)
194+
print("Quantized reloaded output:", output)
195+
196+
"""
197+
## Practical tips
198+
199+
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
200+
after quantizing it to INT8.
201+
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
202+
* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.
203+
* Storage savings are immediate; speedups depend on backend/device kernels.
204+
205+
## References
206+
207+
* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)
208+
* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)
209+
"""

0 commit comments

Comments
 (0)