Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 KV Cache quant example #113

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions examples/quantization_kv_cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# `fp8` Weight, Activation, and KV Cache Quantization

`llmcompressor` now supports quantizing weights, activations, and KV cache to `fp8` for memory savings and inference acceleration with `vllm`.

> `fp8` computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper).

## Installation

To get started, install llmcompressor from source as this feature is new:

```bash
pip install git+https://github.com/vllm-project/llm-compressor.git@cb98f34d4ec9dd175e6995d12fb02dec39c6f27a
```

## Quickstart

The example includes an end-to-end script for applying the quantization algorithm:

```bash
python3 llama3_fp8_kv_example.py
```

The resulting model `Meta-Llama-3-8B-Instruct-FP8-KV` is ready to be loaded into vLLM.

## Code Walkthrough

Let's walk through the main steps of the quantization process:

1. Load model
2. Prepare calibration data
3. Apply quantization
4. Evaluate and save the model

### 1. Load Model

Load the model using `SparseAutoModelForCausalLM`:

```python
from llmcompressor.transformers import SparseAutoModelForCausalLM
from transformers import AutoTokenizer

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```

### 2. Prepare Calibration Data

Prepare the calibration data using the `ultrachat` dataset:

```python
from datasets import load_dataset

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def process_and_tokenize(example):
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(text, padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
```

### 3. Apply Quantization

Configure and apply the FP8 quantization for weights, activations, and KV cache.
Notice the new `kv_cache_scheme` section:

```python
from llmcompressor.transformers import oneshot

recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
targets: ["Linear"]
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
```

### 4. Evaluate and Save the Model

Test the quantized model with a sample generation:

```python
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
```

Save the quantized model:

```python
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```

For running the model in vLLM, make sure to specify the `kv_cache_dtype="fp8"` argument to enable quantization of the kv cache, and thus usage of your calibrated scales.

## Evaluating Accuracy

To evaluate the accuracy of your quantized model:

1. Install `vllm` and `lm-evaluation-harness`:

```bash
pip install "vllm>=0.5.5" lm_eval==0.4.3
```

2. Run an evaluation (e.g., on GSM-8K):

```bash
MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-KV
lm_eval \
--model vllm \
--model_args pretrained=$MODEL,kv_cache_dtype=fp8,add_bos_token=True \
--tasks gsm8k --num_fewshot 5 --batch_size auto
```

```
vllm (pretrained=Meta-Llama-3-8B-Instruct-FP8-KV,kv_cache_dtype=fp8,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7748|± |0.0115|
| | |strict-match | 5|exact_match|↑ |0.7763|± |0.0115|
```

Note: Include `add_bos_token=True` as quantized models can be sensitive to the presence of the `bos` token.

## Questions or Feature Requests?

Please open an issue on `vllm-project/llm-compressor`.
87 changes: 87 additions & 0 deletions examples/quantization_kv_cache/llama3_fp8_kv_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def process_and_tokenize(example):
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(text, padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per-tensor scales
# * quantize the activations to fp8 with per-tensor scales
# * quantize the kv cache to fp8 with per-tensor scales
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
targets: ["Linear"]
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Loading