Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into kylesayrs/better-remote-co…
Browse files Browse the repository at this point in the history
…de-check
  • Loading branch information
kylesayrs committed Mar 11, 2025
2 parents 5eeed1b + 2a59554 commit 3fff13a
Show file tree
Hide file tree
Showing 136 changed files with 1,900 additions and 1,772 deletions.
34 changes: 33 additions & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,41 @@ env:
CLEARML_API_SECRET_KEY: ${{ secrets.CLEARML_API_SECRET_KEY }}

jobs:
detect-changes:
runs-on: ubuntu-latest

outputs:
changes-present: ${{ steps.changed-files.outputs.any_modified }}

steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
**
!examples/**
!tests/e2e/**
!tests/lmeval/**
!tests/examples/**
!**/*.md
!.github/**
.github/workflows/test-check-transformers.yaml
- name: Log relevant output
run: |
echo "changes-present: ${{ steps.changed-files.outputs.any_modified }}"
echo "all modified files: ${{ steps.changed-files.outputs.all_modified_files }}"
shell: bash

transformers-tests:
needs: [detect-changes]
runs-on: gcp-k8s-vllm-l4-solo
if: contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push'
if: (contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push') && needs.detect-changes.outputs.changes-present == 'true'
steps:
- uses: actions/setup-python@v5
with:
Expand Down
42 changes: 34 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,32 @@
* SmoothQuant
* SparseGPT

### When to Use Which Optimization

#### PTQ
PTQ is performed to reduce the precision of quantizable weights (e.g., linear layers) to a lower bit-width. Supported formats are:

##### [W4A16](./examples/quantization_w4a16/README.md)
- Uses GPTQ to compress weights to 4 bits. Requires calibration dataset.
- Useful speed ups in low QPS regimes with more weight compression.
- Recommended for any GPUs types.
##### [W8A8-INT8](./examples/quantization_w8a8_int8/README.md)
- Uses channel-wise quantization to compress weights to 8 bits using GPTQ, and uses dynamic per-token quantization to compress activations to 8 bits. Requires calibration dataset for weight quantization. Activation quantization is carried out during inference on vLLM.
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
- Recommended for NVIDIA GPUs with compute capability <8.9 (Ampere, Turing, Volta, Pascal, or older).
##### [W8A8-FP8](./examples/quantization_w8a8_fp8/README.md)
- Uses channel-wise quantization to compress weights to 8 bits, and uses dynamic per-token quantization to compress activations to 8 bits. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
- Recommended for NVIDIA GPUs with compute capability >8.9 (Hopper and Ada Lovelace).

#### Sparsification
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:

##### [2:4-Sparsity with FP8 Weight, FP8 Input Activation](./examples/sparse_2of4_quantization_fp8/README.md)
- Uses (1) semi-structured sparsity (SparseGPT), where, for every four contiguous weights in a tensor, two are set to zero. (2) Uses channel-wise quantization to compress weights to 8 bits and dynamic per-token quantization to compress activations to 8 bits.
- Useful for better inference than W8A8-fp8, with almost no drop in its evaluation score [blog](https://neuralmagic.com/blog/24-sparse-llama-fp8-sota-performance-for-nvidia-hopper-gpus/). Note: Small models may experience accuracy drops when the remaining non-zero weights are insufficient to recapitulate the original distribution.
- Recommended for compute capability >8.9 (Hopper and Ada Lovelace).


## Installation

Expand All @@ -35,16 +61,16 @@ pip install llmcompressor
### End-to-End Examples

Applying quantization with `llmcompressor`:
* [Activation quantization to `int8`](examples/quantization_w8a8_int8)
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8)
* [Weight only quantization to `int4`](examples/quantization_w4a16)
* [Quantizing MoE LLMs](examples/quantizing_moe)
* [Quantizing Vision-Language Models](examples/multimodal_vision)
* [Quantizing Audio-Language Models](examples/multimodal_audio)
* [Activation quantization to `int8`](examples/quantization_w8a8_int8/README.md)
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8/README.md)
* [Weight only quantization to `int4`](examples/quantization_w4a16/README.md)
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md)

### User Guides
Deep dives into advanced usage of `llmcompressor`:
* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate)
* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate/README.md)


## Quick Tour
Expand All @@ -58,7 +84,7 @@ Quantization is applied by selecting an algorithm and calling the `oneshot` API.
```python
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor import oneshot

# Select quantization algorithm. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
Expand Down
2 changes: 1 addition & 1 deletion examples/big_models_with_accelerate/cpu_offloading_fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
OUTPUT_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
Expand Down
2 changes: 1 addition & 1 deletion examples/big_models_with_accelerate/multi_gpu_int8.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic"
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_audio/whisper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from datasets import load_dataset
from transformers import WhisperProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableWhisperForConditionalGeneration

# Select model and load it.
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/idefics3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from PIL import Image
from transformers import AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration

# Load model.
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from PIL import Image
from transformers import AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration

# Load model.
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/mllama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from PIL import Image
from transformers import AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableMllamaForConditionalGeneration

# Load model.
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/phi3_vision_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# Load model.
model_id = "microsoft/Phi-3-vision-128k-instruct"
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from PIL import Image
from transformers import AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration

# Load model.
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal_vision/qwen2_vl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration

# Load model.
Expand Down
10 changes: 6 additions & 4 deletions examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
bf16 = False # using full precision for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1
preprocessing_num_workers = 8

# this will run the recipe stage by stage:
# oneshot sparsification -> finetuning -> oneshot quantization
Expand All @@ -52,10 +53,11 @@
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
preprocessing_num_workers=preprocessing_num_workers,
)
logger.info(
"Note: llcompressor does not currently support running ",
"compressed models in the marlin-24 format. The model ",
"produced from this example can be run on vLLM with ",
"dtype=torch.float16",
"llmcompressor does not currently support running compressed models in the marlin24 format." # noqa
)
logger.info(
"The model produced from this example can be run on vLLM with dtype=torch.float16"
)
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Configure and apply the FP8 quantization for weights, activations, and KV cache.
Notice the new `kv_cache_scheme` section:

```python
from llmcompressor.transformers import oneshot
from llmcompressor import oneshot

recipe = """
quant_stage:
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/gemma2_fp8_kv_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.transformers import oneshot
from llmcompressor import oneshot

# Select model and load it.
MODEL_ID = "google/gemma-2-9b-it"
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/llama3_fp8_kv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.transformers import oneshot
from llmcompressor import oneshot

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_kv_cache/phi3.5_fp8_kv_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.transformers import oneshot
from llmcompressor import oneshot

# Select model and load it.
# Phi-3.5 is a special case for KV cache quantization because it has
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w4a16/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ In our case, we will apply the default GPTQ recipe for `int4` (which uses static
> See the `Recipes` documentation for more information on making complex recipes
```python
from llmcompressor.transformers import oneshot
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier

# Configure the quantization algorithm to run.
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ We recommend targeting all `Linear` layers using the `FP8_DYNAMIC` scheme, which
Since simple PTQ does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.

```python
from llmcompressor.transformers import oneshot
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

# Configure the simple PTQ quantization
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/gemma2_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "google/gemma-2-27b-it"

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llama3.2_vision_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoProcessor, MllamaForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llama3_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/llava1.5_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoProcessor, LlavaForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "llava-hf/llava-1.5-7b-hf"

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_fp8/qwen2vl_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"

Expand Down
6 changes: 3 additions & 3 deletions examples/quantization_w8a8_fp8/whisper_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datasets import load_dataset
from transformers import AutoProcessor, WhisperForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

MODEL_ID = "openai/whisper-large-v2"

Expand Down Expand Up @@ -35,8 +35,8 @@
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features
input_features = input_features.to(model.device)
predicted_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
print(processor.batch_decode(predicted_ids, skip_special_tokens=False)[0])
output_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
print(processor.batch_decode(output_ids, skip_special_tokens=False)[0])
# Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel
print("==========================================")

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_int8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ We first select the quantization algorithm. For W8A8, we want to:
> See the `Recipes` documentation for more information on recipes
```python
from llmcompressor.transformers import oneshot
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier

Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_int8/gemma2_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# 1) Select model and load it.
MODEL_ID = "google/gemma-2-2b-it"
Expand Down
2 changes: 1 addition & 1 deletion examples/quantization_w8a8_int8/llama3_example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down
Loading

0 comments on commit 3fff13a

Please sign in to comment.