From 98a5b73cc635ca9cbd703f0849ec0ff799b183fb Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 17:40:58 +0000 Subject: [PATCH 01/21] cherry picked files from stale PR #181 branch awq-feature-branch Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/__init__.py | 3 + src/llmcompressor/modifiers/awq/base.py | 720 ++++++++++++++++++ src/llmcompressor/pytorch/utils/helpers.py | 111 +++ .../transformers/finetune/data/__init__.py | 1 + .../transformers/finetune/data/pile.py | 45 ++ src/llmcompressor/utils/pytorch/module.py | 20 + tests/llmcompressor/modifiers/awq/__init__.py | 0 .../llmcompressor/modifiers/awq/test_base.py | 28 + .../pytorch/utils/test_helpers.py | 42 + .../finetune/data/test_registry.py | 17 + tests/llmcompressor/utils/pytorch/__init__.py | 0 .../utils/pytorch/test_module.py | 31 + 12 files changed, 1018 insertions(+) create mode 100644 src/llmcompressor/modifiers/awq/__init__.py create mode 100644 src/llmcompressor/modifiers/awq/base.py create mode 100644 src/llmcompressor/transformers/finetune/data/pile.py create mode 100644 tests/llmcompressor/modifiers/awq/__init__.py create mode 100644 tests/llmcompressor/modifiers/awq/test_base.py create mode 100644 tests/llmcompressor/utils/pytorch/__init__.py create mode 100644 tests/llmcompressor/utils/pytorch/test_module.py diff --git a/src/llmcompressor/modifiers/awq/__init__.py b/src/llmcompressor/modifiers/awq/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/awq/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py new file mode 100644 index 000000000..71767b0b4 --- /dev/null +++ b/src/llmcompressor/modifiers/awq/base.py @@ -0,0 +1,720 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from loguru import logger +from torch.nn import Module +from tqdm import tqdm + +from llmcompressor.core import Event, State +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.pytorch.utils import ( + clear_memory, + pseudo_quantize_tensor, + tensor_forward_with_input_args, +) +from llmcompressor.utils.fsdp.helpers import get_fsdp_parent +from llmcompressor.utils.pytorch.module import ( + get_layer, + get_layers, + get_matching_layer, + get_parent_by_name, +) + +DEFAULT_AWQ_MAPPINGS = [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], + [["re:.*down_proj"], "re:.*up_proj"], +] + +__all__ = ["AWQScale", "AWQMapping", "AWQModifier"] + + +@dataclass +class AWQScale: + """ + Dataclass for storing the input activations of a layer to be smoothed + """ + + inps: Union[List[torch.Tensor], torch.Tensor] + + +@dataclass +class AWQMapping: + """ + Dataclass for storing the mapping between an activation layer and the following + weights that must be balanced during smoothing + + :param smooth_name: name of the activation layer + :param smooth_layer: PyTorch module storing the activation layer + :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be + balanced to offset the smoothing of smooth_layer + :param balance_names: optional list of names of the balance_layers + :param parent: parent module of the balance_layers + :param parent_name: name of the parent module + """ + + smooth_name: str + smooth_layer: Module + balance_layers: List[Module] + balance_names: Optional[List[str]] = None + parent: Optional[Module] = None + parent_name: Optional[str] = None + + +class AWQModifier(Modifier): + """ + Implements the AWQ (Activation-Weighted Quantization) algorithm, + as described in https://arxiv.org/pdf/2306.00978. The algorithm + significantly reduces quantization error by protecting only 1% + of the most salient weight channels. + + Instead of focusing on the weight values directly, AWQ identifies + salient channels based on the activation distribution. + To further minimize quantization error, the algorithm scales up these + salient channels using an equivalent transformation. The scaling factor + is determined offline by collecting activation statistics + + Because this modifier manipulates the weights of the model, it can only be used in + in one-shot and not during training. Activation ranges are determined by running a + small set of calibration data through the model. + + example recipe: + ```yaml + AWQModifier: + bits: 4 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], + [["re:.*fc1"], "re:.*final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] + ``` + + :param mappings: list activation layers to smooth, and which layers to + scale the output such that activations are smoothed. + Each entry of the mapping list should be a list itself, in which the first + entry is a list of layers who share the same input activation (the one to be + to smoothed) and the second entry is the layer whose output is scaled to + achieve the smoothing. + If regex is used, it matches layers with the largest overlap in module name. + :param ignore: list of layers to ignore, even if they match a regex in mappings. + It should match the name of layers whose outputs are scaled to achieve + smoothing (the second entry of the mappings list). + :param num_calibration_steps: number of samples to use for calibration, or None to + use the whole dataset + :param calibration_function: optional function to use for the forward pass, or None + to use the default tensor_module_forward + :param group_size: number of weights to group together for scaling + :param max_chunk_memory: maximum memory to use for each chunk of input activations + :param bits: number of bits to quantize the weights to + :param symmetric: whether to use symmetric quantization + :param duo_scaling: whether to use duo scaling, which uses both input activations + and weights to determine the scaling factor + :param apply_clip: whether to apply clipping to the weights after scaling + """ + + mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS + ignore: Optional[List[str]] = None + num_calibration_steps: Optional[int] = None + calibration_function: Optional[Callable] = None + group_size: int = 128 + max_chunk_memory: int = 1024 * 1024 * 1024 + bits: int = 4 + symmetric: bool = True + duo_scaling: bool = True + apply_clip: bool = True + + hooks_: Optional[List] = None + resolved_mappings_: Optional[List] = None + scales_: Optional[Dict] = None + module_kwargs_: Optional[Dict] = None + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier + + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Initialize and run AWQ on the given state + + :param state: state to run AWQ on + :return: True on a successful run, False otherwise + """ + if not (self.end is None or self.end == -1): + raise ValueError( + f"{self.__class__.__name__} can only be applied during one-shot. " + f" Expected end to be None or -1, got {self.end}" + ) + if self.start and self.start != -1: + raise ValueError( + f"{self.__class__.__name__} can only be applied during one-shot. " + f"Expected start to be None or -1, got {self.end}" + ) + + self.ignore = [] if not self.ignore else self.ignore + self.resolved_mappings_ = self._resolve_mappings(state.model) + self.scales_ = {} + + calibration_dataloader = state.data.calib + self.hooks_ = [] + + self._get_module_kwargs(state.model, calibration_dataloader) + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._concat_collected_activations() + self._apply_smoothing(state.model) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + pass + + def on_update(self, state: State, event: Event, **kwargs): + pass + + def on_end(self, state: State, event: Event, **kwargs): + pass + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + Clean up by clearing the scale and mapping data + + :param state: unused + :return: True + """ + if self.scales_ is not None: + self.scales_.clear() + if self.resolved_mappings_ is not None: + self.resolved_mappings_.clear() + + return True + + def _resolve_mappings(self, model: Module) -> List: + """ + Transforms the list of activations to smooth and their corresponding weights + into AWQMapping objects, resolving regular expressions. + + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on + """ + resolved_mappings = [] + for to_balance, to_smooth in self.mappings: + to_smooth_layers = get_layers(to_smooth, model) + for layer_name, smooth_layer in to_smooth_layers.items(): + if layer_name not in self.ignore: + balance_layers, balance_names = [], [] + for balance_suffix in to_balance: + # find the submodule that matches the activation layer + balance_name, balance_layer = get_matching_layer( + balance_suffix, layer_name, model + ) + if balance_layer: + balance_layers.append(balance_layer) + balance_names.append(balance_name) + + # each mapping can contain multiple layers to balance, but only + # one layer to smooth + + if len(balance_layers) == 1: + # for single balance layer, parent is the balance layer + parent_name, parent = balance_name, balance_layer + else: + # for multiple balance layers, + # parent of any balance layer is the parent + parent_name, parent = get_parent_by_name( + layer_name=balance_name, model=model + ) + mapping = AWQMapping( + layer_name, + smooth_layer, + balance_layers, + balance_names=balance_names, + parent=parent, + parent_name=parent_name, + ) + resolved_mappings.append(mapping) + return resolved_mappings + + def _setup_scale_hooks(self): + """ + Attach a forward hook to each activation we want to smooth. This allows us to + calculate the dynamic range during calibration + """ + + def create_hook_fn(layer_name): + def hook_fn(module, inp, out): + inp = inp[0] + inp.cpu().detach() + + if layer_name in self.scales_: + self.scales_[layer_name].inps.append(inp) + else: + self.scales_[layer_name] = AWQScale(inps=[inp]) + + return hook_fn + + for mapping in self.resolved_mappings_: + name = mapping.smooth_name + # storing inps to first balance layer + # is enough, as other balance layers + # get the same input + layer = mapping.balance_layers[0] + self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + + @torch.no_grad() + def _calibrate(self, model: Module, calibration_dataloader: List): + """ + Catch the output dynamic ranges of each layer that will be smoothed by running + forward passes with calibration_dataloader + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + logger.info( + f"Running {class_name} calibration with " + f"{len(calibration_dataloader)} samples..." + ) + if not calibration_dataloader: + raise ValueError( + "Calibration data loader not set, must populate the calib_data field of" + " CompressionSession to run the AWQ modifier" + ) + + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) + + # remove the hooks now that we are done calibrating + for hook in self.hooks_: + hook.remove() + del self.hooks_ + + def _concat_collected_activations(self): + """ + Concatenate the collected activation values from each forward pass into a single + tensor for each layer + + :postcondition: each layer in self.scales_ will have a single tensor containing + all the activation values seen during calibration + """ + for mapping in self.resolved_mappings_: + name = mapping.smooth_name + self.scales_[name].inps = torch.cat(self.scales_[name].inps, dim=0) + + torch.cuda.empty_cache() + + @torch.no_grad() + def _apply_smoothing(self, model: Module): + """ + Calculate the best scaling factors for each layer to smooth activations and + apply the scaling factors to the weights of the next layer to offset the + smoothing + + :param model: model to apply smoothing to + """ + logger.info("Smoothing activation scales...") + for mapping in tqdm(self.resolved_mappings_): + smooth_layer = mapping.smooth_layer + balance_layers = mapping.balance_layers + balance_names = mapping.balance_names + + activations = self.scales_[mapping.smooth_name].inps + + module2inspect = mapping.parent + + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([_m.weight for _m in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + weight = weight.view(-1, self.group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) + # Resizes the rescaled weight matrix back up to its original dimensions + w_scale = w_scale.view(org_shape) + # Gets the average rescaled magnitude for each output channel + w_mean = w_scale.mean(0) + + # [STEP 2]: Compute per-channel mean of the input activation with chunking + # move inp to cpu to avoid memory leak + inp = activations + inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + num_elements = inp_flat.size(0) + num_channels = inp_flat.size(1) + element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 + + # Calculate chunk size dynamically based on max_chunk_memory + chunk_size = int( + self.max_chunk_memory // (element_size_bytes * num_channels) + ) + chunk_size = min(chunk_size, num_elements) + + # Use float32 for sum calculation + x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device) + + for i in range(0, num_elements, chunk_size): + end = min(i + chunk_size, num_elements) + chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0) + x_sum += chunk_sum.to(inp.device) + + x_mean = (x_sum / num_elements).to(inp.dtype) + + # [STEP 3]: Compute output of module + with torch.no_grad(): + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + ) + + # [STEP 4]: Compute loss + best_scales = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output + ) + + scales = best_scales + + @torch.no_grad() + def smooth(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1).to(module.weight.device)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales.to(module.weight.device)) + else: + module.weight.div_(scales.view(-1, 1).to(module.weight.device)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales.to(module.bias.device)) + + parent = get_fsdp_parent(mapping.smooth_name, model) + if parent is not None: + parent.apply(smooth) + else: + # if we're not running with FSDP we can apply smoothing directly + for layer in balance_layers: + smooth(layer) + smooth(smooth_layer) + + if self.apply_clip: + clip_list = self._search_best_clip( + balance_layers=balance_layers, + balance_names=balance_names, + input_feat=inp, + ) + + _apply_clip(model, clip_list) + + # clear out allocated smoothing scales + torch.cuda.empty_cache() + + def _compute_best_scale( + self, + x: torch.Tensor, + w_mean: torch.Tensor, + x_mean: torch.Tensor, + module2inspect: torch.nn.Module, + linears2scale: List[torch.nn.Linear], + fp16_output: torch.Tensor, + ): + """ + Compute loss and select best scales + + L(s) = || Q(W * s) (s^-1 * X) - W * X || + Q: weight quantization function | pseudo_quantize_tensor(W * s) + X: inputs from calib dataset | X + W: original weights in FP16 | layer + s: per channel scaling factor | s^-1 * X + """ + n_grid = 20 + history = [] + best_ratio = -1 + best_scales = None + best_error = float("inf") + + org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} + + device = x.device + x_mean = x_mean.view(-1).to(device) + w_mean = w_mean.view(-1).to(device) + + for ratio in range(n_grid): + # create new scales + ratio = ratio / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if self.duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + scales_view = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for fc in linears2scale: + fc.weight.mul_(scales_view) + fc.weight.data = ( + pseudo_quantize_tensor( + w=fc.weight.data, + symmetric=self.symmetric, + bit_width=self.bits, + group_size=self.group_size, + )[0] + / scales_view + ) + + # W * X + int_w_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ + ) + + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_output, int_w_output, device) + + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() + module2inspect.load_state_dict(org_sd) + + if best_ratio == -1: + logger.debug(history) + raise Exception + + assert torch.isnan(best_scales).sum() == 0, best_scales + + return best_scales.detach().cpu() + + @torch.no_grad() + def _compute_loss( + self, + fp16_output: torch.Tensor, + int_w_output: torch.Tensor, + device: torch.device, + ): + loss = 0.0 + fp16_output_flat = fp16_output.view(-1) + int_w_output_flat = int_w_output.view(-1) + num_elements = fp16_output_flat.size(0) + element_size_bytes = fp16_output.element_size() + + # Calculate chunk size dynamically based on max_chunk_memory + # Divide the max_chunk_memory by twice the element size + chunk_size = self.max_chunk_memory // (element_size_bytes * 2) + chunk_size = min(chunk_size, num_elements) + + # Split the computation into chunks + fp16_chunks = torch.split(fp16_output_flat, chunk_size) + int_w_chunks = torch.split(int_w_output_flat, chunk_size) + + # Compute the loss for each chunk + for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks): + chunk_loss = ( + (fp16_chunk.to(device) - int_w_chunk.to(device)) + .float() + .pow(2) + .sum() + .item() + ) + loss += chunk_loss + + # Normalize the loss by the total number of elements + loss /= num_elements + + return loss + + def _get_module_kwargs(self, model, dataloader): + _, modules = next(iter(get_layers("re:.*layers", model).items())) + + samples = [batch["input_ids"] for batch in dataloader] + + samples = torch.cat(samples, dim=0) + + inps = [] + layer_kwargs = {} + + best_device = "cuda" + modules[0] = modules[0].to(best_device) + # self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + try: + model(samples.to(next(model.parameters()).device)) + except ValueError: # work with early exit + pass + modules[0] = modules[0].module # restore + + # Update the layer kwargs with `prepare_inputs_for_generation` method + # that takes care of everything to avoid unexpected errors. + layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) + # Pop the input_ids as they are not needed at all. + layer_kwargs.pop("input_ids") + + del samples + inps = inps[0] + + torch.cuda.empty_cache() + + if layer_kwargs.get("attention_mask") is not None: + layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( + best_device + ) + + self.module_kwargs_ = layer_kwargs + + def _forward_input_with_kwargs( + self, + module: Module, + inputs: torch.Tensor, + input_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Forward pass with input arguments + + :param module: module to run forward pass on + :param inputs: input tensor to pass to the module + :param input_kwargs: additional arguments to pass to the module + :return: the first output tensor from the forward pass + """ + kwargs = input_kwargs or self.module_kwargs_ or {} + return tensor_forward_with_input_args( + module=module, + inputs=inputs, + input_kwargs=kwargs, + )[0] + + @torch.no_grad() + def _search_best_clip(self, balance_layers, balance_names, input_feat): + clip_list = [] + avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] + + for name, layer in zip(balance_names, balance_layers): + # due to qk bmm, it is hard to clip precisely + if any([_ in name for _ in avoid_clipping]): + continue + + max_val = self._compute_best_clip(layer.weight, input_feat) + clip_list.append((name, max_val)) + + return clip_list + + @torch.no_grad() + def _compute_best_clip( + self, + w: torch.Tensor, + input_feat: torch.Tensor, + n_grid=20, + max_shrink=0.5, + n_sample_token=512, + ): + assert w.dim() == 2 + org_w_shape = w.shape + # w [co, ci] -> [co, 1, n_group, group size] + # input_feat [n_token, ci] -> [1, n_token, n_group, group size] + group_size = self.group_size if self.group_size > 0 else org_w_shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + + # Compute input feature step size (minimum 1) + step_size = max(1, input_feat.shape[1] // n_sample_token) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(org_w_shape[0], 1, -1, group_size) + + oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM + assert org_w_shape[0] % oc_batch_size == 0 + w_all = w + best_max_val_all = [] + + for i_b in range(org_w_shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w.device) + org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + + for i_s in range(int(max_shrink * n_grid)): + max_val = org_max_val * (1 - i_s / n_grid) + min_val = -max_val + cur_w = torch.clamp(w, min_val, max_val) + q_w = pseudo_quantize_tensor( + w=cur_w, + symmetric=self.symmetric, + group_size=group_size, + bit_width=self.bits, + )[0] + cur_out = (input_feat * q_w).sum(dim=-1) + + # co, 1, n_group, 1 + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + del cur_w + del cur_out + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + best_max_val = torch.cat(best_max_val_all, dim=0) + + clear_memory(input_feat) + clear_memory(org_out) + + return best_max_val.squeeze(1) + + +@torch.no_grad() +def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): + """ + Apply clipping to the weights of the given module + + :post-condition: the weights of the module are clipped to the given maximum values + :param module: module to apply clipping to + :param clip_list: list of tuples containing the name of the layer and the maximum + value to clip the weights to + """ + for name, max_val in clip_list: + _, layer = get_layer(target=name, module=module) + assert isinstance(layer, torch.nn.Linear) + max_val = max_val.to(layer.weight.device) + org_shape = layer.weight.shape + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) + layer.weight.data = layer.weight.data.reshape(org_shape) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 1a0724e6c..38071f4d0 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -2,6 +2,9 @@ Utility / helper functions """ +import functools +import gc +import inspect import os import random import re @@ -85,6 +88,11 @@ "detach", "adjust_quantization_for_onnx_export", "get_dependency_order", + "pseudo_quantize_tensor", + "pseudo_dequantize_linear", + "tensor_forward_with_input_args", + "sanitize_kwargs_for_module", + "clear_memory", ] @@ -680,6 +688,43 @@ def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor: return -1.0 * newly_masked + newly_unmasked +def sanitize_kwargs_for_module( + kwargs: Dict[str, Any], module: Module +) -> Dict[str, Any]: + """ + Sanitize the kwargs for a Module by removing any keys that are not + in the signature of the forward method. + :param kwargs: the kwargs to sanitize + :param module: the Module to sanitize the kwargs for + :return: the sanitized kwargs for the callable object + """ + if not isinstance(kwargs, dict): + raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}") + + allowed_params = inspect.signature(module.forward).parameters + return {key: value for key, value in kwargs.items() if key in allowed_params} + + +def tensor_forward_with_input_args( + module: Module, inputs: Tensor, input_kwargs: Dict[str, Any] +) -> Tensor: + """ + Forward the given inputs through the given module with the given input_kwargs. + This function is a wrapper around tensors_module_forward that ensures that the + input_kwargs are sanitized and passed to the module as keyword arguments during + the forward pass. + :param module: the module to forward the inputs through + :param inputs: the inputs to forward through the module + :param input_kwargs: the keyword arguments to pass to the + module during the forward pass + :return: the output of the module after forwarding the inputs through it + """ + inputs = inputs.to(next(module.parameters()).device) + input_kwargs = sanitize_kwargs_for_module(input_kwargs, module) + + return tensors_module_forward(inputs, functools.partial(module, **input_kwargs)) + + ############################## # # pytorch module helper functions @@ -1194,3 +1239,69 @@ def swap_modules( parent.__setattr__(sections[-1], submodule_to_replace) return cur + + +def pseudo_quantize_tensor( + w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 +): + org_w_shape = w.shape + if group_size > 0: + assert org_w_shape[-1] % group_size == 0 + w = w.reshape(-1, group_size) + assert w.dim() == 2 + assert torch.isnan(w).sum() == 0 + + if not symmetric: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**bit_width - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = zeros.view(org_w_shape[0], -1) + else: + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (bit_width - 1) - 1 + min_int = -(2 ** (bit_width - 1)) + scales = max_val / max_int + zeros = None + w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + + return w, scales, zeros + + +def pseudo_dequantize_linear( + w: torch.Tensor, + scales: torch.Tensor, + zeros: Optional[torch.Tensor] = None, + symmetric: bool = False, +): + # get repeated count + repeat_count = w.weight.data.shape[-1] // scales.shape[-1] + scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) + + # dequantize + if not symmetric: + zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) + w = (w.weight.data - zeros) * scales + else: + w = w.weight.data * scales + + return w + + +def clear_memory(value: Optional[Any] = None): + if value is not None: + del value + gc.collect() + torch.cuda.empty_cache() diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index a53caed1b..b72efa7c7 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -8,6 +8,7 @@ from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset +from .pile import PileEvalDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py new file mode 100644 index 000000000..b3a99ea0e --- /dev/null +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -0,0 +1,45 @@ +from copy import deepcopy +from typing import Optional + +from llmcompressor.transformers.finetune.data import TextGenerationDataset + + +@TextGenerationDataset.register(name="pile_eval") +class PileEvalDataset(TextGenerationDataset): + """ + Child text generation class for the PileEval dataset + :param data_args: configuration settings for dataset loading + :param split: split from dataset to load, for instance `test` or `train[:5%]` + :param tokenizer: tokenizer to use on dataset + """ + + def __init__(self, data_args, split, tokenizer): + data_args = deepcopy(data_args) + data_args.dataset = "mit-han-lab/pile-val-backup" + super().__init__( + text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + ) + + def get_raw_dataset(self, cache_dir: Optional[str] = None): + """ + Load the raw dataset from Hugging Face, using cached copy if available. + Additionally reformats the entries to fit the template. + :param cache_dir: disk location to search for cached dataset + :return: the requested dataset + """ + raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) + + def restructure_fn(sample): + sample["text"] = sample["text"].strip() + return sample + + raw_dataset = self.map( + raw_dataset, + function=restructure_fn, + batched=False, + remove_columns=["meta"], + num_proc=self.data_args.preprocessing_num_workers, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Restructuring Pile Dataset", + ) + return raw_dataset diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 8f7eadb53..c980f00c8 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -60,6 +60,7 @@ "get_layers_params", "get_matching_layer", "get_no_split_params", + "get_parent_by_name", ] @@ -338,3 +339,22 @@ def get_no_split_params(module: Module) -> Union[str, List[str]]: if hasattr(model, "_no_split_modules"): return model._no_split_modules return ALL_TARGET + + +def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]: + """ + Get the parent layer of a layer by name. + :param layer_name: Name of the layer to find the parent of. + :param model: Model to search for the parent layer. + :return: Tuple containing the name of the parent layer + and the parent layer itself. + """ + if not any(layer_name == name for name, _ in model.named_modules()): + raise ValueError(f"Layer '{layer_name}' not found in model") + + parent_name_parts = layer_name.split(".")[:-1] + if not parent_name_parts: + return "", model + + parent_name = ".".join(parent_name_parts) + return get_layer(parent_name, model) diff --git a/tests/llmcompressor/modifiers/awq/__init__.py b/tests/llmcompressor/modifiers/awq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py new file mode 100644 index 000000000..918238718 --- /dev/null +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -0,0 +1,28 @@ +import unittest + +import pytest + +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.factory import ModifierFactory +from tests.llmcompressor.modifiers.conf import setup_modifier_factory + + +@pytest.mark.unit +class TestAWQIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = {} + setup_modifier_factory() + + def test_awq_is_registered(self): + modifier = ModifierFactory.create( + type_="AWQModifier", + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + modifier, + AWQModifier, + "PyTorch AWQModifier not registered", + ) diff --git a/tests/llmcompressor/pytorch/utils/test_helpers.py b/tests/llmcompressor/pytorch/utils/test_helpers.py index e2f0133f1..b683cab55 100644 --- a/tests/llmcompressor/pytorch/utils/test_helpers.py +++ b/tests/llmcompressor/pytorch/utils/test_helpers.py @@ -16,9 +16,11 @@ get_optim_learning_rate, mask_difference, memory_aware_threshold, + sanitize_kwargs_for_module, set_optim_learning_rate, tensor_density, tensor_export, + tensor_forward_with_input_args, tensor_sample, tensor_sparsity, tensors_batch_size, @@ -855,3 +857,43 @@ def test_memory_aware_threshold(tensor, idx): if prior_state is not None: os.environ[MEMORY_BOUNDED] = prior_state + + +class TestSanitizeKwargsForModule: + @pytest.fixture + def module(self): + return Linear(10, 20) + + def test_sanitize_kwargs_for_module_not_dict(self, module): + # Test with kwargs that are not a dictionary + with pytest.raises(TypeError): + sanitize_kwargs_for_module("not a dictionary", module) + + def test_sanitize_kwargs_for_module_not_in_signature(self, module): + # Test with kwargs that are not in the signature of the forward method + kwargs = {"not_in_signature": 123} + sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module) + assert sanitized_kwargs == {} + + def test_sanitize_kwargs_for_module_in_signature(self, module): + # Test with kwargs that are in the signature of the forward method + kwargs = {"input": torch.randn(1, 10)} + sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module) + assert sanitized_kwargs == kwargs + + +class TestTensorForwardWithInputArgs: + @pytest.fixture + def module(self): + return Linear(10, 20) + + def test_tensor_forward_with_input_args(self, module): + # Test with valid inputs and input_kwargs + inputs = torch.randn(1, 10) + input_kwargs = {} + output = tensor_forward_with_input_args(module, inputs, input_kwargs) + assert output.shape == (1, 20) + + # Test with input_kwargs that are not in the signature of the forward method + input_kwargs = {"not_in_signature": 123} + tensor_forward_with_input_args(module, inputs, input_kwargs) diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 694a9b6d3..d2391e61d 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -4,6 +4,7 @@ from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, + PileEvalDataset, TextGenerationDataset, WikiTextDataset, ) @@ -57,3 +58,19 @@ def test_open_platypus_initializes(tiny_llama_tokenizer): assert op_manager.data_args.text_column == "text" assert not op_manager.padding assert op_manager.max_seq_length == data_args.max_seq_length + + +@pytest.mark.usefixtures("tiny_llama_tokenizer") +def test_pile_eval_initializes(tiny_llama_tokenizer): + data_args = DatasetArguments(dataset="pile_eval", pad_to_max_length=False) + pile_eval_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=None, + tokenizer=tiny_llama_tokenizer, + ) + assert isinstance(pile_eval_manager, TextGenerationDataset) + assert isinstance(pile_eval_manager, PileEvalDataset) + assert pile_eval_manager.text_column == "text" + assert not pile_eval_manager.padding + assert pile_eval_manager.max_seq_length == data_args.max_seq_length diff --git a/tests/llmcompressor/utils/pytorch/__init__.py b/tests/llmcompressor/utils/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py new file mode 100644 index 000000000..4600377fa --- /dev/null +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -0,0 +1,31 @@ +import unittest + +import torch.nn as nn + +from llmcompressor.utils.pytorch import get_parent_by_name + + +class TestGetParentByName(unittest.TestCase): + def setUp(self): + self.model = nn.Sequential( + nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.Softmax(dim=1) + ) + + def test_get_parent_by_name(self): + # Test getting the parent of a non-existent layer + with self.assertRaises(ValueError): + get_parent_by_name("non_existent_layer", self.model) + + # Test getting the parent of the first layer + name, parent = get_parent_by_name("0", self.model) + self.assertEqual(parent, self.model) + + # Test getting the parent of a nested layer + nested_model = nn.Sequential( + nn.Linear(10, 20), + nn.Sequential(nn.ReLU(), nn.Linear(20, 10)), + nn.Softmax(dim=1), + ) + name, parent = get_parent_by_name("1.1", nested_model) + self.assertEqual(parent, nested_model[1]) + self.assertEqual(name, "1") From 261196628e422655143ad4b2466c0d4f5ba851da Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 20:32:53 +0000 Subject: [PATCH 02/21] updated to be compatible with latest, unit tests passing Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 26 ++++-------- src/llmcompressor/pytorch/utils/helpers.py | 7 ---- .../transformers/finetune/data/pile.py | 42 +++++++------------ .../finetune/data/test_registry.py | 4 +- 4 files changed, 23 insertions(+), 56 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 71767b0b4..07ad3093f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,7 +10,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( - clear_memory, pseudo_quantize_tensor, tensor_forward_with_input_args, ) @@ -83,12 +82,12 @@ class AWQModifier(Modifier): example recipe: ```yaml AWQModifier: - bits: 4 - mappings: [ + bits: 4 + mappings: [ [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], [["re:.*fc1"], "re:.*final_layer_norm"] - ] - ignore: ["model.decoder.final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] ``` :param mappings: list activation layers to smooth, and which layers to @@ -166,18 +165,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True - def on_start(self, state: State, event: Event, **kwargs): - pass - - def on_update(self, state: State, event: Event, **kwargs): - pass - - def on_end(self, state: State, event: Event, **kwargs): - pass - - def on_event(self, state: State, event: Event, **kwargs): - pass - def on_finalize(self, state: State, **kwargs) -> bool: """ Clean up by clearing the scale and mapping data @@ -694,8 +681,9 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - clear_memory(input_feat) - clear_memory(org_out) + #TODO this appears unneeded, clear_memory removed + # clear_memory(input_feat) + # clear_memory(org_out) return best_max_val.squeeze(1) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 38071f4d0..00de09d40 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -92,7 +92,6 @@ "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", - "clear_memory", ] @@ -1299,9 +1298,3 @@ def pseudo_dequantize_linear( return w - -def clear_memory(value: Optional[Any] = None): - if value is not None: - del value - gc.collect() - torch.cuda.empty_cache() diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index b3a99ea0e..4eef5f7eb 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,7 +1,11 @@ from copy import deepcopy -from typing import Optional +from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.typing import Processor + +if TYPE_CHECKING: + from llmcompressor.args import DatasetArguments @TextGenerationDataset.register(name="pile_eval") @@ -13,33 +17,15 @@ class PileEvalDataset(TextGenerationDataset): :param tokenizer: tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) + data_args.text_column = "text" data_args.dataset = "mit-han-lab/pile-val-backup" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) - - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the template. - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) - - def restructure_fn(sample): - sample["text"] = sample["text"].strip() - return sample + super().__init__(data_args=data_args, split=split, processor=processor) - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["meta"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring Pile Dataset", - ) - return raw_dataset + def dataset_template(self, sample): + return { + "text": self.processor.apply_chat_template( + sample["text"].strip(), + ), + } diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index d2391e61d..3a2540eb9 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -67,10 +67,10 @@ def test_pile_eval_initializes(tiny_llama_tokenizer): data_args.dataset, data_args=data_args, split=None, - tokenizer=tiny_llama_tokenizer, + processor=tiny_llama_tokenizer, ) assert isinstance(pile_eval_manager, TextGenerationDataset) assert isinstance(pile_eval_manager, PileEvalDataset) - assert pile_eval_manager.text_column == "text" + assert pile_eval_manager.data_args.text_column == "text" assert not pile_eval_manager.padding assert pile_eval_manager.max_seq_length == data_args.max_seq_length From 88aeab8f6cca9bc65b9a999e5d017e30584b40c8 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 21:20:02 +0000 Subject: [PATCH 03/21] switch to using HooksMixin api Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 44 +++++++++++-------- .../modifiers/smoothquant/base.py | 4 +- src/llmcompressor/pytorch/utils/helpers.py | 2 - 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 07ad3093f..9a9155e7b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,11 +2,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger from torch.nn import Module from tqdm import tqdm -from llmcompressor.core import Event, State +from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( @@ -14,6 +15,7 @@ tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent +from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( get_layer, get_layers, @@ -124,14 +126,10 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - hooks_: Optional[List] = None - resolved_mappings_: Optional[List] = None + resolved_mappings_: Optional[List[AWQMapping]] = None scales_: Optional[Dict] = None module_kwargs_: Optional[Dict] = None - def on_initialize_structure(self, state: State, **kwargs): - pass # nothing needed for this modifier - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run AWQ on the given state @@ -155,7 +153,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.scales_ = {} calibration_dataloader = state.data.calib - self.hooks_ = [] self._get_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() @@ -179,7 +176,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _resolve_mappings(self, model: Module) -> List: + def _resolve_mappings(self, model: Module) -> List[AWQMapping]: """ Transforms the list of activations to smooth and their corresponding weights into AWQMapping objects, resolving regular expressions. @@ -252,7 +249,7 @@ def hook_fn(module, inp, out): # is enough, as other balance layers # get the same input layer = mapping.balance_layers[0] - self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + self.register_hook(layer, create_hook_fn(name), "forward") @torch.no_grad() def _calibrate(self, model: Module, calibration_dataloader: List): @@ -271,17 +268,16 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) + with calibration_forward_context(model): + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) # remove the hooks now that we are done calibrating - for hook in self.hooks_: - hook.remove() - del self.hooks_ + self.remove_hooks() def _concat_collected_activations(self): """ @@ -370,6 +366,13 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): + # TODO calls to module._hf_hook.pre_forward(module) and + # module._hf_hook.post_forward(module, None) appear a couple places + # in SmoothQuantModifier, do we need them anywhere else? + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + if module in balance_layers: module.weight.mul_(scales.view(1, -1).to(module.weight.device)) elif module == smooth_layer: @@ -380,6 +383,9 @@ def smooth(module): if hasattr(module, "bias") and module.bias is not None: module.bias.div_(scales.to(module.bias.device)) + if offloaded: + module._hf_hook.post_forward(module, None) + parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: parent.apply(smooth) @@ -681,7 +687,7 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - #TODO this appears unneeded, clear_memory removed + # TODO this appears unneeded, clear_memory removed # clear_memory(input_feat) # clear_memory(org_out) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index aa3317198..71b9bd9f6 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -105,7 +105,7 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - resolved_mappings_: Optional[List] = None + resolved_mappings_: Optional[List[SmoothQuantMapping]] = None scales_: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: @@ -166,7 +166,7 @@ def _infer_mappings_from_model( ) @handle_mapping_resolution_errors - def _resolve_mappings(self, model: Module) -> List: + def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: """ Transforms the list of activations to smooth and their corresponding weights into SmoothQuantMapping objects, resolving regular expressions. diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 00de09d40..094ef0b8b 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -3,7 +3,6 @@ """ import functools -import gc import inspect import os import random @@ -1297,4 +1296,3 @@ def pseudo_dequantize_linear( w = w.weight.data * scales return w - From 2b74ccf0fdc0abc40f7e88146a5f25088c8834d2 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 22:56:14 +0000 Subject: [PATCH 04/21] pydantic serialization issue fix Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 4 ++++ src/llmcompressor/modifiers/smoothquant/base.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 9a9155e7b..a0608b573 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -4,6 +4,7 @@ import torch from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger +from pydantic import ConfigDict from torch.nn import Module from tqdm import tqdm @@ -115,6 +116,9 @@ class AWQModifier(Modifier): :param apply_clip: whether to apply clipping to the weights after scaling """ + # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 71b9bd9f6..845798f07 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -4,6 +4,7 @@ import torch from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger +from pydantic import ConfigDict from torch.nn import Module from llmcompressor.core import State @@ -99,6 +100,9 @@ class SmoothQuantModifier(Modifier): to use the default tensor_module_forward """ + # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + smoothing_strength: float = 0.5 mappings: Optional[List[Union[Tuple, List]]] = None ignore: Optional[List[str]] = None From cb5956ed2557d803d414ff79c3ff786e31e75a8c Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Feb 2025 18:00:07 +0000 Subject: [PATCH 05/21] switch to accelerate with align_module_device Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 68 +++++++++---------- .../modifiers/smoothquant/base.py | 40 ++++------- 2 files changed, 45 insertions(+), 63 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index a0608b573..ebd5b0936 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils.offload import is_module_offloaded +from accelerate.utils import align_module_device from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -318,7 +318,7 @@ def _apply_smoothing(self, model: Module): # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together - weight = torch.cat([_m.weight for _m in balance_layers], dim=0) + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) org_shape = weight.shape # The weights are reshaped to be organised by quantization group weight = weight.view(-1, self.group_size) @@ -373,22 +373,18 @@ def smooth(module): # TODO calls to module._hf_hook.pre_forward(module) and # module._hf_hook.post_forward(module, None) appear a couple places # in SmoothQuantModifier, do we need them anywhere else? - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - if module in balance_layers: - module.weight.mul_(scales.view(1, -1).to(module.weight.device)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales.to(module.weight.device)) - else: - module.weight.div_(scales.view(-1, 1).to(module.weight.device)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales.to(module.bias.device)) - - if offloaded: - module._hf_hook.post_forward(module, None) + with align_module_device(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1).to(module.weight.device)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales.to(module.weight.device)) + else: + module.weight.div_( + scales.view(-1, 1).to(module.weight.device) + ) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales.to(module.bias.device)) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -461,16 +457,17 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: - fc.weight.mul_(scales_view) - fc.weight.data = ( - pseudo_quantize_tensor( - w=fc.weight.data, - symmetric=self.symmetric, - bit_width=self.bits, - group_size=self.group_size, - )[0] - / scales_view - ) + with align_module_device(fc): + fc.weight.mul_(scales_view) + fc.weight.data = ( + pseudo_quantize_tensor( + w=fc.weight.data, + symmetric=self.symmetric, + bit_width=self.bits, + group_size=self.group_size, + )[0] + / scales_view + ) # W * X int_w_output = self._forward_input_with_kwargs( @@ -691,10 +688,6 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - # TODO this appears unneeded, clear_memory removed - # clear_memory(input_feat) - # clear_memory(org_out) - return best_max_val.squeeze(1) @@ -711,8 +704,9 @@ def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): for name, max_val in clip_list: _, layer = get_layer(target=name, module=module) assert isinstance(layer, torch.nn.Linear) - max_val = max_val.to(layer.weight.device) - org_shape = layer.weight.shape - layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) - layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) - layer.weight.data = layer.weight.data.reshape(org_shape) + with align_module_device(layer): + max_val = max_val.to(layer.weight.device) + org_shape = layer.weight.shape + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) + layer.weight.data = layer.weight.data.reshape(org_shape) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 845798f07..1b1e0aee6 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils.offload import is_module_offloaded +from accelerate.utils import align_module_device from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -293,22 +293,16 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - if module in balance_layers: - module.weight.mul_(scales.view(1, -1)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales) - else: - module.weight.div_(scales.view(-1, 1)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales) - - if offloaded: - module._hf_hook.post_forward(module, None) + with align_module_device(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales) + else: + module.weight.div_(scales.view(-1, 1)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -333,15 +327,9 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: - offloaded = is_module_offloaded(layer) - if offloaded: - layer._hf_hook.pre_forward(layer) - - scale = layer.weight.abs().max(dim=0, keepdim=True)[0] - weight_scales.append(scale) - - if offloaded: - layer._hf_hook.post_forward(layer, None) + with align_module_device(layer): + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] From 5cb055ce0adad9053bcc3e38ebae14675cd3a2b6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Feb 2025 23:11:09 +0000 Subject: [PATCH 06/21] AWQ running but OOMs unless NUM_CALIBRATION_SAMPLES and MAX_SEQUENCE_LENGTH are very low Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 6 +++--- src/llmcompressor/transformers/finetune/data/__init__.py | 2 +- src/llmcompressor/transformers/finetune/data/pile.py | 8 +++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ebd5b0936..6675291bc 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -158,7 +158,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: calibration_dataloader = state.data.calib - self._get_module_kwargs(state.model, calibration_dataloader) + self._set_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) self._concat_collected_activations() @@ -530,7 +530,7 @@ def _compute_loss( return loss - def _get_module_kwargs(self, model, dataloader): + def _set_module_kwargs(self, model, dataloader) -> None: _, modules = next(iter(get_layers("re:.*layers", model).items())) samples = [batch["input_ids"] for batch in dataloader] @@ -575,7 +575,7 @@ def forward(self, *args, **kwargs): # Update the layer kwargs with `prepare_inputs_for_generation` method # that takes care of everything to avoid unexpected errors. - layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) + layer_kwargs |= model.prepare_inputs_for_generation(samples, **layer_kwargs) # Pop the input_ids as they are not needed at all. layer_kwargs.pop("input_ids") diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index b72efa7c7..186a076bd 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -8,7 +8,7 @@ from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset -from .pile import PileEvalDataset +from .pile import PileValDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index 4eef5f7eb..ccdb92056 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,6 +1,8 @@ from copy import deepcopy from typing import TYPE_CHECKING +from loguru import logger + from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.typing import Processor @@ -8,10 +10,10 @@ from llmcompressor.args import DatasetArguments -@TextGenerationDataset.register(name="pile_eval") -class PileEvalDataset(TextGenerationDataset): +@TextGenerationDataset.register(name="mit-han-lab/pile-val-backup", alias="pile_val") +class PileValDataset(TextGenerationDataset): """ - Child text generation class for the PileEval dataset + Child text generation class for "The Pile" dataset :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param tokenizer: tokenizer to use on dataset From 28f8bca24d542ce6854f9924dc556e2c03eee156 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Feb 2025 17:15:54 +0000 Subject: [PATCH 07/21] working with larger num_calibration_samples Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++++----------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6675291bc..5e985559b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from accelerate.utils import align_module_device +from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool: calibration_dataloader = state.data.calib - self._set_module_kwargs(state.model, calibration_dataloader) - self._setup_scale_hooks() - self._calibrate(state.model, calibration_dataloader) - self._concat_collected_activations() - self._apply_smoothing(state.model) + # TODO is it ok to wrap the whole model in this context? + # I don't think we ever want gradients or to use kv cache + with calibration_forward_context(state.model): + self._set_module_kwargs(state.model, calibration_dataloader) + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._concat_collected_activations() + self._apply_smoothing(state.model) return True @@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - with calibration_forward_context(model): - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) + # with calibration_forward_context(model): + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) # remove the hooks now that we are done calibrating self.remove_hooks() @@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module): x_mean = (x_sum / num_elements).to(inp.dtype) # [STEP 3]: Compute output of module - with torch.no_grad(): - fp16_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ - ) + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + ) # [STEP 4]: Compute loss best_scales = self._compute_best_scale( @@ -459,14 +461,16 @@ def _compute_best_scale( for fc in linears2scale: with align_module_device(fc): fc.weight.mul_(scales_view) - fc.weight.data = ( + update_offload_parameter( + fc, + "weight", pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view + / scales_view, ) # W * X @@ -488,7 +492,9 @@ def _compute_best_scale( logger.debug(history) raise Exception - assert torch.isnan(best_scales).sum() == 0, best_scales + assert ( + torch.isnan(best_scales).sum() == 0 + ), f"Nan found in scales: {best_scales}" return best_scales.detach().cpu() From 2226bfdf1673d845f05d75589c48a5e21c80798c Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Feb 2025 18:37:18 +0000 Subject: [PATCH 08/21] fix pile dataset issue Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 5 ++++- src/llmcompressor/transformers/finetune/data/pile.py | 8 +------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 5e985559b..160f2106a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -25,7 +25,10 @@ ) DEFAULT_AWQ_MAPPINGS = [ - [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [ + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"], + "re:.*input_layernorm", + ], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], [["re:.*down_proj"], "re:.*up_proj"], ] diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index ccdb92056..f420ba2a5 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,8 +1,6 @@ from copy import deepcopy from typing import TYPE_CHECKING -from loguru import logger - from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.typing import Processor @@ -26,8 +24,4 @@ def __init__(self, data_args: "DatasetArguments", split: str, processor: Process super().__init__(data_args=data_args, split=split, processor=processor) def dataset_template(self, sample): - return { - "text": self.processor.apply_chat_template( - sample["text"].strip(), - ), - } + return {"text": sample["text"].strip()} From 5ca7eb27056c8d4fb0a6c315801fe0da2ee83f01 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 24 Feb 2025 15:02:07 -0500 Subject: [PATCH 09/21] updated config dataclasses Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 94 ++++++++++++++++--------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 160f2106a..73a6abb51 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -24,15 +24,6 @@ get_parent_by_name, ) -DEFAULT_AWQ_MAPPINGS = [ - [ - ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"], - "re:.*input_layernorm", - ], - [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], - [["re:.*down_proj"], "re:.*up_proj"], -] - __all__ = ["AWQScale", "AWQMapping", "AWQModifier"] @@ -48,8 +39,42 @@ class AWQScale: @dataclass class AWQMapping: """ - Dataclass for storing the mapping between an activation layer and the following - weights that must be balanced during smoothing + Dataclass storing config of activation mappings to smooth + The output activations of smooth_layer are input activations + into the balance_layers + + `AWQMapping`s are resolved into `ResolvedMapping`s, which + retain pointers to the actual `torch.nn.Module`s and additional + metadata at runtime + """ + + smooth_layer: str + balance_layers: list[str] + + +DEFAULT_AWQ_MAPPINGS: list[AWQMapping] = [ + AWQMapping( + "re:.*input_layernorm", + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], + ), + AWQMapping( + "re:.*post_attention_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + AWQMapping( + "re:.*up_proj", + ["re:.*down_proj"], + ), + # TODO check with this uncommented + # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), +] + + +@dataclass +class ResolvedMapping: + """ + Dataclass for storing the resolved mappings between an activation layer + and the following weights that must be balanced during smoothing :param smooth_name: name of the activation layer :param smooth_layer: PyTorch module storing the activation layer @@ -89,9 +114,11 @@ class AWQModifier(Modifier): ```yaml AWQModifier: bits: 4 - mappings: [ - [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], - [["re:.*fc1"], "re:.*final_layer_norm"] + mappings: + - smooth_layer: "re:.*self_attn_layer_norm" + balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"] + - smooth_layer: "re:.*final_layer_norm" + balance_layers: ["re:.*fc1"] ] ignore: ["model.decoder.final_layer_norm"] ``` @@ -119,10 +146,10 @@ class AWQModifier(Modifier): :param apply_clip: whether to apply clipping to the weights after scaling """ - # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) - mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS + mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None @@ -133,7 +160,7 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - resolved_mappings_: Optional[List[AWQMapping]] = None + resolved_mappings_: Optional[List[ResolvedMapping]] = None scales_: Optional[Dict] = None module_kwargs_: Optional[Dict] = None @@ -156,13 +183,11 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) self.ignore = [] if not self.ignore else self.ignore - self.resolved_mappings_ = self._resolve_mappings(state.model) + self.resolved_mappings_ = self._get_resolved_mappings(state.model) self.scales_ = {} calibration_dataloader = state.data.calib - # TODO is it ok to wrap the whole model in this context? - # I don't think we ever want gradients or to use kv cache with calibration_forward_context(state.model): self._set_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() @@ -186,10 +211,10 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _resolve_mappings(self, model: Module) -> List[AWQMapping]: + def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: """ Transforms the list of activations to smooth and their corresponding weights - into AWQMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. For each activation in the mapping list, we find the corresponding weight to balance by searching for the longest substring. For instance, if our balance @@ -197,13 +222,13 @@ def _resolve_mappings(self, model: Module) -> List[AWQMapping]: would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and repeat for model.layer.1 and so on """ - resolved_mappings = [] - for to_balance, to_smooth in self.mappings: - to_smooth_layers = get_layers(to_smooth, model) + resolved_mappings: list[ResolvedMapping] = [] + for mapping in self.mappings: + to_smooth_layers = get_layers(mapping.smooth_layer, model) for layer_name, smooth_layer in to_smooth_layers.items(): if layer_name not in self.ignore: balance_layers, balance_names = [], [] - for balance_suffix in to_balance: + for balance_suffix in mapping.balance_layers: # find the submodule that matches the activation layer balance_name, balance_layer = get_matching_layer( balance_suffix, layer_name, model @@ -224,15 +249,16 @@ def _resolve_mappings(self, model: Module) -> List[AWQMapping]: parent_name, parent = get_parent_by_name( layer_name=balance_name, model=model ) - mapping = AWQMapping( - layer_name, - smooth_layer, - balance_layers, - balance_names=balance_names, - parent=parent, - parent_name=parent_name, + resolved_mappings.append( + ResolvedMapping( + layer_name, + smooth_layer, + balance_layers, + balance_names=balance_names, + parent=parent, + parent_name=parent_name, + ) ) - resolved_mappings.append(mapping) return resolved_mappings def _setup_scale_hooks(self): From 405aeb3cf16627ab40dbb3613aff69f9bc384957 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 00:24:09 +0000 Subject: [PATCH 10/21] OOM error resolved Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 73a6abb51..0dbcbc811 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -65,7 +65,7 @@ class AWQMapping: "re:.*up_proj", ["re:.*down_proj"], ), - # TODO check with this uncommented + # TODO this generally results in higher perplexity for llama 2 7B on wikitext # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -269,8 +269,7 @@ def _setup_scale_hooks(self): def create_hook_fn(layer_name): def hook_fn(module, inp, out): - inp = inp[0] - inp.cpu().detach() + inp = inp[0].cpu().detach() if layer_name in self.scales_: self.scales_[layer_name].inps.append(inp) @@ -365,8 +364,8 @@ def _apply_smoothing(self, model: Module): # [STEP 2]: Compute per-channel mean of the input activation with chunking # move inp to cpu to avoid memory leak - inp = activations - inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + inp = activations.to(weight.device) + inp_flat = activations.cpu().abs().view(-1, inp.shape[-1]) num_elements = inp_flat.size(0) num_channels = inp_flat.size(1) element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 From e819fcdf0fd7031e9b7bf88e8bf949f0cdfd29a8 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 20:19:50 +0000 Subject: [PATCH 11/21] codereview updates Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 42 +++++-------------- .../transformers/finetune/runner.py | 3 +- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0dbcbc811..f71feff6f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -4,7 +4,7 @@ import torch from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from torch.nn import Module from tqdm import tqdm @@ -24,16 +24,7 @@ get_parent_by_name, ) -__all__ = ["AWQScale", "AWQMapping", "AWQModifier"] - - -@dataclass -class AWQScale: - """ - Dataclass for storing the input activations of a layer to be smoothed - """ - - inps: Union[List[torch.Tensor], torch.Tensor] +__all__ = ["AWQMapping", "AWQModifier"] @dataclass @@ -161,8 +152,8 @@ class AWQModifier(Modifier): apply_clip: bool = True resolved_mappings_: Optional[List[ResolvedMapping]] = None - scales_: Optional[Dict] = None - module_kwargs_: Optional[Dict] = None + scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = Field(default_factory=dict) + module_kwargs_: Dict = Field(default_factory=dict) def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -171,20 +162,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: state to run AWQ on :return: True on a successful run, False otherwise """ - if not (self.end is None or self.end == -1): - raise ValueError( - f"{self.__class__.__name__} can only be applied during one-shot. " - f" Expected end to be None or -1, got {self.end}" - ) - if self.start and self.start != -1: - raise ValueError( - f"{self.__class__.__name__} can only be applied during one-shot. " - f"Expected start to be None or -1, got {self.end}" - ) - + self.ignore = [] if not self.ignore else self.ignore self.resolved_mappings_ = self._get_resolved_mappings(state.model) - self.scales_ = {} calibration_dataloader = state.data.calib @@ -272,9 +252,9 @@ def hook_fn(module, inp, out): inp = inp[0].cpu().detach() if layer_name in self.scales_: - self.scales_[layer_name].inps.append(inp) + self.scales_[layer_name].append(inp) else: - self.scales_[layer_name] = AWQScale(inps=[inp]) + self.scales_[layer_name] = [inp] return hook_fn @@ -324,7 +304,7 @@ def _concat_collected_activations(self): """ for mapping in self.resolved_mappings_: name = mapping.smooth_name - self.scales_[name].inps = torch.cat(self.scales_[name].inps, dim=0) + self.scales_[name] = torch.cat(self.scales_[name], dim=0) torch.cuda.empty_cache() @@ -343,7 +323,7 @@ def _apply_smoothing(self, model: Module): balance_layers = mapping.balance_layers balance_names = mapping.balance_names - activations = self.scales_[mapping.smooth_name].inps + activations = self.scales_[mapping.smooth_name] module2inspect = mapping.parent @@ -445,7 +425,7 @@ def _compute_best_scale( module2inspect: torch.nn.Module, linears2scale: List[torch.nn.Linear], fp16_output: torch.Tensor, - ): + ) -> torch.Tensor: """ Compute loss and select best scales @@ -639,7 +619,7 @@ def _forward_input_with_kwargs( :param input_kwargs: additional arguments to pass to the module :return: the first output tensor from the forward pass """ - kwargs = input_kwargs or self.module_kwargs_ or {} + kwargs = input_kwargs or self.module_kwargs_ return tensor_forward_with_input_args( module=module, inputs=inputs, diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 769b84248..d8fc556fc 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -6,6 +6,7 @@ import torch from loguru import logger from torch.utils.data import Dataset +import datasets from llmcompressor.args import ( DatasetArguments, @@ -106,7 +107,7 @@ def _get_split_name(inp_str): ) for split_name, split_str in splits.items(): dataset = self._data_args.dataset - if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + if isinstance(dataset, datasets.Dataset) or (hasattr(dataset, "column_names") and "input_ids" in dataset.column_names): # dataset is already tokenized tokenized_datasets[split_name] = dataset else: From e8013073ea0773c851f519bfb77ec678785454e3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 23:02:26 +0000 Subject: [PATCH 12/21] minor touchups Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++++++------ src/llmcompressor/pytorch/utils/helpers.py | 2 +- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index f71feff6f..e773b191d 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,3 +1,4 @@ +import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -57,7 +58,7 @@ class AWQMapping: ["re:.*down_proj"], ), # TODO this generally results in higher perplexity for llama 2 7B on wikitext - # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -141,7 +142,7 @@ class AWQModifier(Modifier): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS - ignore: Optional[List[str]] = None + ignore: List[str] = [] num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None group_size: int = 128 @@ -151,9 +152,9 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - resolved_mappings_: Optional[List[ResolvedMapping]] = None - scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = Field(default_factory=dict) - module_kwargs_: Dict = Field(default_factory=dict) + resolved_mappings_: List[ResolvedMapping] = [] + scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + module_kwargs_: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -162,8 +163,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: state to run AWQ on :return: True on a successful run, False otherwise """ - - self.ignore = [] if not self.ignore else self.ignore + self.resolved_mappings_ = self._get_resolved_mappings(state.model) calibration_dataloader = state.data.calib @@ -368,7 +368,12 @@ def _apply_smoothing(self, model: Module): # [STEP 3]: Compute output of module fp16_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + module=module2inspect, + inputs=inp, + input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect), + ) + fp16_output = fp16_output.clip( + torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max ) # [STEP 4]: Compute loss @@ -380,9 +385,6 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - # TODO calls to module._hf_hook.pre_forward(module) and - # module._hf_hook.post_forward(module, None) appear a couple places - # in SmoothQuantModifier, do we need them anywhere else? with align_module_device(module): if module in balance_layers: module.weight.mul_(scales.view(1, -1).to(module.weight.device)) @@ -589,7 +591,7 @@ def forward(self, *args, **kwargs): # Update the layer kwargs with `prepare_inputs_for_generation` method # that takes care of everything to avoid unexpected errors. - layer_kwargs |= model.prepare_inputs_for_generation(samples, **layer_kwargs) + layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) # Pop the input_ids as they are not needed at all. layer_kwargs.pop("input_ids") @@ -620,6 +622,7 @@ def _forward_input_with_kwargs( :return: the first output tensor from the forward pass """ kwargs = input_kwargs or self.module_kwargs_ + kwargs = self._sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, inputs=inputs, @@ -704,6 +707,25 @@ def _compute_best_clip( return best_max_val.squeeze(1) + def _sanitize_kwargs(self, inputs_kwargs, module): + """ + Remove the arguments that are not supported in the module's + forward pass to avoid breaking behaviour between different versions + of transformers. + + Args: + inputs_kwargs (`dict`): + The input dictionary to pass to the model layer + module (`torch.nn.Module`): + Target module to quantize. + """ + module_signature = inspect.signature(module.forward).parameters + sanitized_kwargs = {} + for k, v in inputs_kwargs.items(): + if k in module_signature and k != "use_cache": + sanitized_kwargs[k] = v + return sanitized_kwargs + @torch.no_grad() def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 094ef0b8b..c273a712f 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -1249,7 +1249,7 @@ def pseudo_quantize_tensor( assert w.dim() == 2 assert torch.isnan(w).sum() == 0 - if not symmetric: + if symmetric: max_val = w.amax(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True) max_int = 2**bit_width - 1 From 386ead2eed1ab93737e849674e17794397e5f1af Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 3 Mar 2025 21:56:00 +0000 Subject: [PATCH 13/21] updates from debugging Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 11 ++-- src/llmcompressor/observers/__init__.py | 1 + src/llmcompressor/observers/rtn.py | 58 +++++++++++++++++++ src/llmcompressor/pytorch/utils/helpers.py | 7 ++- .../transformers/finetune/runner.py | 6 +- 5 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 src/llmcompressor/observers/rtn.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e773b191d..796ad65cf 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -49,6 +49,8 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), + # TODO this generally results in higher perplexity for llama 2 7B on wikitext + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"], @@ -57,8 +59,6 @@ class AWQMapping: "re:.*up_proj", ["re:.*down_proj"], ), - # TODO this generally results in higher perplexity for llama 2 7B on wikitext - AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -148,7 +148,7 @@ class AWQModifier(Modifier): group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 bits: int = 4 - symmetric: bool = True + symmetric: bool = False duo_scaling: bool = True apply_clip: bool = True @@ -487,6 +487,9 @@ def _compute_best_scale( int_w_output = self._forward_input_with_kwargs( module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ ) + int_w_output = int_w_output.clip( + torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max + ) # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_output, int_w_output, device) @@ -598,8 +601,6 @@ def forward(self, *args, **kwargs): del samples inps = inps[0] - torch.cuda.empty_cache() - if layer_kwargs.get("attention_mask") is not None: layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( best_device diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index 4c3ee5a88..e16d9d93b 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -5,3 +5,4 @@ from .base import * from .min_max import * from .mse import * +from .rtn import * diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py new file mode 100644 index 000000000..47b6846b6 --- /dev/null +++ b/src/llmcompressor/observers/rtn.py @@ -0,0 +1,58 @@ +from typing import Any, Optional, Tuple + +import torch +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.utils import deprecated + +from llmcompressor.observers.base import Observer +from llmcompressor.pytorch.utils import pseudo_quantize_tensor + +__all__ = ["RoundToNearestObserver"] + + +@Observer.register("rtn") +class RoundToNearestObserver(Observer): + """ + Implements a quantization observer that calculates scale and zero point based on the + minimum and maximum values of the tensor being observed. If averaging_constant is + specified, then the scales are updated using a moving average + """ + + def calculate_qparams( + self, + observed: torch.Tensor, + reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + ) -> Tuple[torch.FloatTensor, torch.IntTensor]: + """ + Updates the observed min and max using a moving average smoothed by the + averaging_constant. Set the averaging_constant to 1.0 to disable averaging. + + :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :return: tuple of scale and zero point derived from the observed tensor + """ + + _, scales, zp = pseudo_quantize_tensor( + observed, + symmetric=self.quantization_args.symmetric, + bit_width=self.quantization_args.num_bits, + group_size=-1, #self.quantization_args.group_size, + ) + return (scales, zp) + + def get_qparams_along_dim( + self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None + ): + """ + Calculate quantization parameters along the specified dimension + """ + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) + return self.calculate_qparams( + observed, reduce_dims=reduce_dims, tensor_id=tensor_id + ) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index c273a712f..374fd6d7e 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -1244,12 +1244,13 @@ def pseudo_quantize_tensor( ): org_w_shape = w.shape if group_size > 0: - assert org_w_shape[-1] % group_size == 0 + assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" w = w.reshape(-1, group_size) assert w.dim() == 2 assert torch.isnan(w).sum() == 0 - if symmetric: + # zero point quantization + if not symmetric: max_val = w.amax(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True) max_int = 2**bit_width - 1 @@ -1259,7 +1260,7 @@ def pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = zeros.view(org_w_shape[0], -1) + zeros = (zeros- 2**(bit_width-1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index d8fc556fc..31cded7c9 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -3,10 +3,10 @@ import re from typing import List, Optional +import datasets import torch from loguru import logger from torch.utils.data import Dataset -import datasets from llmcompressor.args import ( DatasetArguments, @@ -107,7 +107,9 @@ def _get_split_name(inp_str): ) for split_name, split_str in splits.items(): dataset = self._data_args.dataset - if isinstance(dataset, datasets.Dataset) or (hasattr(dataset, "column_names") and "input_ids" in dataset.column_names): + if isinstance(dataset, datasets.Dataset) or ( + hasattr(dataset, "column_names") and "input_ids" in dataset.column_names + ): # dataset is already tokenized tokenized_datasets[split_name] = dataset else: From 32b0b534e58cf2b3e8e80923f7d0204e35a86151 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 4 Mar 2025 18:17:55 +0000 Subject: [PATCH 14/21] styling Signed-off-by: Brian Dellabetta --- src/llmcompressor/pytorch/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 374fd6d7e..3a42dcb8a 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -1260,7 +1260,7 @@ def pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = (zeros- 2**(bit_width-1)).view(org_w_shape[0], -1) + zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) From 31884cf50b4b32a0f54ea0adc03a6f0b4911e1a4 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 5 Mar 2025 17:30:50 +0000 Subject: [PATCH 15/21] slightly improved rtn calculate_qparams logic Signed-off-by: Brian Dellabetta --- src/llmcompressor/observers/rtn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py index 47b6846b6..889b03318 100644 --- a/src/llmcompressor/observers/rtn.py +++ b/src/llmcompressor/observers/rtn.py @@ -42,7 +42,7 @@ def calculate_qparams( observed, symmetric=self.quantization_args.symmetric, bit_width=self.quantization_args.num_bits, - group_size=-1, #self.quantization_args.group_size, + group_size=self.quantization_args.group_size or -1, ) return (scales, zp) From b03124a3f0ccff3757df56a389200110f22d7474 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 19:37:17 +0000 Subject: [PATCH 16/21] code cleanup Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 212 ++++++------------ src/llmcompressor/observers/__init__.py | 1 - src/llmcompressor/observers/rtn.py | 58 ----- src/llmcompressor/pytorch/utils/helpers.py | 61 ----- .../llmcompressor/modifiers/awq/test_base.py | 2 +- 5 files changed, 71 insertions(+), 263 deletions(-) delete mode 100644 src/llmcompressor/observers/rtn.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 796ad65cf..bdf3f8628 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -13,7 +13,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( - pseudo_quantize_tensor, tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent @@ -49,7 +48,7 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this generally results in higher perplexity for llama 2 7B on wikitext + # TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case? AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -127,8 +126,6 @@ class AWQModifier(Modifier): smoothing (the second entry of the mappings list). :param num_calibration_steps: number of samples to use for calibration, or None to use the whole dataset - :param calibration_function: optional function to use for the forward pass, or None - to use the default tensor_module_forward :param group_size: number of weights to group together for scaling :param max_chunk_memory: maximum memory to use for each chunk of input activations :param bits: number of bits to quantize the weights to @@ -144,17 +141,15 @@ class AWQModifier(Modifier): mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS ignore: List[str] = [] num_calibration_steps: Optional[int] = None - calibration_function: Optional[Callable] = None group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 bits: int = 4 symmetric: bool = False duo_scaling: bool = True - apply_clip: bool = True - resolved_mappings_: List[ResolvedMapping] = [] - scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {} - module_kwargs_: Dict = {} + _resolved_mappings: List[ResolvedMapping] = [] + _scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + _module_kwargs: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -164,7 +159,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :return: True on a successful run, False otherwise """ - self.resolved_mappings_ = self._get_resolved_mappings(state.model) + self._set_resolved_mappings(state.model) calibration_dataloader = state.data.calib @@ -184,17 +179,18 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self.scales_ is not None: - self.scales_.clear() - if self.resolved_mappings_ is not None: - self.resolved_mappings_.clear() + if self._scales is not None: + self._scales.clear() + if self._resolved_mappings is not None: + self._resolved_mappings.clear() return True - def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: + def _set_resolved_mappings(self, model: Module) -> None: """ Transforms the list of activations to smooth and their corresponding weights - into ResolvedMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. + Result is stored in _resolved_mappings. For each activation in the mapping list, we find the corresponding weight to balance by searching for the longest substring. For instance, if our balance @@ -239,7 +235,8 @@ def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: parent_name=parent_name, ) ) - return resolved_mappings + self._resolved_mappings = resolved_mappings + return def _setup_scale_hooks(self): """ @@ -251,14 +248,14 @@ def create_hook_fn(layer_name): def hook_fn(module, inp, out): inp = inp[0].cpu().detach() - if layer_name in self.scales_: - self.scales_[layer_name].append(inp) + if layer_name in self._scales: + self._scales[layer_name].append(inp) else: - self.scales_[layer_name] = [inp] + self._scales[layer_name] = [inp] return hook_fn - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name # storing inps to first balance layer # is enough, as other balance layers @@ -288,7 +285,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List): model, calibration_dataloader, self.num_calibration_steps, - self.calibration_function, ) # remove the hooks now that we are done calibrating @@ -299,12 +295,12 @@ def _concat_collected_activations(self): Concatenate the collected activation values from each forward pass into a single tensor for each layer - :postcondition: each layer in self.scales_ will have a single tensor containing + :postcondition: each layer in self._scales will have a single tensor containing all the activation values seen during calibration """ - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name - self.scales_[name] = torch.cat(self.scales_[name], dim=0) + self._scales[name] = torch.cat(self._scales[name], dim=0) torch.cuda.empty_cache() @@ -318,12 +314,11 @@ def _apply_smoothing(self, model: Module): :param model: model to apply smoothing to """ logger.info("Smoothing activation scales...") - for mapping in tqdm(self.resolved_mappings_): + for mapping in tqdm(self._resolved_mappings): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers - balance_names = mapping.balance_names - activations = self.scales_[mapping.smooth_name] + activations = self._scales[mapping.smooth_name] module2inspect = mapping.parent @@ -370,7 +365,7 @@ def _apply_smoothing(self, model: Module): fp16_output = self._forward_input_with_kwargs( module=module2inspect, inputs=inp, - input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect), + input_kwargs=self._sanitize_kwargs(self._module_kwargs, module2inspect), ) fp16_output = fp16_output.clip( torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max @@ -407,15 +402,6 @@ def smooth(module): smooth(layer) smooth(smooth_layer) - if self.apply_clip: - clip_list = self._search_best_clip( - balance_layers=balance_layers, - balance_names=balance_names, - input_feat=inp, - ) - - _apply_clip(model, clip_list) - # clear out allocated smoothing scales torch.cuda.empty_cache() @@ -432,7 +418,7 @@ def _compute_best_scale( Compute loss and select best scales L(s) = || Q(W * s) (s^-1 * X) - W * X || - Q: weight quantization function | pseudo_quantize_tensor(W * s) + Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X @@ -461,7 +447,7 @@ def _compute_best_scale( else: scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() - scales_view = scales.view(1, -1).to(device) + _scalesview = scales.view(1, -1).to(device) # avoid scaling values that overflow scales[torch.isinf(scales)] = 1 @@ -470,22 +456,22 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: with align_module_device(fc): - fc.weight.mul_(scales_view) + fc.weight.mul_(_scalesview) update_offload_parameter( fc, "weight", - pseudo_quantize_tensor( + _pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view, + / _scalesview, ) # W * X int_w_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ + module=module2inspect, inputs=x, input_kwargs=self._module_kwargs ) int_w_output = int_w_output.clip( torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max @@ -606,7 +592,7 @@ def forward(self, *args, **kwargs): best_device ) - self.module_kwargs_ = layer_kwargs + self._module_kwargs = layer_kwargs def _forward_input_with_kwargs( self, @@ -622,7 +608,7 @@ def _forward_input_with_kwargs( :param input_kwargs: additional arguments to pass to the module :return: the first output tensor from the forward pass """ - kwargs = input_kwargs or self.module_kwargs_ + kwargs = input_kwargs or self._module_kwargs kwargs = self._sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, @@ -630,84 +616,6 @@ def _forward_input_with_kwargs( input_kwargs=kwargs, )[0] - @torch.no_grad() - def _search_best_clip(self, balance_layers, balance_names, input_feat): - clip_list = [] - avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] - - for name, layer in zip(balance_names, balance_layers): - # due to qk bmm, it is hard to clip precisely - if any([_ in name for _ in avoid_clipping]): - continue - - max_val = self._compute_best_clip(layer.weight, input_feat) - clip_list.append((name, max_val)) - - return clip_list - - @torch.no_grad() - def _compute_best_clip( - self, - w: torch.Tensor, - input_feat: torch.Tensor, - n_grid=20, - max_shrink=0.5, - n_sample_token=512, - ): - assert w.dim() == 2 - org_w_shape = w.shape - # w [co, ci] -> [co, 1, n_group, group size] - # input_feat [n_token, ci] -> [1, n_token, n_group, group size] - group_size = self.group_size if self.group_size > 0 else org_w_shape[1] - input_feat = input_feat.view(-1, input_feat.shape[-1]) - input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) - - # Compute input feature step size (minimum 1) - step_size = max(1, input_feat.shape[1] // n_sample_token) - input_feat = input_feat[:, ::step_size] - - w = w.reshape(org_w_shape[0], 1, -1, group_size) - - oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM - assert org_w_shape[0] % oc_batch_size == 0 - w_all = w - best_max_val_all = [] - - for i_b in range(org_w_shape[0] // oc_batch_size): - w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] - - org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 - - best_max_val = org_max_val.clone() - min_errs = torch.ones_like(org_max_val) * 1e9 - input_feat = input_feat.to(w.device) - org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group - - for i_s in range(int(max_shrink * n_grid)): - max_val = org_max_val * (1 - i_s / n_grid) - min_val = -max_val - cur_w = torch.clamp(w, min_val, max_val) - q_w = pseudo_quantize_tensor( - w=cur_w, - symmetric=self.symmetric, - group_size=group_size, - bit_width=self.bits, - )[0] - cur_out = (input_feat * q_w).sum(dim=-1) - - # co, 1, n_group, 1 - err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) - del cur_w - del cur_out - cur_best_idx = err < min_errs - min_errs[cur_best_idx] = err[cur_best_idx] - best_max_val[cur_best_idx] = max_val[cur_best_idx] - best_max_val_all.append(best_max_val) - - best_max_val = torch.cat(best_max_val_all, dim=0) - - return best_max_val.squeeze(1) - def _sanitize_kwargs(self, inputs_kwargs, module): """ Remove the arguments that are not supported in the module's @@ -728,22 +636,42 @@ def _sanitize_kwargs(self, inputs_kwargs, module): return sanitized_kwargs -@torch.no_grad() -def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): - """ - Apply clipping to the weights of the given module - :post-condition: the weights of the module are clipped to the given maximum values - :param module: module to apply clipping to - :param clip_list: list of tuples containing the name of the layer and the maximum - value to clip the weights to - """ - for name, max_val in clip_list: - _, layer = get_layer(target=name, module=module) - assert isinstance(layer, torch.nn.Linear) - with align_module_device(layer): - max_val = max_val.to(layer.weight.device) - org_shape = layer.weight.shape - layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) - layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) - layer.weight.data = layer.weight.data.reshape(org_shape) +def _pseudo_quantize_tensor( + w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 +): + org_w_shape = w.shape + if group_size > 0: + assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + w = w.reshape(-1, group_size) + assert w.dim() == 2 + assert torch.isnan(w).sum() == 0 + + # zero point quantization + if not symmetric: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**bit_width - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) + else: + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (bit_width - 1) - 1 + min_int = -(2 ** (bit_width - 1)) + scales = max_val / max_int + zeros = None + w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + + return w, scales, zeros diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index e16d9d93b..4c3ee5a88 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -5,4 +5,3 @@ from .base import * from .min_max import * from .mse import * -from .rtn import * diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py deleted file mode 100644 index 889b03318..000000000 --- a/src/llmcompressor/observers/rtn.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams -from compressed_tensors.utils import deprecated - -from llmcompressor.observers.base import Observer -from llmcompressor.pytorch.utils import pseudo_quantize_tensor - -__all__ = ["RoundToNearestObserver"] - - -@Observer.register("rtn") -class RoundToNearestObserver(Observer): - """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average - """ - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - - _, scales, zp = pseudo_quantize_tensor( - observed, - symmetric=self.quantization_args.symmetric, - bit_width=self.quantization_args.num_bits, - group_size=self.quantization_args.group_size or -1, - ) - return (scales, zp) - - def get_qparams_along_dim( - self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None - ): - """ - Calculate quantization parameters along the specified dimension - """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 3a42dcb8a..feeb5ed1c 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -87,8 +87,6 @@ "detach", "adjust_quantization_for_onnx_export", "get_dependency_order", - "pseudo_quantize_tensor", - "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", ] @@ -1238,62 +1236,3 @@ def swap_modules( return cur - -def pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros - - -def pseudo_dequantize_linear( - w: torch.Tensor, - scales: torch.Tensor, - zeros: Optional[torch.Tensor] = None, - symmetric: bool = False, -): - # get repeated count - repeat_count = w.weight.data.shape[-1] // scales.shape[-1] - scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) - - # dequantize - if not symmetric: - zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) - w = (w.weight.data - zeros) * scales - else: - w = w.weight.data * scales - - return w diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 918238718..3fff33e23 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -24,5 +24,5 @@ def test_awq_is_registered(self): self.assertIsInstance( modifier, AWQModifier, - "PyTorch AWQModifier not registered", + "AWQModifier not registered", ) From 4488a8c583409849028d7fd1bc8a3427a0bbf156 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 19:39:30 +0000 Subject: [PATCH 17/21] rename smoothquant private vars Signed-off-by: Brian Dellabetta --- .../modifiers/smoothquant/base.py | 36 +++++++++---------- .../logarithmic_equalization/test_pytorch.py | 6 ++-- .../modifiers/smoothquant/test_pytorch.py | 6 ++-- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 1b1e0aee6..037fe1219 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -109,8 +109,8 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - resolved_mappings_: Optional[List[SmoothQuantMapping]] = None - scales_: Optional[Dict] = None + _resolved_mappings: Optional[List[SmoothQuantMapping]] = None + _scales: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -132,8 +132,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.ignore = [] if not self.ignore else self.ignore self.mappings = self._infer_mappings_from_model(state.model) - self.resolved_mappings_ = self._resolve_mappings(state.model) - self.scales_ = {} + self._resolved_mappings = self._resolve_mappings(state.model) + self._scales = {} calibration_dataloader = state.data.calib @@ -150,10 +150,10 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self.scales_ is not None: - self.scales_.clear() - if self.resolved_mappings_ is not None: - self.resolved_mappings_.clear() + if self._scales is not None: + self._scales.clear() + if self._resolved_mappings is not None: + self._resolved_mappings.clear() return True @@ -219,21 +219,21 @@ def hook_fn(module, inp, out): latest_mins = torch.min(out, dim=0)[0] latest_maxes = torch.max(out, dim=0)[0] - if layer_name in self.scales_: - self.scales_[layer_name].min_channel_vals = torch.minimum( - self.scales_[layer_name].min_channel_vals, latest_mins + if layer_name in self._scales: + self._scales[layer_name].min_channel_vals = torch.minimum( + self._scales[layer_name].min_channel_vals, latest_mins ) - self.scales_[layer_name].max_channel_vals = torch.maximum( - self.scales_[layer_name].max_channel_vals, latest_maxes + self._scales[layer_name].max_channel_vals = torch.maximum( + self._scales[layer_name].max_channel_vals, latest_maxes ) else: - self.scales_[layer_name] = SmoothQuantScale( + self._scales[layer_name] = SmoothQuantScale( min_channel_vals=latest_mins, max_channel_vals=latest_maxes ) return hook_fn - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name layer = mapping.smooth_layer self.register_hook(layer, create_hook_fn(name), "forward") @@ -278,10 +278,10 @@ def _apply_smoothing(self, model: Module): This modifies the weights of the model in-place. """ logger.info("Smoothing activation scales...") - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: activation_scales = ( # get dynamic range for each activation channel - self.scales_[mapping.smooth_name].max_channel_vals - - self.scales_[mapping.smooth_name].min_channel_vals + self._scales[mapping.smooth_name].max_channel_vals + - self._scales[mapping.smooth_name].min_channel_vals ) smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index d485c0637..e84f66e83 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -21,11 +21,11 @@ def test_successful_map(self): modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] - modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + self.assertEqual(len(modifier._resolved_mappings), len(mappings)) - mapping = modifier.resolved_mappings_[0] + mapping = modifier._resolved_mappings[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index 7977c4546..cbb60f030 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -19,11 +19,11 @@ def test_successful_map(self): modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] - modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + self.assertEqual(len(modifier._resolved_mappings), len(mappings)) - mapping = modifier.resolved_mappings_[0] + mapping = modifier._resolved_mappings[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) From b46429023943d68ef9818f167df63c661edd00d6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 21:56:38 +0000 Subject: [PATCH 18/21] address gh comment on updating offloaded parameter Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 16 +++++++++------- src/llmcompressor/pytorch/utils/helpers.py | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bdf3f8628..e5e8f1006 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -12,9 +12,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.pytorch.utils import ( - tensor_forward_with_input_args, -) +from llmcompressor.pytorch.utils import tensor_forward_with_input_args from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( @@ -189,7 +187,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: def _set_resolved_mappings(self, model: Module) -> None: """ Transforms the list of activations to smooth and their corresponding weights - into ResolvedMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. For each activation in the mapping list, we find the corresponding weight to @@ -386,12 +384,15 @@ def smooth(module): elif module == smooth_layer: if module.weight.ndim == 1: module.weight.div_(scales.to(module.weight.device)) + update_offload_parameter(module, "weight") else: module.weight.div_( scales.view(-1, 1).to(module.weight.device) ) + update_offload_parameter(module, "weight") if hasattr(module, "bias") and module.bias is not None: module.bias.div_(scales.to(module.bias.device)) + update_offload_parameter(module, "bias") parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -636,13 +637,14 @@ def _sanitize_kwargs(self, inputs_kwargs, module): return sanitized_kwargs - def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): org_w_shape = w.shape if group_size > 0: - assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + assert ( + org_w_shape[-1] % group_size == 0 + ), f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" w = w.reshape(-1, group_size) assert w.dim() == 2 assert torch.isnan(w).sum() == 0 @@ -658,7 +660,7 @@ def _pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) + zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index feeb5ed1c..305961e3a 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -1235,4 +1235,3 @@ def swap_modules( parent.__setattr__(sections[-1], submodule_to_replace) return cur - From e0cb4d4c4e24d6445d539190cf08d3d5aa2a1253 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 22:02:28 +0000 Subject: [PATCH 19/21] drop pile dataset, lint error fixes Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 7 +++-- .../transformers/finetune/data/__init__.py | 1 - .../transformers/finetune/data/pile.py | 27 ------------------- .../finetune/data/test_registry.py | 1 - 4 files changed, 3 insertions(+), 33 deletions(-) delete mode 100644 src/llmcompressor/transformers/finetune/data/pile.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e5e8f1006..fba5f699a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,11 +1,11 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger -from pydantic import ConfigDict, Field +from pydantic import ConfigDict from torch.nn import Module from tqdm import tqdm @@ -16,7 +16,6 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_layer, get_layers, get_matching_layer, get_parent_by_name, @@ -146,7 +145,7 @@ class AWQModifier(Modifier): duo_scaling: bool = True _resolved_mappings: List[ResolvedMapping] = [] - _scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + _scales: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] = {} _module_kwargs: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index 186a076bd..a53caed1b 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -8,7 +8,6 @@ from .flickr_30k import Flickr30K from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset -from .pile import PileValDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py deleted file mode 100644 index f420ba2a5..000000000 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ /dev/null @@ -1,27 +0,0 @@ -from copy import deepcopy -from typing import TYPE_CHECKING - -from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.typing import Processor - -if TYPE_CHECKING: - from llmcompressor.args import DatasetArguments - - -@TextGenerationDataset.register(name="mit-han-lab/pile-val-backup", alias="pile_val") -class PileValDataset(TextGenerationDataset): - """ - Child text generation class for "The Pile" dataset - :param data_args: configuration settings for dataset loading - :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset - """ - - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.text_column = "text" - data_args.dataset = "mit-han-lab/pile-val-backup" - super().__init__(data_args=data_args, split=split, processor=processor) - - def dataset_template(self, sample): - return {"text": sample["text"].strip()} diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index ce872fba9..29895b4a4 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -4,7 +4,6 @@ from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, - PileEvalDataset, TextGenerationDataset, WikiTextDataset, ) From 06a12bfe0dd7c870ed55f5623e5674bebd3e7588 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 22:18:30 +0000 Subject: [PATCH 20/21] style fixes Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index fba5f699a..0b1ac3a84 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -45,7 +45,8 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case? + # TODO this should only be added if v_proj/o_proj shapes match up + # should we check during validation and skip if this is not the case? AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -641,9 +642,10 @@ def _pseudo_quantize_tensor( ): org_w_shape = w.shape if group_size > 0: - assert ( - org_w_shape[-1] % group_size == 0 - ), f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + assert org_w_shape[-1] % group_size == 0, ( + f"org_w_shape ({org_w_shape[-1]}) must be a multiple " + + f"of group_size ({group_size})!" + ) w = w.reshape(-1, group_size) assert w.dim() == 2 assert torch.isnan(w).sum() == 0 From 38d15482801603a2c71745d21aff7bab1b5b34fb Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Mar 2025 23:14:07 +0000 Subject: [PATCH 21/21] fix update_offload_parameter Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0b1ac3a84..18d2c4541 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -383,16 +383,25 @@ def smooth(module): module.weight.mul_(scales.view(1, -1).to(module.weight.device)) elif module == smooth_layer: if module.weight.ndim == 1: - module.weight.div_(scales.to(module.weight.device)) - update_offload_parameter(module, "weight") + update_offload_parameter( + module, + "weight", + module.weight.div(scales.to(module.weight.device)), + ) else: - module.weight.div_( - scales.view(-1, 1).to(module.weight.device) + update_offload_parameter( + module, + "weight", + module.weight.div( + scales.view(-1, 1).to(module.weight.device) + ), ) - update_offload_parameter(module, "weight") if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales.to(module.bias.device)) - update_offload_parameter(module, "bias") + update_offload_parameter( + module, + "bias", + module.bias.div(scales.to(module.bias.device)), + ) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: