diff --git a/.github/workflows/linkcheck.yml b/.github/workflows/linkcheck.yml index 8d02a43c7..7573b91ff 100644 --- a/.github/workflows/linkcheck.yml +++ b/.github/workflows/linkcheck.yml @@ -16,7 +16,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - uses: gaurav-nelson/github-action-markdown-link-check@v1 + - uses: umbrelladocs/action-linkspector@v1 with: - use-quiet-mode: 'yes' - config-file: '.github/workflows/mlc_config.json' + github_token: ${{ secrets.github_token }} + reporter: github-pr-review + fail_on_error: true + config_file: '.github/workflows/linkspector/linkspector.yml' diff --git a/.github/workflows/linkspector/linkspector.yml b/.github/workflows/linkspector/linkspector.yml new file mode 100644 index 000000000..e553e9752 --- /dev/null +++ b/.github/workflows/linkspector/linkspector.yml @@ -0,0 +1,10 @@ +aliveStatusCodes: + - 0 + - 200 +ignorePatterns: + - pattern: '.*localhost.*' + - pattern: '.*127\\.0\\.0\\.1.*' + - pattern: '.*0\\.0\\.0\\.0.*' +dirs: + - . +useGitIgnore: true \ No newline at end of file diff --git a/.github/workflows/mlc-config.json b/.github/workflows/mlc-config.json deleted file mode 100644 index 8cdd2fa1a..000000000 --- a/.github/workflows/mlc-config.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "aliveStatusCodes": [ - 0, - 200 - ], - "ignorePatterns": [ - { - "pattern": ".*localhost.*" - }, - { - "pattern": ".*127\\.0\\.0\\.1.*" - }, - { - "pattern": ".*0\\.0\\.0\\.0.*" - } - ] -} \ No newline at end of file diff --git a/.github/workflows/set-comment.yaml b/.github/workflows/set-comment.yaml new file mode 100644 index 000000000..47edb6fec --- /dev/null +++ b/.github/workflows/set-comment.yaml @@ -0,0 +1,23 @@ +name: PR Reminder Comment Bot +on: + pull_request: + branches: + - main + types: [opened] + +jobs: + pr_reminder: + runs-on: ubuntu-latest + steps: + - name: Remind to add ready label + uses: actions/github-script@v6 + with: + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: '👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.' + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index b29a971f3..5eeb8c74d 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -8,6 +8,7 @@ on: branches: - main - 'release/*' + types: [labeled, opened, synchronize] env: CADENCE: "commit" @@ -46,6 +47,15 @@ jobs: with: python-version: '3.11' - uses: actions/checkout@v2 + - uses: actions/checkout@v2 + with: + repository: "neuralmagic/compressed-tensors" + path: "compressed-tensors" + ref: ${{needs.test-setup.outputs.branch}} + - name: "⚙️ Install compressed-tensors dependencies" + run: pip3 install -U pip && pip3 install setuptools compressed-tensors/ + - name: "Clean compressed-tensors directory" + run: rm -r compressed-tensors/ - name: "⚙️ Install dependencies" run: pip3 install .[dev] - name: "🔬 Running base tests" @@ -95,6 +105,7 @@ jobs: run: | pytest tests/llmcompressor/pytorch -v transformers-tests: + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'ready') }} runs-on: ubuntu-22.04 needs: test-setup steps: diff --git a/examples/quantization_kv_cache/README.md b/examples/quantization_kv_cache/README.md new file mode 100644 index 000000000..11d78ab18 --- /dev/null +++ b/examples/quantization_kv_cache/README.md @@ -0,0 +1,170 @@ +# `fp8` Weight, Activation, and KV Cache Quantization + +`llmcompressor` now supports quantizing weights, activations, and KV cache to `fp8` for memory savings and inference acceleration with `vllm`. + +> `fp8` computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper). + +## Installation + +To get started, install llmcompressor from source as this feature is new: + +```bash +pip install git+https://github.com/vllm-project/llm-compressor.git@cb98f34d4ec9dd175e6995d12fb02dec39c6f27a +``` + +## Quickstart + +The example includes an end-to-end script for applying the quantization algorithm: + +```bash +python3 llama3_fp8_kv_example.py +``` + +The resulting model `Meta-Llama-3-8B-Instruct-FP8-KV` is ready to be loaded into vLLM. + +## Code Walkthrough + +Let's walk through the main steps of the quantization process: + +1. Load model +2. Prepare calibration data +3. Apply quantization +4. Evaluate and save the model + +### 1. Load Model + +Load the model using `SparseAutoModelForCausalLM`: + +```python +from llmcompressor.transformers import SparseAutoModelForCausalLM +from transformers import AutoTokenizer + +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +``` + +### 2. Prepare Calibration Data + +Prepare the calibration data using the `ultrachat` dataset: + +```python +from datasets import load_dataset + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +def process_and_tokenize(example): + text = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer(text, padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False) + +ds = ds.map(process_and_tokenize, remove_columns=ds.column_names) +``` + +### 3. Apply Quantization + +Configure and apply the FP8 quantization for weights, activations, and KV cache. +Notice the new `kv_cache_scheme` section: + +```python +from llmcompressor.transformers import oneshot + +recipe = """ +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + targets: ["Linear"] + kv_cache_scheme: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true +""" + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) +``` + +### 4. Evaluate and Save the Model + +Test the quantized model with a sample generation: + +```python +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +``` + +Save the quantized model: + +```python +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) +``` + +For running the model in vLLM, make sure to specify the `kv_cache_dtype="fp8"` argument to enable quantization of the kv cache, and thus usage of your calibrated scales. + +## Evaluating Accuracy + +To evaluate the accuracy of your quantized model: + +1. Install `vllm` and `lm-evaluation-harness`: + +```bash +pip install "vllm>=0.5.5" lm_eval==0.4.3 +``` + +2. Run an evaluation (e.g., on GSM-8K): + +```bash +MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-KV +lm_eval \ + --model vllm \ + --model_args pretrained=$MODEL,kv_cache_dtype=fp8,add_bos_token=True \ + --tasks gsm8k --num_fewshot 5 --batch_size auto +``` + +``` +vllm (pretrained=Meta-Llama-3-8B-Instruct-FP8-KV,kv_cache_dtype=fp8,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7748|± |0.0115| +| | |strict-match | 5|exact_match|↑ |0.7763|± |0.0115| +``` + +Note: Include `add_bos_token=True` as quantized models can be sensitive to the presence of the `bos` token. + +## Questions or Feature Requests? + +Please open an issue on `vllm-project/llm-compressor`. \ No newline at end of file diff --git a/examples/quantization_kv_cache/llama3_fp8_kv_example.py b/examples/quantization_kv_cache/llama3_fp8_kv_example.py new file mode 100644 index 000000000..a3e2a8e95 --- /dev/null +++ b/examples/quantization_kv_cache/llama3_fp8_kv_example.py @@ -0,0 +1,95 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def process_and_tokenize(example): + text = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return tokenizer( + text, + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(process_and_tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per-tensor scales +# * quantize the activations to fp8 with per-tensor scales +# * quantize the kv cache to fp8 with per-tensor scales +recipe = """ +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + targets: ["Linear"] + kv_cache_scheme: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true +""" + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w8a8_fp8/gemma2_example.py b/examples/quantization_w8a8_fp8/gemma2_example.py index 494676b73..20700da53 100644 --- a/examples/quantization_w8a8_fp8/gemma2_example.py +++ b/examples/quantization_w8a8_fp8/gemma2_example.py @@ -7,7 +7,8 @@ # 1) Load model. model = SparseAutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto") + MODEL_ID, device_map="auto", torch_dtype="auto" +) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # 2) Configure the quantization algorithm and scheme. @@ -15,14 +16,12 @@ # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token recipe = QuantizationModifier( - targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]) + targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] +) # 3) Apply quantization and save in compressed-tensors format. OUTPUT_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" -oneshot(model=model, - recipe=recipe, - output_dir=OUTPUT_DIR, - tokenizer=tokenizer) +oneshot(model=model, recipe=recipe, output_dir=OUTPUT_DIR, tokenizer=tokenizer) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") diff --git a/examples/quantization_w8a8_int8/gemma2_example.py b/examples/quantization_w8a8_int8/gemma2_example.py index bf87f5714..976ac5473 100644 --- a/examples/quantization_w8a8_int8/gemma2_example.py +++ b/examples/quantization_w8a8_int8/gemma2_example.py @@ -7,7 +7,10 @@ # 1) Select model and load it. MODEL_ID = "google/gemma-2-2b-it" model = SparseAutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto",) + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # 2) Prepare calibration dataset. @@ -62,7 +65,7 @@ def tokenize(sample): recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - output_dir=MODEL_ID.split("/")[1] + "-INT8" + output_dir=MODEL_ID.split("/")[1] + "-INT8", ) # Confirm generations of the quantized model look sane. diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 4ca01fa3b..ebe826768 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -59,6 +59,7 @@ class GPTQModifier(Modifier): | symmetric: true | strategy: "tensor" | group_size: 128 + | actorder: False :param sequential_update: Whether or not to update weights sequentially by layer, @@ -169,9 +170,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) if self.quantization_modifier_: - self.quantization_modifier_.initialize( - state, freeze_quantization=False, **kwargs - ) + self.quantization_modifier_.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index 82b6f6802..3d2a13f55 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -1,7 +1,12 @@ import time +from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from compressed_tensors.quantization.observers import MemorylessObserver + from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper +from llmcompressor.utils import getattr_chain from llmcompressor.utils.metric_logging import ( get_GPU_memory_usage, get_layer_size_bytes, @@ -21,6 +26,7 @@ from compressed_tensors.utils import ( get_offloaded_device, is_module_offloaded, + update_parameter_data, update_prefix_dict, ) from loguru import logger @@ -83,6 +89,12 @@ def compress( :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm """ + weight_quant_args = getattr_chain( + self.layer, "quantization_scheme.weights", None + ) + if weight_quant_args is None: + logger.debug(f"Skipping unquantized layer {self.name}...") + return if is_module_offloaded(self.layer): self.layer._hf_hook.pre_forward(self.layer) @@ -92,12 +104,14 @@ def compress( W = self.layer.weight.data.clone() from llmcompressor.pytorch.utils.helpers import tensor_sparsity + # standardize shape and dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() + elif isinstance(self.layer, transformers.Conv1D): + W.transpose_(0, 1) W = W.float() + # sparsity mask sparsity = tensor_sparsity(W) preserve_zeros = sparsity >= SPARSITY_THRESHOLD W_nz_mask = ( @@ -108,23 +122,32 @@ def compress( tick = time.time() - if hasattr(self.layer, "quantization_scheme"): - quant_scheme = self.layer.quantization_scheme - if quant_scheme.weights is not None: - # fetch latest correct scale and ZP relevant for any changes - # such as activation reordering - from compressed_tensors.quantization import ( - update_layer_weight_quant_params, - ) + # consider activation ordering + if weight_quant_args.actorder: + # use hessian to create a permutation of weights + perm = torch.argsort(torch.diag(self.H), descending=True) + + # permute weight and hessian + W = W[:, perm] + self.H = self.H[perm][:, perm] + + # update quantization parameters for activation ordering + observer = MemorylessObserver(weight_quant_args) + _scale, _zero_point = observer(W) + update_parameter_data(self.layer, _scale, "weight_scale") + update_parameter_data(self.layer, _zero_point, "weight_zero_point") - update_layer_weight_quant_params(self.layer) + scale = self.layer.weight_scale + zero_point = self.layer.weight_zero_point + # mask dead hessian values dead = torch.diag(self.H) == 0 self.H[dead, dead] = 1 W[:, dead] = 0 Losses = torch.zeros(self.rows, device=self.dev) + # compute inverse hessian in place to save memory damp = percdamp * torch.mean(torch.diag(self.H)) diag = torch.arange(self.columns, device=self.dev) self.H[diag, diag] += damp @@ -152,61 +175,44 @@ def compress( d = Hinv1[i, i] q = w.clone() - if hasattr(self.layer, "weight_fake_quant"): - scale = self.layer.weight_fake_quant.scale - zero_point = self.layer.weight_fake_quant.zero_point - dtype = self.layer.weight_fake_quant.dtype - qscheme = self.layer.weight_fake_quant.qscheme - if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: - q = torch.quantize_per_tensor(q, scale, zero_point, dtype) - else: - q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) - q = torch.dequantize(q) - elif hasattr(self.layer, "quantization_scheme"): - quant_scheme = self.layer.quantization_scheme - if quant_scheme.weights is not None: - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - from compressed_tensors.quantization import QuantizationStrategy - from compressed_tensors.quantization.lifecycle.forward import ( - fake_quantize, - ) - - strategy = quant_scheme.weights.strategy - - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - # TODO: for channelwise why isn't this just a 1d tensor? - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - quant_scheme.weights, - ) - else: # strategy == QuantizationStrategy.GROUP - # get the group index for the current column - column_idx = i1 + i - input_dim_group = ( - column_idx // quant_scheme.weights.group_size - ) - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(quant_scheme.weights) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, input_dim_group], - zero_point[:, input_dim_group], - altered_qargs, - ) + # quantize column + strategy = weight_quant_args.strategy + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + self.layer.quantization_scheme.weights, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + weight_quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + input_dim_group = column_idx // weight_quant_args.group_size + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(weight_quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, input_dim_group], + zero_point[:, input_dim_group], + altered_qargs, + ) + else: + raise ValueError( + "Quantization strategy is not supported for GPTQ: " + f"{strategy}" + ) + # propagate column error Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 @@ -218,6 +224,7 @@ def compress( W1[:, i:] -= w1_err Err1[:, i] = err1 + # propagate block error W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 @@ -228,33 +235,28 @@ def compress( W[:, i2:] -= w_err if "METRIC" in logger._core.levels.keys(): - logger.log("METRIC", "time %.2f" % (time.time() - tick)) - logger.log("METRIC", "error %.2f" % torch.sum(Losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - logger.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) + self.log_metrics(tick, Losses) + + if weight_quant_args.actorder: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + # g_idx describes the group index of the permuted weight + g_idx = torch.tensor( + [i // weight_quant_args.group_size for i in range(self.columns)], + dtype=torch.int, + ).to(device=invperm.device) - logger.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) + # invert to get the group index of the unpermuted weight + update_parameter_data(self.layer, g_idx[invperm], "weight_g_idx") if isinstance(self.layer, transformers.Conv1D): - W = W.t() + W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - # This is a bit hacky, but FSDP updates only work if we change the weight in - # place, clone() or direct assignment won't work + # This is a bit hacky, but FSDP updates only work if we change + # the weight in place, clone() or direct assignment won't work self.layer.weight -= self.layer.weight self.layer.weight += W @@ -263,13 +265,37 @@ def compress( update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) self.layer._hf_hook.post_forward(self.layer, None) - del W - del Losses - del diag - def free(self): """ Free the Hessian memory after the layer is complete """ delattr(self, "H") super().free() + + def log_metrics(self, start_tick: float, losses: torch.Tensor): + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + logger.log("METRIC", "time %.2f" % (time.time() - start_tick)) + logger.log("METRIC", "error %.2f" % torch.sum(losses).item()) + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + logger.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + logger.log( + "METRIC", + f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", + ) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 3d7649a4a..ebbb48f4f 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -70,13 +70,9 @@ class QuantizationModifier(Modifier): calibration_function_: Any = None def on_initialize_structure(self, state: State, **kwargs): - module = state.model - self._apply_modifier_to_model(module) - module.apply(freeze_module_quantization) + pass - def on_initialize( - self, state: State, freeze_quantization: bool = True, **kwargs - ) -> bool: + def on_initialize(self, state: State, **kwargs) -> bool: if self.end and self.end != -1: raise ValueError( "end_epoch is disabled for QuantizationModifier and can only be set to" @@ -96,8 +92,7 @@ def on_initialize( self._check_token_distribution( module, threshold=kwargs.get("min_tokens_per_module") ) - if freeze_quantization: - module.apply(freeze_module_quantization) + module.apply(freeze_module_quantization) return True diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 169cef170..62bdbfa6b 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -118,9 +118,9 @@ def create_instance( if not os.path.isfile(path_or_modifiers): # not a local file # assume it's a string - logger.warning( - "Could not process input as a file path or zoo stub, " - "attempting to process it as a string." + logger.debug( + "Could not initialize recipe as a file path or zoo stub, " + "attempting to process as a string." ) logger.debug(f"Input string: {path_or_modifiers}") obj = _load_json_or_yaml_string(path_or_modifiers) diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index a49c0f7fe..42530f3e1 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -4,19 +4,20 @@ from typing import Optional, Union import torch +from accelerate import load_checkpoint_and_dispatch from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.quantization import ( + QuantizationStatus, + apply_quantization_config, +) from loguru import logger from torch.nn import Module from transformers import AutoModelForCausalLM, PreTrainedModel -from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, ) -from llmcompressor.transformers.utils.helpers import ( - download_model_directory, - resolve_recipe, -) +from llmcompressor.transformers.utils.helpers import download_model_directory __all__ = ["SparseAutoModel", "SparseAutoModelForCausalLM", "get_shared_tokenizer_src"] @@ -40,6 +41,7 @@ class SparseAutoModelForCausalLM(AutoModelForCausalLM): def from_pretrained( cls, pretrained_model_name_or_path, + run_compressed: bool = False, recipe: Optional[Union[str, Path]] = None, *model_args, **kwargs, @@ -109,18 +111,37 @@ def skip(*args, **kwargs): # restore transformers logging level now that model shell is loaded transformers_logger.setLevel(level=restore_log_level) + # HfQuantizer Quantization + if hasattr(model.config, "quantization_config"): + return model + # override the PreTrainedModel instance with compression save function modify_save_pretrained(model) # If model is quantized or compressed on disk, initialize quantization # structure and run decompression if compressor is not None: - # initialize quantization and decompress weights - compressor.decompress(model_path=pretrained_model_name_or_path, model=model) - - recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) - if recipe: - initialize_recipe(model=model, recipe_path=recipe) + quantization_config = compressor.quantization_config + is_compressed = ( + quantization_config is not None + and quantization_config.quantization_status + == QuantizationStatus.COMPRESSED + ) + if run_compressed and is_compressed: + # initialize quantization, don't decompress + apply_quantization_config( + model, quantization_config, run_compressed=True + ) + model = load_checkpoint_and_dispatch( + model, pretrained_model_name_or_path + ) + else: + # initialize quantization and decompress weights + if quantization_config is not None: + quantization_config.quantization_status = QuantizationStatus.FROZEN + compressor.decompress( + model_path=pretrained_model_name_or_path, model=model + ) return model diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index a93d4614a..266acf973 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -58,6 +58,7 @@ "parse_kwarg_tuples", "is_package_available", "import_from_path", + "getattr_chain", ] @@ -1008,3 +1009,35 @@ def import_from_path(path: str) -> str: return getattr(module, class_name) except AttributeError: raise AttributeError(f"Cannot find {class_name} in {_path}") + + +def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: + """ + Chain multiple getattr calls, separated by `.` + + :param obj: base object whose attributes are being retrieved + :param chain_str: attribute names separated by `.` + :param default: default value, throw error otherwise + + """ + if len(args) >= 1: + has_default = True + default = args[0] + elif "default" in kwargs: + has_default = True + default = kwargs["default"] + else: + has_default = False + + attr_names = chain_str.split(".") + + res = obj + for attr_name in attr_names: + if not hasattr(res, attr_name): + if has_default: + return default + else: + raise AttributeError(f"{res} object has no attribute {attr_name}") + res = getattr(res, attr_name) + + return res diff --git a/tests/e2e/vLLM/__init__.py b/tests/e2e/vLLM/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml new file mode 100644 index 000000000..b37bbde09 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_dynamic_per_token.yaml @@ -0,0 +1,4 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: FP8_DYNAMIC \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml b/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml new file mode 100644 index 000000000..9d0e3c1a1 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_static_per_tensor.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: FP8 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml new file mode 100644 index 000000000..89f845279 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_channel.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml +scheme: FP8A16_channel \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml new file mode 100644 index 000000000..1239287f2 --- /dev/null +++ b/tests/e2e/vLLM/configs/FP8/fp8_weight_only_tensor.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml +scheme: FP8A16_tensor \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..ecdd84938 --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_channel_weight_static_per_tensor_act.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +scheme: W8A8_channel_weight_static_per_tensor \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml b/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml new file mode 100644 index 000000000..befa14beb --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_dynamic_per_token.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A8 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..4af8e65ad --- /dev/null +++ b/tests/e2e/vLLM/configs/INT8/int8_tensor_weight_static_per_tensor_act.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +scheme: W8A8_tensor_weight_static_per_tensor_act diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml new file mode 100644 index 000000000..f08a64159 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w4a16_channel_quant.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W4A16_channel +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml new file mode 100644 index 000000000..bbd1406ce --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w4a16_grouped_quant.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W4A16 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml new file mode 100644 index 000000000..f9adbc506 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w8a16_channel_quant.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A16_channel +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft +recipe: tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml \ No newline at end of file diff --git a/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml b/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml new file mode 100644 index 000000000..4e9a278a5 --- /dev/null +++ b/tests/e2e/vLLM/configs/WNA16/w8a16_grouped_quant.yaml @@ -0,0 +1,6 @@ +cadence: "nightly" +test_type: "regression" +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +scheme: W8A16 +dataset_id: HuggingFaceH4/ultrachat_200k +dataset_split: train_sft \ No newline at end of file diff --git a/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml new file mode 100644 index 000000000..84d6505cb --- /dev/null +++ b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_channel.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: float, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml new file mode 100644 index 000000000..8a6dfbde6 --- /dev/null +++ b/tests/e2e/vLLM/recipes/FP8/recipe_fp8_weight_only_per_tensor.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: float, symmetric: true, strategy: tensor, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..6cfa275af --- /dev/null +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_static_per_tensor_act.yaml @@ -0,0 +1,10 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: channel} + input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml new file mode 100644 index 000000000..6ddcc63b4 --- /dev/null +++ b/tests/e2e/vLLM/recipes/INT8/recipe_int8_tensor_weight_static_per_tensor_act.yaml @@ -0,0 +1,10 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + input_activations: {num_bits: 8, type: int, symmetric: true, strategy: tensor} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml new file mode 100644 index 000000000..b667b2d10 --- /dev/null +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w4a16_channel_quant.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 4, type: int, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml new file mode 100644 index 000000000..bafd7928d --- /dev/null +++ b/tests/e2e/vLLM/recipes/WNA16/recipe_w8a16_channel_quant.yaml @@ -0,0 +1,9 @@ +quant_stage: + quant_modifiers: + QuantizationModifier: + sequential_update: false + ignore: [lm_head] + config_groups: + group_0: + weights: {num_bits: 8, type: int, symmetric: true, strategy: channel, dynamic: false} + targets: [Linear] diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py new file mode 100644 index 000000000..a4a47c6d2 --- /dev/null +++ b/tests/e2e/vLLM/test_vllm.py @@ -0,0 +1,144 @@ +import shutil +import unittest + +import pytest +from datasets import load_dataset +from parameterized import parameterized_class +from transformers import AutoTokenizer + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot +from tests.testing_utils import parse_params, requires_gpu, requires_torch + +try: + from vllm import LLM, SamplingParams + + vllm_installed = True +except ImportError: + vllm_installed = False + +# Defines the file paths to the directories containing the test configs +# for each of the quantization schemes +WNA16 = "tests/e2e/vLLM/configs/WNA16" +FP8 = "tests/e2e/vLLM/configs/FP8" +INT8 = "tests/e2e/vLLM/configs/INT8" +CONFIGS = [WNA16, FP8, INT8] + + +@requires_gpu +@requires_torch +@pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") +@parameterized_class(parse_params(CONFIGS)) +class TestvLLM(unittest.TestCase): + """ + The following test quantizes a model using a preset scheme or recipe, + runs the model using vLLM, and then pushes the model to the hub for + future use. Each test case is focused on a specific quantization type + (e.g W4A16 with grouped quantization, W4N16 with channel quantization). + To add a new test case, a new config has to be added to one of the folders + listed in the `CONFIGS` folder. If the test case is for a data type not listed + in `CONFIGS`, a new folder can be created and added to the list. The tests + run on a cadence defined by the `cadence` field. Each config defines the model + to quantize. Optionally, a dataset id and split can be provided for calibration. + Finally, all config files must list a scheme. The scheme can be a preset scheme + from https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py + or another identifier which can be used for the particular test case. If a recipe + is not provided, it is assumed that the scheme provided is a preset scheme and will + be used for quantization. Otherwise, the recipe will always be used if given. + """ # noqa: E501 + + model = None + scheme = None + dataset_id = None + dataset_split = None + recipe = None + + def setUp(self): + print("========== RUNNING ==============") + print(self.scheme) + + self.save_dir = None + self.device = "cuda:0" + self.oneshot_kwargs = {} + self.num_calibration_samples = 256 + self.max_seq_length = 1048 + self.prompts = [ + "The capital of France is", + "The president of the US is", + "My name is", + ] + + def test_vllm(self): + # Load model. + loaded_model = SparseAutoModelForCausalLM.from_pretrained( + self.model, device_map=self.device, torch_dtype="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(self.model) + + def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=self.max_seq_length, + truncation=True, + add_special_tokens=False, + ) + + if self.dataset_id: + ds = load_dataset(self.dataset_id, split=self.dataset_split) + ds = ds.shuffle(seed=42).select(range(self.num_calibration_samples)) + ds = ds.map(preprocess) + ds = ds.map(tokenize, remove_columns=ds.column_names) + self.oneshot_kwargs["dataset"] = ds + self.oneshot_kwargs["max_seq_length"] = self.max_seq_length + self.oneshot_kwargs["num_calibration_samples"] = ( + self.num_calibration_samples + ) + + self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" + self.oneshot_kwargs["model"] = loaded_model + if self.recipe: + self.oneshot_kwargs["recipe"] = self.recipe + else: + # Test assumes that if a recipe was not provided, using + # a compatible preset sceme + self.oneshot_kwargs["recipe"] = QuantizationModifier( + targets="Linear", scheme=self.scheme, ignore=["lm_head"] + ) + + # Apply quantization. + print("ONESHOT KWARGS", self.oneshot_kwargs) + oneshot( + **self.oneshot_kwargs, + clear_sparse_session=True, + oneshot_device=self.device, + ) + self.oneshot_kwargs["model"].save_pretrained(self.save_dir) + tokenizer.save_pretrained(self.save_dir) + # Run vLLM with saved model + print("================= RUNNING vLLM =========================") + sampling_params = SamplingParams(temperature=0.80, top_p=0.95) + llm = LLM(model=self.save_dir) + outputs = llm.generate(self.prompts, sampling_params) + print("================= vLLM GENERATION ======================") + for output in outputs: + assert output + prompt = output.prompt + generated_text = output.outputs[0].text + print("PROMPT", prompt) + print("GENERATED TEXT", generated_text) + + print("================= UPLOADING TO HUB ======================") + self.oneshot_kwargs["model"].push_to_hub(f"nm-testing/{self.save_dir}-e2e") + tokenizer.push_to_hub(f"nm-testing/{self.save_dir}-e2e") + + def tearDown(self): + shutil.rmtree(self.save_dir) diff --git a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py new file mode 100644 index 000000000..203d1fe03 --- /dev/null +++ b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py @@ -0,0 +1,41 @@ +from collections import OrderedDict + +import torch +from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config +from compressed_tensors.quantization.quant_config import QuantizationConfig +from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme +from loguru import logger + +from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper + + +def test_ignore(): + model = torch.nn.Sequential( + OrderedDict( + [ + ("first_layer", torch.nn.Linear(2, 3)), + ("second_layer", torch.nn.Linear(3, 5)), + ] + ) + ) + + config = QuantizationConfig( + config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])}, + ignore=["first_layer"], + ) + apply_quantization_config(model, config) + + messages = [] + logger.add(lambda m: messages.append(m)) + + with torch.no_grad(): + first_compressor = GPTQWrapper("first_layer", model.first_layer) + first_compressor.add_batch(torch.ones(2), None) + first_compressor.compress() + + second_compressor = GPTQWrapper("second_layer", model.second_layer) + second_compressor.add_batch(torch.ones(3), None) + second_compressor.compress() + + assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1 + assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0 diff --git a/tests/llmcompressor/transformers/compression/configs/actorder_1.1b.yaml b/tests/llmcompressor/transformers/compression/configs/actorder_1.1b.yaml new file mode 100644 index 000000000..4cb398810 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/configs/actorder_1.1b.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml" +ppl_threshold: 20 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml b/tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml new file mode 100644 index 000000000..21f249948 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/new_quant_actorder.yaml @@ -0,0 +1,19 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head", "model.layers.0.mlp.down_proj"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: False + strategy: "group" + group_size: 128 + actorder: True + input_activations: null + output_activations: null + targets: ["Linear"] + GPTQModifier: + block_size: 128 + sequential_update: False \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml new file mode 100644 index 000000000..6159646ed --- /dev/null +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "regression" +model_stub: "nm-testing/tinyllama-fp8-dynamic-compressed" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml new file mode 100644 index 000000000..844cf457d --- /dev/null +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "regression" +model_stub: "nm-testing/tinyllama-w4a16-compressed" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml new file mode 100644 index 000000000..367d3fd4f --- /dev/null +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "regression" +model_stub: "nm-testing/tinyllama-w8a16-dense" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml new file mode 100644 index 000000000..844cf457d --- /dev/null +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml @@ -0,0 +1,3 @@ +cadence: "commit" +test_type: "regression" +model_stub: "nm-testing/tinyllama-w4a16-compressed" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/test_run_compressed.py b/tests/llmcompressor/transformers/compression/test_run_compressed.py new file mode 100644 index 000000000..97070377b --- /dev/null +++ b/tests/llmcompressor/transformers/compression/test_run_compressed.py @@ -0,0 +1,59 @@ +import shutil +import tempfile +import unittest + +import torch +from parameterized import parameterized_class +from transformers import AutoTokenizer + +from llmcompressor.transformers import SparseAutoModelForCausalLM +from tests.testing_utils import parse_params, requires_gpu, requires_torch + +CONFIG_DIR = "tests/llmcompressor/transformers/compression/run_compressed_configs" + + +@requires_torch +@requires_gpu +@parameterized_class(parse_params(CONFIG_DIR)) +class TestQuantizationMatches(unittest.TestCase): + model_stub = None + + @classmethod + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() + + cls.compressed_model = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, torch_dtype="auto", device_map="auto", run_compressed=True + ) + cls.uncompressed_model = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, torch_dtype="auto", device_map="auto", run_compressed=False + ) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_stub) + cls.device = cls.compressed_model.device + + def test_compressed_matches_uncompressed(self): + SAMPLE_INPUT = [ + "I love 4-bit quantization because", + "What is the capital of Paris?", + "def fibonacci(n):", + ] + + inputs = self.tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to( + self.device + ) + compressed_output = self.tokenizer.batch_decode( + self.compressed_model.generate(**inputs, max_length=50) + ) + uncompressed_output = self.tokenizer.batch_decode( + self.uncompressed_model.generate(**inputs, max_length=50) + ) + + for idx in range(len(SAMPLE_INPUT)): + assert compressed_output[idx] == uncompressed_output[idx] + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.compressed_model + del cls.uncompressed_model + torch.cuda.empty_cache() diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index ef515eb32..dc9e44954 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -1,9 +1,12 @@ +from types import SimpleNamespace + import pytest from llmcompressor.utils import ( ALL_TOKEN, convert_to_bool, flatten_iterable, + getattr_chain, interpolate, parse_kwarg_tuples, validate_str_iterable, @@ -90,3 +93,34 @@ def test_interpolate(x_cur, x0, x1, y0, y1, inter_func, out): def test_pass_kwargs_tuples(): kwargs = parse_kwarg_tuples(("--input_1", 1, "--input_2", "two", "--input_3", "2")) assert kwargs == dict(input_1=1, input_2="two", input_3=2) + + +def test_getattr_chain(): + base = SimpleNamespace() + base.a = None + base.b = SimpleNamespace() + base.b.c = "value" + base.b.d = None + + # test base cases + assert getattr_chain(base, "", None) is None + with pytest.raises(AttributeError): + getattr_chain(base, "") + + # test single layer + assert getattr_chain(base, "a") is None + assert getattr_chain(base, "a", "default") is None + assert getattr_chain(base, "b") == base.b + + assert getattr_chain(base, "dne", None) is None + with pytest.raises(AttributeError): + getattr_chain(base, "dne") + + # test multi layer + assert getattr_chain(base, "b.c") == "value" + assert getattr_chain(base, "b.d") is None + assert getattr_chain(base, "b.d", "default") is None + + assert getattr_chain(base, "b.d.dne", "default") == "default" + with pytest.raises(AttributeError): + getattr_chain(base, "b.d.dne") diff --git a/tests/testing_utils.py b/tests/testing_utils.py index d1d9494df..ca1f05d74 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -68,39 +68,49 @@ def _validate_test_config(config: dict): # Set cadence in the config. The environment must set if nightly, weekly or commit # tests are running def parse_params( - configs_directory: str, type: Optional[str] = None + configs_directory: Union[list, str], type: Optional[str] = None ) -> List[Union[dict, CustomTestConfig]]: - # parses the config file provided - assert os.path.isdir( - configs_directory - ), f"Config_directory {configs_directory} is not a directory" + # parses the config files provided config_dicts = [] - for file in os.listdir(configs_directory): - config = _load_yaml(configs_directory, file) - if not config: - continue - - cadence = os.environ.get("CADENCE", "commit") - expected_cadence = config.get("cadence") - - if not isinstance(expected_cadence, list): - expected_cadence = [expected_cadence] - if cadence in expected_cadence: - if type == "custom": - config = CustomTestConfig(**config) + + def _parse_configs_dir(current_config_dir): + assert os.path.isdir( + current_config_dir + ), f"Config_directory {current_config_dir} is not a directory" + + for file in os.listdir(current_config_dir): + config = _load_yaml(current_config_dir, file) + if not config: + continue + + cadence = os.environ.get("CADENCE", "commit") + expected_cadence = config.get("cadence") + + if not isinstance(expected_cadence, list): + expected_cadence = [expected_cadence] + if cadence in expected_cadence: + if type == "custom": + config = CustomTestConfig(**config) + else: + if not _validate_test_config(config): + raise ValueError( + "The config provided does not comply with the expected " + "structure. See tests.data.TestConfig for the expected " + "fields." + ) + config_dicts.append(config) else: - if not _validate_test_config(config): - raise ValueError( - "The config provided does not comply with the expected " - "structure. See tests.data.TestConfig for the expected " - "fields." - ) - config_dicts.append(config) - else: - logging.info( - f"Skipping testing model: {file} for cadence: {config['cadence']}" - ) + logging.info( + f"Skipping testing model: {file} for cadence: {config['cadence']}" + ) + + if isinstance(configs_directory, list): + for config in configs_directory: + _parse_configs_dir(config) + else: + _parse_configs_dir(configs_directory) + return config_dicts