Skip to content

Commit 72c19ac

Browse files
address reviews
1 parent 372dc47 commit 72c19ac

File tree

6 files changed

+157
-205
lines changed

6 files changed

+157
-205
lines changed

guides/int8_quantization_in_keras.py

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
44
Date created: 2025/10/14
55
Last modified: 2025/10/14
6-
Description: Minimal, end-to-end examples of INT8 post-training quantization in Keras.
6+
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
77
Accelerator: GPU
88
"""
99

@@ -12,47 +12,41 @@
1212
1313
Quantization lowers the numerical precision of weights and activations to reduce memory use
1414
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
15-
`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs
16-
`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also
15+
`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs
16+
`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
1717
improve throughput and latency. Actual gains depend on your backend and device.
1818
19-
### How it works (symmetric, linear)
19+
### How it works
2020
2121
Quantization maps real values to 8-bit integers with a scale:
2222
2323
* Integer domain: `[-128, 127]` (256 levels).
2424
* For a tensor (often per-output-channel for weights) with values `w`:
25-
2625
* Compute `a_max = max(abs(w))`.
2726
* Set scale `s = (2 * a_max) / 256`.
28-
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`.
27+
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.
2928
* Inference uses `q` and `s` to reconstruct effective weights on the fly
30-
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
31-
32-
### Trade-off
33-
Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;
34-
tighter range reduces rounding error but risks more clipping. Per-channel scaling
35-
for weights typically helps recover accuracy as compared to per-tensor scaling.
29+
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
3630
3731
### Benefits
3832
3933
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
40-
reducing the computation time does not reduce their overall runtime. `int8` reduces bytes
41-
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
42-
this often helps more than increasing raw FLOPs.
43-
* Compute bound layers on supported hardware: On NVIDIA GPUs, int8
44-
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
45-
boosting throughput on compute-limited layers.
46-
* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest
47-
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
48-
49-
### What Keras does in `int8` mode
50-
51-
* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale.
34+
reducing the computation time does not reduce their overall runtime. INT8 reduces bytes
35+
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
36+
this often helps more than increasing raw FLOPs.
37+
* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8
38+
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
39+
boosting throughput on compute-limited layers.
40+
* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest
41+
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
42+
43+
### What Keras does in INT8 mode
44+
45+
* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.
5246
* **Weights**: per-output-channel scales to preserve accuracy.
5347
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
5448
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
55-
is rewritten so you can run or save immediately.
49+
is rewritten so you can run or save immediately.
5650
"""
5751

5852
"""
@@ -76,7 +70,8 @@
7670
from keras import layers
7771

7872

79-
np.random.seed(7)
73+
# Create a random number generator.
74+
rng = np.random.default_rng()
8075

8176
# Create a simple functional model.
8277
inputs = keras.Input(shape=(10,))
@@ -86,12 +81,12 @@
8681

8782
# Compile and train briefly to materialize meaningful weights.
8883
model.compile(optimizer="adam", loss="mse")
89-
x_train = np.random.rand(256, 10).astype("float32")
90-
y_train = np.random.rand(256, 1).astype("float32")
84+
x_train = rng.random((256, 10)).astype("float32")
85+
y_train = rng.random((256, 1)).astype("float32")
9186
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
9287

9388
# Sample inputs for evaluation.
94-
x_eval = np.random.rand(32, 10).astype("float32")
89+
x_eval = rng.random((32, 10)).astype("float32")
9590

9691
# Baseline (FP) outputs.
9792
y_fp32 = model(x_eval)
@@ -111,47 +106,32 @@
111106
It is evident that the INT8 quantized model produces outputs close to the original FP32
112107
model, as indicated by the low MSE value.
113108
114-
## Saving and reloading a quantized model.
109+
## Saving and reloading a quantized model
115110
116111
You can use the standard Keras saving and loading APIs with quantized models. Quantization
117112
is preserved when saving to `.keras` and loading back.
118113
"""
119114

120-
from keras import saving
121-
122-
# Build a functional model.
123-
inputs = keras.Input(shape=(10,))
124-
x = layers.Dense(32, activation="relu")(inputs)
125-
outputs = layers.Dense(1, name="target")(x)
126-
model = keras.Model(inputs, outputs)
127-
model.build((None, 10))
128-
129-
# Quantize the model in-place to INT8.
130-
model.quantize("int8")
131-
132-
# INT8 outputs after quantization.
133-
y_int8 = model(x_eval)
134-
135115
# Save the quantized model and reload to verify round-trip.
136116
model.save("int8.keras")
137-
int8_reloaded = saving.load_model("int8.keras")
117+
int8_reloaded = keras.saving.load_model("int8.keras")
138118
y_int8_reloaded = int8_reloaded(x_eval)
139119
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
140120
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
141121

142122
"""
143-
## Quantizing a KerasHub model.
123+
## Quantizing a KerasHub model
144124
145125
All KerasHub models support the `.quantize(...)` API for post-training quantization,
146126
and follow the same workflow as above.
147127
148128
In this example, we will:
149129
150130
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
151-
preset from KerasHub
152-
1. Generate text using both the full-precision and quantized models, and compare outputs.
153-
1. Save both models to disk and compute storage savings.
154-
1. Reload the INT8 model and verify output consistency with the original quantized model.
131+
preset from KerasHub
132+
2. Generate text using both the full-precision and quantized models, and compare outputs.
133+
3. Save both models to disk and compute storage savings.
134+
4. Reload the INT8 model and verify output consistency with the original quantized model.
155135
"""
156136

157137
from keras_hub.models import Gemma3CausalLM
@@ -160,39 +140,41 @@
160140
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
161141

162142
# Generate text for a single prompt
163-
output = gemma3.generate("Keras is a", max_length=30)
143+
output = gemma3.generate("Keras is a", max_length=50)
164144
print("Full-precision output:", output)
165145

166-
# Save FP32 Gemma3 model
146+
# Save FP32 Gemma3 model for size comparison.
167147
gemma3.save_to_preset("gemma3_fp32")
168148

169149
# Quantize in-place to INT8 and generate again
170150
gemma3.quantize("int8")
171151

172-
output = gemma3.generate("Keras is a", max_length=30)
152+
output = gemma3.generate("Keras is a", max_length=50)
173153
print("Quantized output:", output)
174154

175155
# Save INT8 Gemma3 model
176156
gemma3.save_to_preset("gemma3_int8")
177157

158+
# Reload and compare outputs
159+
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
160+
161+
output = gemma3_int8.generate("Keras is a", max_length=50)
162+
print("Quantized reloaded output:", output)
163+
178164

165+
# Compute storage savings
179166
def bytes_to_mib(n):
180167
return n / (1024**2)
181168

182169

183-
gemma_fp32_size = os.path.getsize("gemma3_fp32")
184-
gemma_int8_size = os.path.getsize("gemma3_int8")
170+
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
171+
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
172+
185173
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
186174
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
187175
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
188176
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
189177

190-
# Reload and compare outputs
191-
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
192-
193-
output = gemma3_int8.generate("Keras is a", max_length=30)
194-
print("Quantized reloaded output:", output)
195-
196178
"""
197179
## Practical tips
198180

0 commit comments

Comments
 (0)