Skip to content

Commit 23587db

Browse files
[ Docs ] Update FP8 example to use dynamic per token (#75)
* update for fp8 dyanmic * cleanup * format * fp8 example * updated per michael's comments * update example * update * tweak fruther * updated
1 parent 2ab6ae5 commit 23587db

File tree

2 files changed

+59
-136
lines changed

2 files changed

+59
-136
lines changed

examples/quantization_w8a8_fp8/README.md

+44-75
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# `fp8` Weight and Activation Quantization
22

3-
`llm-compressor` supports quantizing weights and activations to `fp8` for memory savings and inference acceleration with `vLLM`
3+
`llmcompressor` supports quantizing weights and activations to `fp8` for memory savings and inference acceleration with `vllm`
44

55
> `fp8` compuation is supported on Nvidia GPUs with compute capability > 8.9 (Ada Lovelace, Hopper).
66
@@ -9,9 +9,7 @@
99
To get started, install:
1010

1111
```bash
12-
git clone https://github.com/vllm-project/llm-compressor.git
13-
cd llm-compressor
14-
pip install -e .
12+
pip install llmcompressor==0.1.0
1513
```
1614

1715
## Quickstart
@@ -22,122 +20,93 @@ The example includes an end-to-end script for applying the quantization algorith
2220
python3 llama3_example.py
2321
```
2422

25-
The resulting model `Meta-Llama-3-8B-Instruct-W8A8-FP8` is ready to be loaded into vLLM.
23+
The resulting model `Meta-Llama-3-8B-Instruct-FP8-Dynamic` is ready to be loaded into vLLM.
2624

2725
## Code Walkthough
2826

29-
Now, we will step though the code in the example. There are four steps:
27+
Now, we will step though the code in the example. There are three steps:
3028
1) Load model
31-
2) Prepare calibration data
32-
3) Apply quantization
33-
4) Evaluate accuracy in vLLM
29+
2) Apply quantization
30+
3) Evaluate accuracy in vLLM
3431

3532
### 1) Load Model
3633

37-
Load the model using `SparseAutoModelForCausalLM`, which is a wrapper around `AutoModel` for handling quantized saving and loading. Note that `SparseAutoModel` is compatible with `accelerate` so you can load your model onto multiple GPUs if needed.
34+
Load the model using `SparseAutoModelForCausalLM`, which wraps `AutoModelForCausalLM` for saving and loading quantized models.
3835

3936
```python
4037
from llmcompressor.transformers import SparseAutoModelForCausalLM
4138
from transformers import AutoTokenizer
4239

4340
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
41+
4442
model = SparseAutoModelForCausalLM.from_pretrained(
45-
MODEL_ID, device_map="auto", torch_dtype="auto",
46-
)
43+
MODEL_ID, device_map="auto", torch_dtype="auto")
4744
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
4845
```
4946

50-
### 2) Prepare Calibration Data
51-
52-
Prepare the calibration data. When quantizing activations of a model to `fp8`, we need some sample data to estimate the activation scales. As a result, it is very useful to use calibration data that closely matches the type of data used in deployment. If you have fine-tuned a model, using a sample of your training data is a good idea.
53-
54-
In our case, we are quantizing an Instruction tuned generic model, so we will use the `ultrachat` dataset. Some best practices include:
55-
* 512 samples is a good place to start (increase if accuracy drops)
56-
* 2048 sequence length is a good place to start
57-
* Use the chat template or instrucion template that the model is trained with
58-
59-
```python
60-
from datasets import load_dataset
47+
### 2) Apply Quantization
6148

62-
NUM_CALIBRATION_SAMPLES=512
63-
MAX_SEQUENCE_LENGTH=2048
64-
65-
# Load dataset.
66-
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
67-
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
68-
69-
# Preprocess the data into the format the model is trained with.
70-
def preprocess(example):
71-
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False,)}
72-
ds = ds.map(preprocess)
73-
74-
# Tokenize the data (be careful with bos tokens - we need add_special_tokens=False since the chat_template already added it).
75-
def tokenize(sample):
76-
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
77-
ds = ds.map(tokenize, remove_columns=ds.column_names)
78-
```
49+
For `fp8` quantization, we can recover accuracy with simple PTQ quantization.
7950

80-
### 3) Apply Quantization
51+
We recommend targeting all `Linear` layers using the `FP8_DYNAMIC` scheme, which uses:
52+
- Static, per-channel quantization on the weights
53+
- Dynamic, per-token quantization on the activations
8154

82-
With the dataset ready, we will now apply quantization.
83-
84-
We first select the quantization algorithm. In our case, we will apply the default recipe for `fp8` (which uses static-per-tensor weights and static-per-tensor activations) to all linear layers.
85-
> See the `Recipes` documentation for more information on making complex recipes
55+
Since simple PTQ does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
8656

8757
```python
8858
from llmcompressor.transformers import oneshot
8959
from llmcompressor.modifiers.quantization import QuantizationModifier
9060

91-
# Configure the quantization algorithm to run.
92-
recipe = QuantizationModifier(targets="Linear", scheme="FP8", ignore=["lm_head"])
93-
94-
# Apply quantization.
95-
oneshot(
96-
model=model,
97-
dataset=ds,
98-
recipe=recipe,
99-
max_seq_length=MAX_SEQUENCE_LENGTH,
100-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
101-
)
102-
103-
# Save to disk compressed.
104-
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-FP8"
105-
model.save_pretrained(SAVE_DIR, save_compressed=True)
61+
# Configure the simple PTQ quantization
62+
recipe = QuantizationModifier(
63+
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
64+
65+
# Apply the quantization algorithm.
66+
oneshot(model=model, recipe=recipe)
67+
68+
# Save the model.
69+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
70+
model.save_pretrained(SAVE_DIR)
10671
tokenizer.save_pretrained(SAVE_DIR)
10772
```
10873

10974
We have successfully created an `fp8` model!
11075

111-
### 4) Evaluate Accuracy
76+
### 3) Evaluate Accuracy
77+
78+
Install `vllm` and `lm-evaluation-harness`:
11279

113-
With the model created, we can now load and run in vLLM (after installing).
80+
```bash
81+
pip install vllm lm_eval==0.4.3
82+
```
83+
84+
Load and run the model in `vllm`:
11485

11586
```python
11687
from vllm import LLM
117-
model = LLM("./Meta-Llama-3-8B-Instruct-W8A8-FP8")
88+
model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
89+
model.generate("Hello my name is")
11890
```
11991

120-
We can evaluate accuracy with `lm_eval` (`pip install lm_eval==v0.4.3`):
92+
Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`):
12193
> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations.
12294
123-
Run the following to test accuracy on GSM-8K:
124-
12595
```bash
126-
lm_eval --model vllm \
127-
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W8A8-FP8",add_bos_token=true \
128-
--tasks gsm8k \
129-
--num_fewshot 5 \
130-
--limit 250 \
131-
--batch_size 'auto'
96+
MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic
97+
lm_eval \
98+
--model vllm \
99+
--model_args pretrained=$MODEL,add_bos_token=True \
100+
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
132101
```
133102

134-
We can see the resulting scores look good!
103+
We can see the resulting scores look good:
135104

136105
```bash
137106
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
138107
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
139-
|gsm8k| 3|flexible-extract| 5|exact_match||0.776|± |0.0264|
140-
| | |strict-match | 5|exact_match||0.776|± |0.0264|
108+
|gsm8k| 3|flexible-extract| 5|exact_match||0.768|± |0.0268|
109+
| | |strict-match | 5|exact_match||0.768|± |0.0268|
141110
```
142111

143112
### Questions or Feature Request?
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,35 @@
1-
from datasets import load_dataset
21
from transformers import AutoTokenizer
32

43
from llmcompressor.modifiers.quantization import QuantizationModifier
54
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
65

7-
# Select model and load it.
86
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
7+
8+
# Load model.
99
model = SparseAutoModelForCausalLM.from_pretrained(
10-
MODEL_ID,
11-
device_map="auto",
12-
torch_dtype="auto",
10+
MODEL_ID, device_map="auto", torch_dtype="auto"
1311
)
1412
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1513

16-
# Select calibration dataset.
17-
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
18-
DATASET_SPLIT = "train_sft"
19-
20-
# Select number of samples. 512 samples is a good place to start.
21-
# Increasing the number of samples can improve accuracy.
22-
NUM_CALIBRATION_SAMPLES = 512
23-
MAX_SEQUENCE_LENGTH = 2048
24-
25-
# Load dataset and preprocess.
26-
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
27-
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
28-
29-
30-
def preprocess(example):
31-
return {
32-
"text": tokenizer.apply_chat_template(
33-
example["messages"],
34-
tokenize=False,
35-
)
36-
}
37-
38-
39-
ds = ds.map(preprocess)
40-
41-
42-
# Tokenize inputs.
43-
def tokenize(sample):
44-
return tokenizer(
45-
sample["text"],
46-
padding=False,
47-
max_length=MAX_SEQUENCE_LENGTH,
48-
truncation=True,
49-
add_special_tokens=False,
50-
)
51-
52-
53-
ds = ds.map(tokenize, remove_columns=ds.column_names)
54-
55-
# Configure the quantization algorithm to run.
14+
# Configure the quantization algorithm and scheme.
5615
# In this case, we:
57-
# * quantize the weights to fp8 with simple PTQ (static per tensor)
58-
# * quantize the activations to fp8 with simple PTQ (static per tensor)
59-
recipe = QuantizationModifier(targets="Linear", scheme="FP8", ignore=["lm_head"])
16+
# * quantize the weights to fp8 with per channel via ptq
17+
# * quantize the activations to fp8 with dynamic per token
18+
recipe = QuantizationModifier(
19+
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
20+
)
6021

6122
# Apply quantization.
62-
oneshot(
63-
model=model,
64-
dataset=ds,
65-
recipe=recipe,
66-
max_seq_length=MAX_SEQUENCE_LENGTH,
67-
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
68-
)
23+
oneshot(model=model, recipe=recipe)
6924

7025
# Confirm generations of the quantized model look sane.
71-
print("\n\n")
7226
print("========== SAMPLE GENERATION ==============")
7327
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
74-
output = model.generate(input_ids, max_new_tokens=100)
28+
output = model.generate(input_ids, max_new_tokens=20)
7529
print(tokenizer.decode(output[0]))
76-
print("==========================================\n\n")
30+
print("==========================================")
7731

78-
# Save to disk compressed.
79-
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-FP8"
80-
model.save_pretrained(SAVE_DIR, save_compressed=True)
32+
# Save to disk in compressed-tensors format.
33+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
34+
model.save_pretrained(SAVE_DIR)
8135
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)