-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add first test * update tests * update to use config files * update test * update to add int8 tests * update * fix condition * fix typo * add w8a16 * update * update to clear session and delete dirs * conditional import for vllm * update * update num samples * add more test cases; add custom recipe support * update model * updat recipe modifier * Update fp8_weight_only.yaml * add more test cases * try a larger model * revert * add description; save model to hub post testing
- Loading branch information
Showing
20 changed files
with
305 additions
and
29 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: FP8_DYNAMIC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: FP8 | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml | ||
scheme: FP8A16_channel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml | ||
scheme: FP8A16_tensor |
7 changes: 7 additions & 0 deletions
7
tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft | ||
scheme: W8A8_channel_weight_static_per_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: W8A8 | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft |
7 changes: 7 additions & 0 deletions
7
tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft | ||
scheme: W8A8_tensor_weight_static_per_tensor_act |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: W4A16_channel | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft | ||
recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: W4A16 | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: W8A16_channel | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft | ||
recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
cadence: "nightly" | ||
test_type: "regression" | ||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
scheme: W8A16 | ||
dataset_id: HuggingFaceH4/ultrachat_200k | ||
dataset_split: train_sft |
9 changes: 9 additions & 0 deletions
9
tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 8, type: float, symmetric: true, strategy: channel, dynamic: false} | ||
targets: [Linear] |
9 changes: 9 additions & 0 deletions
9
tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 8, type: float, symmetric: true, strategy: tensor, dynamic: false} | ||
targets: [Linear] |
10 changes: 10 additions & 0 deletions
10
tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 8, type: int, symmetric: true, strategy: channel} | ||
input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} | ||
targets: [Linear] |
10 changes: 10 additions & 0 deletions
10
tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 8, type: int, symmetric: true, strategy: tensor} | ||
input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} | ||
targets: [Linear] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} | ||
targets: [Linear] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
quant_stage: | ||
quant_modifiers: | ||
QuantizationModifier: | ||
sequential_update: false | ||
ignore: [lm_head] | ||
config_groups: | ||
group_0: | ||
weights: {num_bits: 8, type: int, symmetric: true, strategy: channel, dynamic: false} | ||
targets: [Linear] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import shutil | ||
import unittest | ||
|
||
import pytest | ||
from datasets import load_dataset | ||
from parameterized import parameterized_class | ||
from transformers import AutoTokenizer | ||
|
||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot | ||
from tests.testing_utils import parse_params, requires_gpu, requires_torch | ||
|
||
try: | ||
from vllm import LLM, SamplingParams | ||
|
||
vllm_installed = True | ||
except ImportError: | ||
vllm_installed = False | ||
|
||
# Defines the file paths to the directories containing the test configs | ||
# for each of the quantization schemes | ||
WNA16 = "tests/e2e/vLLM/configs/WNA16" | ||
FP8 = "tests/e2e/vLLM/configs/FP8" | ||
INT8 = "tests/e2e/vLLM/configs/INT8" | ||
CONFIGS = [WNA16, FP8, INT8] | ||
|
||
|
||
@requires_gpu | ||
@requires_torch | ||
@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") | ||
@parameterized_class(parse_params(CONFIGS)) | ||
class TestvLLM(unittest.TestCase): | ||
""" | ||
The following test quantizes a model using a preset scheme or recipe, | ||
runs the model using vLLM, and then pushes the model to the hub for | ||
future use. Each test case is focused on a specific quantization type | ||
(e.g W4A16 with grouped quantization, W4N16 with channel quantization). | ||
To add a new test case, a new config has to be added to one of the folders | ||
listed in the `CONFIGS` folder. If the test case is for a data type not listed | ||
in `CONFIGS`, a new folder can be created and added to the list. The tests | ||
run on a cadence defined by the `cadence` field. Each config defines the model | ||
to quantize. Optionally, a dataset id and split can be provided for calibration. | ||
Finally, all config files must list a scheme. The scheme can be a preset scheme | ||
from https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py # noqa: E501 | ||
or another identifier which can be used for the particular test case. If a recipe | ||
is not provided, it is assumed that the scheme provided is a preset scheme and will | ||
be used for quantization. Otherwise, the recipe will always be used if given. | ||
""" | ||
|
||
model = None | ||
scheme = None | ||
dataset_id = None | ||
dataset_split = None | ||
recipe = None | ||
|
||
def setUp(self): | ||
print("========== RUNNING ==============") | ||
print(self.scheme) | ||
|
||
self.save_dir = None | ||
self.device = "cuda:0" | ||
self.oneshot_kwargs = {} | ||
self.num_calibration_samples = 256 | ||
self.max_seq_length = 1048 | ||
self.prompts = [ | ||
"The capital of France is", | ||
"The president of the US is", | ||
"My name is", | ||
] | ||
|
||
def test_vllm(self): | ||
# Load model. | ||
loaded_model = SparseAutoModelForCausalLM.from_pretrained( | ||
self.model, device_map=self.device, torch_dtype="auto" | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(self.model) | ||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=self.max_seq_length, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
if self.dataset_id: | ||
ds = load_dataset(self.dataset_id, split=self.dataset_split) | ||
ds = ds.shuffle(seed=42).select(range(self.num_calibration_samples)) | ||
ds = ds.map(preprocess) | ||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
self.oneshot_kwargs["dataset"] = ds | ||
self.oneshot_kwargs["max_seq_length"] = self.max_seq_length | ||
self.oneshot_kwargs["num_calibration_samples"] = ( | ||
self.num_calibration_samples | ||
) | ||
|
||
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" | ||
self.oneshot_kwargs["model"] = loaded_model | ||
if self.recipe: | ||
self.oneshot_kwargs["recipe"] = self.recipe | ||
else: | ||
# Test assumes that if a recipe was not provided, using | ||
# a compatible preset sceme | ||
self.oneshot_kwargs["recipe"] = QuantizationModifier( | ||
targets="Linear", scheme=self.scheme, ignore=["lm_head"] | ||
) | ||
|
||
# Apply quantization. | ||
print("ONESHOT KWARGS", self.oneshot_kwargs) | ||
oneshot( | ||
**self.oneshot_kwargs, | ||
clear_sparse_session=True, | ||
oneshot_device=self.device, | ||
) | ||
self.oneshot_kwargs["model"].save_pretrained(self.save_dir) | ||
tokenizer.save_pretrained(self.save_dir) | ||
# Run vLLM with saved model | ||
print("================= RUNNING vLLM =========================") | ||
sampling_params = SamplingParams(temperature=0.80, top_p=0.95) | ||
llm = LLM(model=self.save_dir) | ||
outputs = llm.generate(self.prompts, sampling_params) | ||
print("================= vLLM GENERATION ======================") | ||
for output in outputs: | ||
assert output | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print("PROMPT", prompt) | ||
print("GENERATED TEXT", generated_text) | ||
|
||
print("================= UPLOADING TO HUB ======================") | ||
self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e") | ||
tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e") | ||
|
||
def tearDown(self): | ||
shutil.rmtree(self.save_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters