|
3 | 3 | Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) |
4 | 4 | Date created: 2025/10/14 |
5 | 5 | Last modified: 2025/10/14 |
6 | | -Description: Minimal, end-to-end examples of INT8 post-training quantization in Keras. |
| 6 | +Description: Complete guide to using INT8 quantization in Keras and KerasHub. |
7 | 7 | Accelerator: GPU |
8 | 8 | """ |
9 | 9 |
|
|
12 | 12 |
|
13 | 13 | Quantization lowers the numerical precision of weights and activations to reduce memory use |
14 | 14 | and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to |
15 | | -`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs |
16 | | -`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also |
| 15 | +`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs |
| 16 | +`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also |
17 | 17 | improve throughput and latency. Actual gains depend on your backend and device. |
18 | 18 |
|
19 | | -### How it works (symmetric, linear) |
| 19 | +### How it works |
20 | 20 |
|
21 | 21 | Quantization maps real values to 8-bit integers with a scale: |
22 | 22 |
|
23 | 23 | * Integer domain: `[-128, 127]` (256 levels). |
24 | 24 | * For a tensor (often per-output-channel for weights) with values `w`: |
25 | | -
|
26 | 25 | * Compute `a_max = max(abs(w))`. |
27 | 26 | * Set scale `s = (2 * a_max) / 256`. |
28 | | - * Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`. |
| 27 | + * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`. |
29 | 28 | * Inference uses `q` and `s` to reconstruct effective weights on the fly |
30 | | -(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency. |
31 | | -
|
32 | | -### Trade-off |
33 | | -Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error; |
34 | | -tighter range reduces rounding error but risks more clipping. Per-channel scaling |
35 | | -for weights typically helps recover accuracy as compared to per-tensor scaling. |
| 29 | + (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency. |
36 | 30 |
|
37 | 31 | ### Benefits |
38 | 32 |
|
39 | 33 | * Memory / bandwidth bound models: When implementation spends most of its time on memory I/O, |
40 | | -reducing the computation time does not reduce their overall runtime. `int8` reduces bytes |
41 | | -moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls; |
42 | | -this often helps more than increasing raw FLOPs. |
43 | | -* Compute bound layers on supported hardware: On NVIDIA GPUs, int8 |
44 | | -[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv, |
45 | | -boosting throughput on compute-limited layers. |
46 | | -* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest |
47 | | -drop (often ~1-5% depending on task/model/data). Always validate on your own dataset. |
48 | | -
|
49 | | -### What Keras does in `int8` mode |
50 | | -
|
51 | | -* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale. |
| 34 | + reducing the computation time does not reduce their overall runtime. INT8 reduces bytes |
| 35 | + moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls; |
| 36 | + this often helps more than increasing raw FLOPs. |
| 37 | +* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8 |
| 38 | + [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv, |
| 39 | + boosting throughput on compute-limited layers. |
| 40 | +* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest |
| 41 | + drop (often ~1-5% depending on task/model/data). Always validate on your own dataset. |
| 42 | +
|
| 43 | +### What Keras does in INT8 mode |
| 44 | +
|
| 45 | +* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale. |
52 | 46 | * **Weights**: per-output-channel scales to preserve accuracy. |
53 | 47 | * **Activations**: **dynamic AbsMax** scaling computed at runtime. |
54 | 48 | * **Graph rewrite**: Quantization is applied after weights are trained and built; the graph |
55 | | -is rewritten so you can run or save immediately. |
| 49 | + is rewritten so you can run or save immediately. |
56 | 50 | """ |
57 | 51 |
|
58 | 52 | """ |
|
76 | 70 | from keras import layers |
77 | 71 |
|
78 | 72 |
|
79 | | -np.random.seed(7) |
| 73 | +# Create a random number generator. |
| 74 | +rng = np.random.default_rng() |
80 | 75 |
|
81 | 76 | # Create a simple functional model. |
82 | 77 | inputs = keras.Input(shape=(10,)) |
|
86 | 81 |
|
87 | 82 | # Compile and train briefly to materialize meaningful weights. |
88 | 83 | model.compile(optimizer="adam", loss="mse") |
89 | | -x_train = np.random.rand(256, 10).astype("float32") |
90 | | -y_train = np.random.rand(256, 1).astype("float32") |
| 84 | +x_train = rng.random((256, 10)).astype("float32") |
| 85 | +y_train = rng.random((256, 1)).astype("float32") |
91 | 86 | model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) |
92 | 87 |
|
93 | 88 | # Sample inputs for evaluation. |
94 | | -x_eval = np.random.rand(32, 10).astype("float32") |
| 89 | +x_eval = rng.random((32, 10)).astype("float32") |
95 | 90 |
|
96 | 91 | # Baseline (FP) outputs. |
97 | 92 | y_fp32 = model(x_eval) |
|
111 | 106 | It is evident that the INT8 quantized model produces outputs close to the original FP32 |
112 | 107 | model, as indicated by the low MSE value. |
113 | 108 |
|
114 | | -## Saving and reloading a quantized model. |
| 109 | +## Saving and reloading a quantized model |
115 | 110 |
|
116 | 111 | You can use the standard Keras saving and loading APIs with quantized models. Quantization |
117 | 112 | is preserved when saving to `.keras` and loading back. |
118 | 113 | """ |
119 | 114 |
|
120 | | -from keras import saving |
121 | | - |
122 | | -# Build a functional model. |
123 | | -inputs = keras.Input(shape=(10,)) |
124 | | -x = layers.Dense(32, activation="relu")(inputs) |
125 | | -outputs = layers.Dense(1, name="target")(x) |
126 | | -model = keras.Model(inputs, outputs) |
127 | | -model.build((None, 10)) |
128 | | - |
129 | | -# Quantize the model in-place to INT8. |
130 | | -model.quantize("int8") |
131 | | - |
132 | | -# INT8 outputs after quantization. |
133 | | -y_int8 = model(x_eval) |
134 | | - |
135 | 115 | # Save the quantized model and reload to verify round-trip. |
136 | 116 | model.save("int8.keras") |
137 | | -int8_reloaded = saving.load_model("int8.keras") |
| 117 | +int8_reloaded = keras.saving.load_model("int8.keras") |
138 | 118 | y_int8_reloaded = int8_reloaded(x_eval) |
139 | 119 | roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded)) |
140 | 120 | print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse)) |
141 | 121 |
|
142 | 122 | """ |
143 | | -## Quantizing a KerasHub model. |
| 123 | +## Quantizing a KerasHub model |
144 | 124 |
|
145 | 125 | All KerasHub models support the `.quantize(...)` API for post-training quantization, |
146 | 126 | and follow the same workflow as above. |
147 | 127 |
|
148 | 128 | In this example, we will: |
149 | 129 |
|
150 | 130 | 1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b) |
151 | | -preset from KerasHub |
152 | | -1. Generate text using both the full-precision and quantized models, and compare outputs. |
153 | | -1. Save both models to disk and compute storage savings. |
154 | | -1. Reload the INT8 model and verify output consistency with the original quantized model. |
| 131 | + preset from KerasHub |
| 132 | +2. Generate text using both the full-precision and quantized models, and compare outputs. |
| 133 | +3. Save both models to disk and compute storage savings. |
| 134 | +4. Reload the INT8 model and verify output consistency with the original quantized model. |
155 | 135 | """ |
156 | 136 |
|
157 | 137 | from keras_hub.models import Gemma3CausalLM |
|
160 | 140 | gemma3 = Gemma3CausalLM.from_preset("gemma3_1b") |
161 | 141 |
|
162 | 142 | # Generate text for a single prompt |
163 | | -output = gemma3.generate("Keras is a", max_length=30) |
| 143 | +output = gemma3.generate("Keras is a", max_length=50) |
164 | 144 | print("Full-precision output:", output) |
165 | 145 |
|
166 | | -# Save FP32 Gemma3 model |
| 146 | +# Save FP32 Gemma3 model for size comparison. |
167 | 147 | gemma3.save_to_preset("gemma3_fp32") |
168 | 148 |
|
169 | 149 | # Quantize in-place to INT8 and generate again |
170 | 150 | gemma3.quantize("int8") |
171 | 151 |
|
172 | | -output = gemma3.generate("Keras is a", max_length=30) |
| 152 | +output = gemma3.generate("Keras is a", max_length=50) |
173 | 153 | print("Quantized output:", output) |
174 | 154 |
|
175 | 155 | # Save INT8 Gemma3 model |
176 | 156 | gemma3.save_to_preset("gemma3_int8") |
177 | 157 |
|
| 158 | +# Reload and compare outputs |
| 159 | +gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8") |
| 160 | + |
| 161 | +output = gemma3_int8.generate("Keras is a", max_length=50) |
| 162 | +print("Quantized reloaded output:", output) |
| 163 | + |
178 | 164 |
|
| 165 | +# Compute storage savings |
179 | 166 | def bytes_to_mib(n): |
180 | 167 | return n / (1024**2) |
181 | 168 |
|
182 | 169 |
|
183 | | -gemma_fp32_size = os.path.getsize("gemma3_fp32") |
184 | | -gemma_int8_size = os.path.getsize("gemma3_int8") |
| 170 | +gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5") |
| 171 | +gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5") |
| 172 | + |
185 | 173 | gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1))) |
186 | 174 | print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB") |
187 | 175 | print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB") |
188 | 176 | print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%") |
189 | 177 |
|
190 | | -# Reload and compare outputs |
191 | | -gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8") |
192 | | - |
193 | | -output = gemma3_int8.generate("Keras is a", max_length=30) |
194 | | -print("Quantized reloaded output:", output) |
195 | | - |
196 | 178 | """ |
197 | 179 | ## Practical tips |
198 | 180 |
|
|
0 commit comments