Skip to content

Commit bcadbc7

Browse files
address reviews
1 parent f65c70e commit bcadbc7

File tree

5 files changed

+58
-110
lines changed

5 files changed

+58
-110
lines changed

guides/int8_quantization_in_keras.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
44
Date created: 2025/10/14
55
Last modified: 2025/10/14
6-
Description: Complete guide to using INT8 quantization in Keras and KerasHub
6+
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
77
Accelerator: GPU
88
"""
99

@@ -28,12 +28,6 @@
2828
* Inference uses `q` and `s` to reconstruct effective weights on the fly
2929
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
3030
31-
### Trade-off
32-
33-
Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;
34-
tighter range reduces rounding error but risks more clipping. Per-channel scaling
35-
for weights can be used to recover accuracy when compared to per-tensor scaling.
36-
3731
### Benefits
3832
3933
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
@@ -77,7 +71,7 @@
7771

7872

7973
# Create a random number generator.
80-
rng = np.random.default_rng(7)
74+
rng = np.random.default_rng()
8175

8276
# Create a simple functional model.
8377
inputs = keras.Input(shape=(10,))

guides/ipynb/int8_quantization_in_keras.ipynb

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>\n",
1212
"**Date created:** 2025/10/14<br>\n",
1313
"**Last modified:** 2025/10/14<br>\n",
14-
"**Description:** Complete guide to using INT8 quantization in Keras and KerasHub"
14+
"**Description:** Complete guide to using INT8 quantization in Keras and KerasHub."
1515
]
1616
},
1717
{
@@ -24,8 +24,8 @@
2424
"\n",
2525
"Quantization lowers the numerical precision of weights and activations to reduce memory use\n",
2626
"and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to\n",
27-
"`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs\n",
28-
"`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also\n",
27+
"`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs\n",
28+
"`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also\n",
2929
"improve throughput and latency. Actual gains depend on your backend and device.\n",
3030
"\n",
3131
"### How it works\n",
@@ -36,34 +36,29 @@
3636
"* For a tensor (often per-output-channel for weights) with values `w`:\n",
3737
" * Compute `a_max = max(abs(w))`.\n",
3838
" * Set scale `s = (2 * a_max) / 256`.\n",
39-
" * Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`.\n",
39+
" * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.\n",
4040
"* Inference uses `q` and `s` to reconstruct effective weights on the fly\n",
41-
"(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.\n",
42-
"\n",
43-
"### Trade-off\n",
44-
"Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;\n",
45-
"tighter range reduces rounding error but risks more clipping. Per-channel scaling\n",
46-
"for weights can be used to recover accuracy when compared to per-tensor scaling.\n",
41+
" (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.\n",
4742
"\n",
4843
"### Benefits\n",
4944
"\n",
5045
"* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,\n",
51-
"reducing the computation time does not reduce their overall runtime. `int8` reduces bytes\n",
52-
"moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;\n",
53-
"this often helps more than increasing raw FLOPs.\n",
54-
"* Compute bound layers on supported hardware: On NVIDIA GPUs, int8\n",
55-
"[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,\n",
56-
"boosting throughput on compute-limited layers.\n",
57-
"* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest\n",
58-
"drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.\n",
59-
"\n",
60-
"### What Keras does in `int8` mode\n",
61-
"\n",
62-
"* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale.\n",
46+
" reducing the computation time does not reduce their overall runtime. INT8 reduces bytes\n",
47+
" moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;\n",
48+
" this often helps more than increasing raw FLOPs.\n",
49+
"* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8\n",
50+
" [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,\n",
51+
" boosting throughput on compute-limited layers.\n",
52+
"* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest\n",
53+
" drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.\n",
54+
"\n",
55+
"### What Keras does in INT8 mode\n",
56+
"\n",
57+
"* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.\n",
6358
"* **Weights**: per-output-channel scales to preserve accuracy.\n",
6459
"* **Activations**: **dynamic AbsMax** scaling computed at runtime.\n",
6560
"* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph\n",
66-
"is rewritten so you can run or save immediately."
61+
" is rewritten so you can run or save immediately."
6762
]
6863
},
6964
{
@@ -101,7 +96,7 @@
10196
"\n",
10297
"\n",
10398
"# Create a random number generator.\n",
104-
"rng = np.random.default_rng(7)\n",
99+
"rng = np.random.default_rng()\n",
105100
"\n",
106101
"# Create a simple functional model.\n",
107102
"inputs = keras.Input(shape=(10,))\n",
@@ -155,27 +150,9 @@
155150
},
156151
"outputs": [],
157152
"source": [
158-
"from keras import saving\n",
159-
"\n",
160-
"# Build a functional model.\n",
161-
"inputs = keras.Input(shape=(10,))\n",
162-
"x = layers.Dense(32, activation=\"relu\")(inputs)\n",
163-
"outputs = layers.Dense(1, name=\"target\")(x)\n",
164-
"model = keras.Model(inputs, outputs)\n",
165-
"model.build((None, 10))\n",
166-
"\n",
167-
"# Sample inputs for evaluation.\n",
168-
"x_eval = rng.random((32, 10)).astype(\"float32\")\n",
169-
"\n",
170-
"# Quantize the model in-place to INT8.\n",
171-
"model.quantize(\"int8\")\n",
172-
"\n",
173-
"# INT8 outputs after quantization.\n",
174-
"y_int8 = model(x_eval)\n",
175-
"\n",
176153
"# Save the quantized model and reload to verify round-trip.\n",
177154
"model.save(\"int8.keras\")\n",
178-
"int8_reloaded = saving.load_model(\"int8.keras\")\n",
155+
"int8_reloaded = keras.saving.load_model(\"int8.keras\")\n",
179156
"y_int8_reloaded = int8_reloaded(x_eval)\n",
180157
"roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))\n",
181158
"print(\"MSE (INT8 vs reloaded-INT8):\", float(roundtrip_mse))"
@@ -195,7 +172,7 @@
195172
"In this example, we will:\n",
196173
"\n",
197174
"1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)\n",
198-
"preset from KerasHub\n",
175+
" preset from KerasHub\n",
199176
"2. Generate text using both the full-precision and quantized models, and compare outputs.\n",
200177
"3. Save both models to disk and compute storage savings.\n",
201178
"4. Reload the INT8 model and verify output consistency with the original quantized model."
@@ -215,16 +192,16 @@
215192
"gemma3 = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n",
216193
"\n",
217194
"# Generate text for a single prompt\n",
218-
"output = gemma3.generate(\"Keras is a\", max_length=30)\n",
195+
"output = gemma3.generate(\"Keras is a\", max_length=50)\n",
219196
"print(\"Full-precision output:\", output)\n",
220197
"\n",
221-
"# Save FP32 Gemma3 model\n",
198+
"# Save FP32 Gemma3 model for size comparison.\n",
222199
"gemma3.save_to_preset(\"gemma3_fp32\")\n",
223200
"\n",
224201
"# Quantize in-place to INT8 and generate again\n",
225202
"gemma3.quantize(\"int8\")\n",
226203
"\n",
227-
"output = gemma3.generate(\"Keras is a\", max_length=30)\n",
204+
"output = gemma3.generate(\"Keras is a\", max_length=50)\n",
228205
"print(\"Quantized output:\", output)\n",
229206
"\n",
230207
"# Save INT8 Gemma3 model\n",
@@ -233,7 +210,7 @@
233210
"# Reload and compare outputs\n",
234211
"gemma3_int8 = Gemma3CausalLM.from_preset(\"gemma3_int8\")\n",
235212
"\n",
236-
"output = gemma3_int8.generate(\"Keras is a\", max_length=30)\n",
213+
"output = gemma3_int8.generate(\"Keras is a\", max_length=50)\n",
237214
"print(\"Quantized reloaded output:\", output)\n",
238215
"\n",
239216
"\n",

guides/ipynb/quantization_overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
"import numpy as np\n",
117117
"\n",
118118
"# Create a random number generator.\n",
119-
"rng = np.random.default_rng(7)\n",
119+
"rng = np.random.default_rng()\n",
120120
"\n",
121121
"# Sample training data.\n",
122122
"x_train = rng.random((100, 10)).astype(\"float32\")\n",

guides/md/int8_quantization_in_keras.md

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>
44
**Date created:** 2025/10/14<br>
55
**Last modified:** 2025/10/14<br>
6-
**Description:** Complete guide to using INT8 quantization in Keras and KerasHub
6+
**Description:** Complete guide to using INT8 quantization in Keras and KerasHub.
77

88

99
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/int8_quantization_in_keras.ipynb) <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/int8_quantization_in_keras.py)
@@ -15,8 +15,8 @@
1515

1616
Quantization lowers the numerical precision of weights and activations to reduce memory use
1717
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
18-
`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs
19-
`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also
18+
`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs
19+
`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
2020
improve throughput and latency. Actual gains depend on your backend and device.
2121

2222
### How it works
@@ -27,34 +27,29 @@ Quantization maps real values to 8-bit integers with a scale:
2727
* For a tensor (often per-output-channel for weights) with values `w`:
2828
* Compute `a_max = max(abs(w))`.
2929
* Set scale `s = (2 * a_max) / 256`.
30-
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`.
30+
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.
3131
* Inference uses `q` and `s` to reconstruct effective weights on the fly
32-
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
33-
34-
### Trade-off
35-
Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;
36-
tighter range reduces rounding error but risks more clipping. Per-channel scaling
37-
for weights can be used to recover accuracy when compared to per-tensor scaling.
32+
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
3833

3934
### Benefits
4035

4136
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
42-
reducing the computation time does not reduce their overall runtime. `int8` reduces bytes
43-
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
44-
this often helps more than increasing raw FLOPs.
45-
* Compute bound layers on supported hardware: On NVIDIA GPUs, int8
46-
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
47-
boosting throughput on compute-limited layers.
48-
* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest
49-
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
50-
51-
### What Keras does in `int8` mode
52-
53-
* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale.
37+
reducing the computation time does not reduce their overall runtime. INT8 reduces bytes
38+
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
39+
this often helps more than increasing raw FLOPs.
40+
* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8
41+
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
42+
boosting throughput on compute-limited layers.
43+
* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest
44+
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
45+
46+
### What Keras does in INT8 mode
47+
48+
* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.
5449
* **Weights**: per-output-channel scales to preserve accuracy.
5550
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
5651
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
57-
is rewritten so you can run or save immediately.
52+
is rewritten so you can run or save immediately.
5853

5954
---
6055
## Overview
@@ -80,7 +75,7 @@ from keras import layers
8075

8176

8277
# Create a random number generator.
83-
rng = np.random.default_rng(7)
78+
rng = np.random.default_rng()
8479

8580
# Create a simple functional model.
8681
inputs = keras.Input(shape=(10,))
@@ -114,7 +109,7 @@ print("Full-Precision vs INT8 MSE:", float(mse))
114109

115110
<div class="k-default-codeblock">
116111
```
117-
Full-Precision vs INT8 MSE: 7.132767677830998e-06
112+
Full-Precision vs INT8 MSE: 4.982496648153756e-06
118113
```
119114
</div>
120115

@@ -129,27 +124,9 @@ is preserved when saving to `.keras` and loading back.
129124

130125

131126
```python
132-
from keras import saving
133-
134-
# Build a functional model.
135-
inputs = keras.Input(shape=(10,))
136-
x = layers.Dense(32, activation="relu")(inputs)
137-
outputs = layers.Dense(1, name="target")(x)
138-
model = keras.Model(inputs, outputs)
139-
model.build((None, 10))
140-
141-
# Sample inputs for evaluation.
142-
x_eval = rng.random((32, 10)).astype("float32")
143-
144-
# Quantize the model in-place to INT8.
145-
model.quantize("int8")
146-
147-
# INT8 outputs after quantization.
148-
y_int8 = model(x_eval)
149-
150127
# Save the quantized model and reload to verify round-trip.
151128
model.save("int8.keras")
152-
int8_reloaded = saving.load_model("int8.keras")
129+
int8_reloaded = keras.saving.load_model("int8.keras")
153130
y_int8_reloaded = int8_reloaded(x_eval)
154131
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
155132
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
@@ -170,7 +147,7 @@ and follow the same workflow as above.
170147
In this example, we will:
171148

172149
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
173-
preset from KerasHub
150+
preset from KerasHub
174151
2. Generate text using both the full-precision and quantized models, and compare outputs.
175152
3. Save both models to disk and compute storage savings.
176153
4. Reload the INT8 model and verify output consistency with the original quantized model.
@@ -183,16 +160,16 @@ from keras_hub.models import Gemma3CausalLM
183160
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
184161

185162
# Generate text for a single prompt
186-
output = gemma3.generate("Keras is a", max_length=30)
163+
output = gemma3.generate("Keras is a", max_length=50)
187164
print("Full-precision output:", output)
188165

189-
# Save FP32 Gemma3 model
166+
# Save FP32 Gemma3 model for size comparison.
190167
gemma3.save_to_preset("gemma3_fp32")
191168

192169
# Quantize in-place to INT8 and generate again
193170
gemma3.quantize("int8")
194171

195-
output = gemma3.generate("Keras is a", max_length=30)
172+
output = gemma3.generate("Keras is a", max_length=50)
196173
print("Quantized output:", output)
197174

198175
# Save INT8 Gemma3 model
@@ -201,7 +178,7 @@ gemma3.save_to_preset("gemma3_int8")
201178
# Reload and compare outputs
202179
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
203180

204-
output = gemma3_int8.generate("Keras is a", max_length=30)
181+
output = gemma3_int8.generate("Keras is a", max_length=50)
205182
print("Quantized reloaded output:", output)
206183

207184

@@ -221,11 +198,11 @@ print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
221198

222199
<div class="k-default-codeblock">
223200
```
224-
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning
201+
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a
225202
226-
Quantized output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning
203+
Quantized output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a
227204
228-
Quantized reloaded output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning
205+
Quantized reloaded output: Keras is a Python library for deep learning. It is a library that is used to train and deploy deep learning models. It is a library that is used to train and deploy deep learning models. It is a library that is used to train
229206
230207
Gemma3: FP32 file size: 3815.32 MiB
231208
Gemma3: INT8 file size: 957.81 MiB

guides/md/quantization_overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ import keras
8888
import numpy as np
8989

9090
# Create a random number generator.
91-
rng = np.random.default_rng(7)
91+
rng = np.random.default_rng()
9292

9393
# Sample training data.
9494
x_train = rng.random((100, 10)).astype("float32")

0 commit comments

Comments
 (0)