Skip to content

Commit ff7a7ba

Browse files
Adds GPTQ Quantization Documentation
1 parent ade0c30 commit ff7a7ba

File tree

4 files changed

+553
-0
lines changed

4 files changed

+553
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Title: GPTQ Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/10/16
5+
Last modified: 2025/10/16
6+
Description: How to run weight-only GPTQ quantization for Keras & KerasHub models.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is GPTQ?
12+
13+
GPTQ ("Generative Pre-Training Quantization") is a post-training, weight-only
14+
quantization method that uses a second-order approximation of the loss (via a
15+
Hessian estimate) to minimize the error introduced when compressing weights to
16+
lower precision, typically 4-bit integers.
17+
18+
Unlike standard post-training techniques, GPTQ keeps activations in
19+
higher-precision and only quantizes the weights. This often preserves model
20+
quality in low bit-width settings while still providing large storage and
21+
memory savings.
22+
23+
Keras supports GPTQ quantization for KerasHub models via the
24+
`keras.quantizers.GPTQConfig` class.
25+
"""
26+
27+
"""
28+
## Load a KerasHub model
29+
30+
This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
31+
parameter) causal language model.
32+
33+
"""
34+
import keras
35+
from keras_hub.models import Gemma3CausalLM
36+
from datasets import load_dataset
37+
38+
39+
prompt = "Keras is a"
40+
41+
model = Gemma3CausalLM.from_preset("gemma3_1b")
42+
43+
outputs = model.generate(prompt, max_length=30)
44+
print(outputs)
45+
46+
"""
47+
## Configure & run GPTQ quantization
48+
49+
You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.
50+
51+
The GPTQ configuration requires a calibration dataset and tokenizer, which it
52+
uses to estimate the Hessian and quantization error. Here, we use a small slice
53+
of the WikiText-2 dataset for calibration.
54+
55+
You can tune several parameters to trade off speed, memory, and accuracy. The
56+
most important of these are `weight_bits` (the bit-width to quantize weights to)
57+
and `group_size` (the number of weights to quantize together). The group size
58+
controls the granularity of quantization: smaller groups typically yield better
59+
accuracy but are slower to quantize and may use more memory. A good starting
60+
point is `group_size=128` for 4-bit quantization (`weight_bits=4`).
61+
62+
In this example, we first prepare a tiny calibration set, and then run GPTQ on
63+
the model using the `.quantize(...)` API.
64+
"""
65+
66+
# Calibration slice (use a larger/representative set in practice)
67+
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
68+
69+
calibration_dataset = [
70+
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
71+
]
72+
73+
gptq_config = keras.quantizers.GPTQConfig(
74+
dataset=calibration_dataset,
75+
tokenizer=model.preprocessor.tokenizer,
76+
weight_bits=4,
77+
group_size=128,
78+
num_samples=256,
79+
sequence_length=256,
80+
hessian_damping=0.01,
81+
symmetric=False,
82+
activation_order=False,
83+
)
84+
85+
model.quantize("gptq", config=gptq_config)
86+
87+
outputs = model.generate(prompt, max_length=30)
88+
print(outputs)
89+
90+
"""
91+
## Model Export
92+
93+
The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just
94+
like any other KerasHub model.
95+
"""
96+
97+
model.save_to_preset("gemma3_gptq_w4gs128_preset")
98+
model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset")
99+
output = model_from_preset.generate(prompt, max_length=30)
100+
print(output)
101+
102+
"""
103+
## Performance & Benchmarking
104+
105+
Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).
106+
Baselines are FP32.
107+
108+
Dataset: WikiText-2.
109+
110+
111+
| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |
112+
| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |
113+
| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |
114+
| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |
115+
| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |
116+
| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |
117+
118+
119+
Detailed benchmarking numbers and scripts are available
120+
[here](https://github.com/keras-team/keras/pull/21641).
121+
122+
### Analysis
123+
124+
There is notable reduction in disk space and VRAM usage across all models, with
125+
disk space savings around 50% and VRAM savings ranging from 41% to 54%. The
126+
reported disk savings understate the true weight compression because presets
127+
also include non-weight assets.
128+
129+
Perplexity increases only marginally, indicating model quality is largely
130+
preserved after quantization.
131+
"""
132+
133+
"""
134+
## Practical tips
135+
136+
* GPTQ is weight-only; training after quantization is not supported.
137+
* Always use the model's own tokenizer for calibration.
138+
* Use a representative calibration set; small slices are only for demos.
139+
* Start with W4 group_size=128; tune per model/task.
140+
* Save to `.keras` or to a preset for reuse elsewhere.
141+
"""
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"colab_type": "text"
7+
},
8+
"source": [
9+
"# GPTQ Quantization in Keras\n",
10+
"\n",
11+
"**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)<br>\n",
12+
"**Date created:** 2025/10/16<br>\n",
13+
"**Last modified:** 2025/10/16<br>\n",
14+
"**Description:** How to run weight-only GPTQ quantization for Keras & KerasHub models."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {
20+
"colab_type": "text"
21+
},
22+
"source": [
23+
"## What is GPTQ?\n",
24+
"\n",
25+
"GPTQ (\"Generative Pre-Training Quantization\") is a post-training, weight-only\n",
26+
"quantization method that uses a second-order approximation of the loss (via a\n",
27+
"Hessian estimate) to minimize the error introduced when compressing weights to\n",
28+
"lower precision, typically 4-bit integers.\n",
29+
"\n",
30+
"Unlike standard post-training techniques, GPTQ keeps activations in\n",
31+
"higher-precision and only quantizes the weights. This often preserves model\n",
32+
"quality in low bit-width settings while still providing large storage and\n",
33+
"memory savings.\n",
34+
"\n",
35+
"Keras supports GPTQ quantization for KerasHub models via the\n",
36+
"`keras.quantizers.GPTQConfig` class."
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {
42+
"colab_type": "text"
43+
},
44+
"source": [
45+
"## Load a KerasHub model\n",
46+
"\n",
47+
"This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B\n",
48+
"parameter) causal language model."
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 0,
54+
"metadata": {
55+
"colab_type": "code"
56+
},
57+
"outputs": [],
58+
"source": [
59+
"import keras\n",
60+
"from keras_hub.models import Gemma3CausalLM\n",
61+
"from datasets import load_dataset\n",
62+
"\n",
63+
"\n",
64+
"prompt = \"Keras is a\"\n",
65+
"\n",
66+
"model = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n",
67+
"\n",
68+
"outputs = model.generate(prompt, max_length=30)\n",
69+
"print(outputs)"
70+
]
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"metadata": {
75+
"colab_type": "text"
76+
},
77+
"source": [
78+
"## Configure & run GPTQ quantization\n",
79+
"\n",
80+
"You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.\n",
81+
"\n",
82+
"The GPTQ configuration requires a calibration dataset and tokenizer, which it\n",
83+
"uses to estimate the Hessian and quantization error. Here, we use a small slice\n",
84+
"of the WikiText-2 dataset for calibration.\n",
85+
"\n",
86+
"You can tune several parameters to trade off speed, memory, and accuracy. The\n",
87+
"most important of these are `weight_bits` (the bit-width to quantize weights to)\n",
88+
"and `group_size` (the number of weights to quantize together). The group size\n",
89+
"controls the granularity of quantization: smaller groups typically yield better\n",
90+
"accuracy but are slower to quantize and may use more memory. A good starting\n",
91+
"point is `group_size=128` for 4-bit quantization (`weight_bits=4`).\n",
92+
"\n",
93+
"In this example, we first prepare a tiny calibration set, and then run GPTQ on\n",
94+
"the model using the `.quantize(...)` API."
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": 0,
100+
"metadata": {
101+
"colab_type": "code"
102+
},
103+
"outputs": [],
104+
"source": [
105+
"# Calibration slice (use a larger/representative set in practice)\n",
106+
"texts = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")[\"text\"]\n",
107+
"\n",
108+
"calibration_dataset = [\n",
109+
" s + \".\" for text in texts for s in map(str.strip, text.split(\".\")) if s\n",
110+
"]\n",
111+
"\n",
112+
"gptq_config = keras.quantizers.GPTQConfig(\n",
113+
" dataset=calibration_dataset,\n",
114+
" tokenizer=model.preprocessor.tokenizer,\n",
115+
" weight_bits=4,\n",
116+
" group_size=128,\n",
117+
" num_samples=256,\n",
118+
" sequence_length=256,\n",
119+
" hessian_damping=0.01,\n",
120+
" symmetric=False,\n",
121+
" activation_order=False,\n",
122+
")\n",
123+
"\n",
124+
"model.quantize(\"gptq\", config=gptq_config)\n",
125+
"\n",
126+
"outputs = model.generate(prompt, max_length=30)\n",
127+
"print(outputs)"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"metadata": {
133+
"colab_type": "text"
134+
},
135+
"source": [
136+
"## Model Export\n",
137+
"\n",
138+
"The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just\n",
139+
"like any other KerasHub model."
140+
]
141+
},
142+
{
143+
"cell_type": "code",
144+
"execution_count": 0,
145+
"metadata": {
146+
"colab_type": "code"
147+
},
148+
"outputs": [],
149+
"source": [
150+
"model.save_to_preset(\"gemma3_gptq_w4gs128_preset\")\n",
151+
"model_from_preset = Gemma3CausalLM.from_preset(\"gemma3_gptq_w4gs128_preset\")\n",
152+
"output = model_from_preset.generate(prompt, max_length=30)\n",
153+
"print(output)"
154+
]
155+
},
156+
{
157+
"cell_type": "markdown",
158+
"metadata": {
159+
"colab_type": "text"
160+
},
161+
"source": [
162+
"## Performance & Benchmarking\n",
163+
"\n",
164+
"Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).\n",
165+
"Baselines are FP32.\n",
166+
"\n",
167+
"Dataset: WikiText-2.\n",
168+
"\n",
169+
"\n",
170+
"| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |\n",
171+
"| ------------------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n",
172+
"| gpt2_causal_lm (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |\n",
173+
"| opt_causal_lm (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |\n",
174+
"| bloom_causal_lm (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |\n",
175+
"| gemma3_causal_lm (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |\n",
176+
"\n",
177+
"\n",
178+
"Detailed benchmarking numbers and scripts are available\n",
179+
"[here](https://github.com/keras-team/keras/pull/21641).\n",
180+
"\n",
181+
"### Analysis\n",
182+
"\n",
183+
"There is notable reduction in disk space and VRAM usage across all models, with\n",
184+
"disk space savings around 50% and VRAM savings ranging from 41% to 54%. The\n",
185+
"reported disk savings understate the true weight compression because presets\n",
186+
"also include non-weight assets.\n",
187+
"\n",
188+
"Perplexity increases only marginally, indicating model quality is largely\n",
189+
"preserved after quantization."
190+
]
191+
},
192+
{
193+
"cell_type": "markdown",
194+
"metadata": {
195+
"colab_type": "text"
196+
},
197+
"source": [
198+
"## Practical tips\n",
199+
"\n",
200+
"* GPTQ is weight-only; training after quantization is not supported.\n",
201+
"* Always use the model's own tokenizer for calibration.\n",
202+
"* Use a representative calibration set; small slices are only for demos.\n",
203+
"* Start with W4 group_size=128; tune per model/task.\n",
204+
"* Save to `.keras` or to a preset for reuse elsewhere."
205+
]
206+
}
207+
],
208+
"metadata": {
209+
"accelerator": "GPU",
210+
"colab": {
211+
"collapsed_sections": [],
212+
"name": "gptq_quantization_in_keras",
213+
"private_outputs": false,
214+
"provenance": [],
215+
"toc_visible": true
216+
},
217+
"kernelspec": {
218+
"display_name": "Python 3",
219+
"language": "python",
220+
"name": "python3"
221+
},
222+
"language_info": {
223+
"codemirror_mode": {
224+
"name": "ipython",
225+
"version": 3
226+
},
227+
"file_extension": ".py",
228+
"mimetype": "text/x-python",
229+
"name": "python",
230+
"nbconvert_exporter": "python",
231+
"pygments_lexer": "ipython3",
232+
"version": "3.7.0"
233+
}
234+
},
235+
"nbformat": 4,
236+
"nbformat_minor": 0
237+
}

0 commit comments

Comments
 (0)