|
11 | 11 | "**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>\n", |
12 | 12 | "**Date created:** 2025/10/14<br>\n", |
13 | 13 | "**Last modified:** 2025/10/14<br>\n", |
14 | | - "**Description:** Complete guide to using INT8 quantization in Keras and KerasHub" |
| 14 | + "**Description:** Complete guide to using INT8 quantization in Keras and KerasHub." |
15 | 15 | ] |
16 | 16 | }, |
17 | 17 | { |
|
24 | 24 | "\n", |
25 | 25 | "Quantization lowers the numerical precision of weights and activations to reduce memory use\n", |
26 | 26 | "and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to\n", |
27 | | - "`float16` halves the memory requirements; `float32` to `int8` is ~4x smaller (and ~2x vs\n", |
28 | | - "`float16`). On hardware with low-precision kernels (e.g., Tensor Cores), this can also\n", |
| 27 | + "`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs\n", |
| 28 | + "`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also\n", |
29 | 29 | "improve throughput and latency. Actual gains depend on your backend and device.\n", |
30 | 30 | "\n", |
31 | 31 | "### How it works\n", |
|
36 | 36 | "* For a tensor (often per-output-channel for weights) with values `w`:\n", |
37 | 37 | " * Compute `a_max = max(abs(w))`.\n", |
38 | 38 | " * Set scale `s = (2 * a_max) / 256`.\n", |
39 | | - " * Quantize: `q = clip(round(w / s), -128, 127)` (stored as int8) and keep `s`.\n", |
| 39 | + " * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.\n", |
40 | 40 | "* Inference uses `q` and `s` to reconstruct effective weights on the fly\n", |
41 | | - "(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.\n", |
42 | | - "\n", |
43 | | - "### Trade-off\n", |
44 | | - "Wider dynamic range (larger `a_max`) reduces clipping but increases rounding error;\n", |
45 | | - "tighter range reduces rounding error but risks more clipping. Per-channel scaling\n", |
46 | | - "for weights can be used to recover accuracy when compared to per-tensor scaling.\n", |
| 41 | + " (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.\n", |
47 | 42 | "\n", |
48 | 43 | "### Benefits\n", |
49 | 44 | "\n", |
50 | 45 | "* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,\n", |
51 | | - "reducing the computation time does not reduce their overall runtime. `int8` reduces bytes\n", |
52 | | - "moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;\n", |
53 | | - "this often helps more than increasing raw FLOPs.\n", |
54 | | - "* Compute bound layers on supported hardware: On NVIDIA GPUs, int8\n", |
55 | | - "[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,\n", |
56 | | - "boosting throughput on compute-limited layers.\n", |
57 | | - "* Accuracy: Many models retain near-FP accuracy with `float16`; `int8` may introduce a modest\n", |
58 | | - "drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.\n", |
59 | | - "\n", |
60 | | - "### What Keras does in `int8` mode\n", |
61 | | - "\n", |
62 | | - "* **Mapping**: Symmetric, linear quantization with `int8` plus a floating-point scale.\n", |
| 46 | + " reducing the computation time does not reduce their overall runtime. INT8 reduces bytes\n", |
| 47 | + " moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;\n", |
| 48 | + " this often helps more than increasing raw FLOPs.\n", |
| 49 | + "* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8\n", |
| 50 | + " [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,\n", |
| 51 | + " boosting throughput on compute-limited layers.\n", |
| 52 | + "* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest\n", |
| 53 | + " drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.\n", |
| 54 | + "\n", |
| 55 | + "### What Keras does in INT8 mode\n", |
| 56 | + "\n", |
| 57 | + "* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.\n", |
63 | 58 | "* **Weights**: per-output-channel scales to preserve accuracy.\n", |
64 | 59 | "* **Activations**: **dynamic AbsMax** scaling computed at runtime.\n", |
65 | 60 | "* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph\n", |
66 | | - "is rewritten so you can run or save immediately." |
| 61 | + " is rewritten so you can run or save immediately." |
67 | 62 | ] |
68 | 63 | }, |
69 | 64 | { |
|
101 | 96 | "\n", |
102 | 97 | "\n", |
103 | 98 | "# Create a random number generator.\n", |
104 | | - "rng = np.random.default_rng(7)\n", |
| 99 | + "rng = np.random.default_rng()\n", |
105 | 100 | "\n", |
106 | 101 | "# Create a simple functional model.\n", |
107 | 102 | "inputs = keras.Input(shape=(10,))\n", |
|
155 | 150 | }, |
156 | 151 | "outputs": [], |
157 | 152 | "source": [ |
158 | | - "from keras import saving\n", |
159 | | - "\n", |
160 | | - "# Build a functional model.\n", |
161 | | - "inputs = keras.Input(shape=(10,))\n", |
162 | | - "x = layers.Dense(32, activation=\"relu\")(inputs)\n", |
163 | | - "outputs = layers.Dense(1, name=\"target\")(x)\n", |
164 | | - "model = keras.Model(inputs, outputs)\n", |
165 | | - "model.build((None, 10))\n", |
166 | | - "\n", |
167 | | - "# Sample inputs for evaluation.\n", |
168 | | - "x_eval = rng.random((32, 10)).astype(\"float32\")\n", |
169 | | - "\n", |
170 | | - "# Quantize the model in-place to INT8.\n", |
171 | | - "model.quantize(\"int8\")\n", |
172 | | - "\n", |
173 | | - "# INT8 outputs after quantization.\n", |
174 | | - "y_int8 = model(x_eval)\n", |
175 | | - "\n", |
176 | 153 | "# Save the quantized model and reload to verify round-trip.\n", |
177 | 154 | "model.save(\"int8.keras\")\n", |
178 | | - "int8_reloaded = saving.load_model(\"int8.keras\")\n", |
| 155 | + "int8_reloaded = keras.saving.load_model(\"int8.keras\")\n", |
179 | 156 | "y_int8_reloaded = int8_reloaded(x_eval)\n", |
180 | 157 | "roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))\n", |
181 | 158 | "print(\"MSE (INT8 vs reloaded-INT8):\", float(roundtrip_mse))" |
|
195 | 172 | "In this example, we will:\n", |
196 | 173 | "\n", |
197 | 174 | "1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)\n", |
198 | | - "preset from KerasHub\n", |
| 175 | + " preset from KerasHub\n", |
199 | 176 | "2. Generate text using both the full-precision and quantized models, and compare outputs.\n", |
200 | 177 | "3. Save both models to disk and compute storage savings.\n", |
201 | 178 | "4. Reload the INT8 model and verify output consistency with the original quantized model." |
|
215 | 192 | "gemma3 = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n", |
216 | 193 | "\n", |
217 | 194 | "# Generate text for a single prompt\n", |
218 | | - "output = gemma3.generate(\"Keras is a\", max_length=30)\n", |
| 195 | + "output = gemma3.generate(\"Keras is a\", max_length=50)\n", |
219 | 196 | "print(\"Full-precision output:\", output)\n", |
220 | 197 | "\n", |
221 | | - "# Save FP32 Gemma3 model\n", |
| 198 | + "# Save FP32 Gemma3 model for size comparison.\n", |
222 | 199 | "gemma3.save_to_preset(\"gemma3_fp32\")\n", |
223 | 200 | "\n", |
224 | 201 | "# Quantize in-place to INT8 and generate again\n", |
225 | 202 | "gemma3.quantize(\"int8\")\n", |
226 | 203 | "\n", |
227 | | - "output = gemma3.generate(\"Keras is a\", max_length=30)\n", |
| 204 | + "output = gemma3.generate(\"Keras is a\", max_length=50)\n", |
228 | 205 | "print(\"Quantized output:\", output)\n", |
229 | 206 | "\n", |
230 | 207 | "# Save INT8 Gemma3 model\n", |
|
233 | 210 | "# Reload and compare outputs\n", |
234 | 211 | "gemma3_int8 = Gemma3CausalLM.from_preset(\"gemma3_int8\")\n", |
235 | 212 | "\n", |
236 | | - "output = gemma3_int8.generate(\"Keras is a\", max_length=30)\n", |
| 213 | + "output = gemma3_int8.generate(\"Keras is a\", max_length=50)\n", |
237 | 214 | "print(\"Quantized reloaded output:\", output)\n", |
238 | 215 | "\n", |
239 | 216 | "\n", |
|
0 commit comments