diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 796ad65cf..bdf3f8628 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -13,7 +13,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( - pseudo_quantize_tensor, tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent @@ -49,7 +48,7 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this generally results in higher perplexity for llama 2 7B on wikitext + # TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case? AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -127,8 +126,6 @@ class AWQModifier(Modifier): smoothing (the second entry of the mappings list). :param num_calibration_steps: number of samples to use for calibration, or None to use the whole dataset - :param calibration_function: optional function to use for the forward pass, or None - to use the default tensor_module_forward :param group_size: number of weights to group together for scaling :param max_chunk_memory: maximum memory to use for each chunk of input activations :param bits: number of bits to quantize the weights to @@ -144,17 +141,15 @@ class AWQModifier(Modifier): mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS ignore: List[str] = [] num_calibration_steps: Optional[int] = None - calibration_function: Optional[Callable] = None group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 bits: int = 4 symmetric: bool = False duo_scaling: bool = True - apply_clip: bool = True - resolved_mappings_: List[ResolvedMapping] = [] - scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {} - module_kwargs_: Dict = {} + _resolved_mappings: List[ResolvedMapping] = [] + _scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + _module_kwargs: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -164,7 +159,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :return: True on a successful run, False otherwise """ - self.resolved_mappings_ = self._get_resolved_mappings(state.model) + self._set_resolved_mappings(state.model) calibration_dataloader = state.data.calib @@ -184,17 +179,18 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self.scales_ is not None: - self.scales_.clear() - if self.resolved_mappings_ is not None: - self.resolved_mappings_.clear() + if self._scales is not None: + self._scales.clear() + if self._resolved_mappings is not None: + self._resolved_mappings.clear() return True - def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: + def _set_resolved_mappings(self, model: Module) -> None: """ Transforms the list of activations to smooth and their corresponding weights - into ResolvedMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. + Result is stored in _resolved_mappings. For each activation in the mapping list, we find the corresponding weight to balance by searching for the longest substring. For instance, if our balance @@ -239,7 +235,8 @@ def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: parent_name=parent_name, ) ) - return resolved_mappings + self._resolved_mappings = resolved_mappings + return def _setup_scale_hooks(self): """ @@ -251,14 +248,14 @@ def create_hook_fn(layer_name): def hook_fn(module, inp, out): inp = inp[0].cpu().detach() - if layer_name in self.scales_: - self.scales_[layer_name].append(inp) + if layer_name in self._scales: + self._scales[layer_name].append(inp) else: - self.scales_[layer_name] = [inp] + self._scales[layer_name] = [inp] return hook_fn - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name # storing inps to first balance layer # is enough, as other balance layers @@ -288,7 +285,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List): model, calibration_dataloader, self.num_calibration_steps, - self.calibration_function, ) # remove the hooks now that we are done calibrating @@ -299,12 +295,12 @@ def _concat_collected_activations(self): Concatenate the collected activation values from each forward pass into a single tensor for each layer - :postcondition: each layer in self.scales_ will have a single tensor containing + :postcondition: each layer in self._scales will have a single tensor containing all the activation values seen during calibration """ - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name - self.scales_[name] = torch.cat(self.scales_[name], dim=0) + self._scales[name] = torch.cat(self._scales[name], dim=0) torch.cuda.empty_cache() @@ -318,12 +314,11 @@ def _apply_smoothing(self, model: Module): :param model: model to apply smoothing to """ logger.info("Smoothing activation scales...") - for mapping in tqdm(self.resolved_mappings_): + for mapping in tqdm(self._resolved_mappings): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers - balance_names = mapping.balance_names - activations = self.scales_[mapping.smooth_name] + activations = self._scales[mapping.smooth_name] module2inspect = mapping.parent @@ -370,7 +365,7 @@ def _apply_smoothing(self, model: Module): fp16_output = self._forward_input_with_kwargs( module=module2inspect, inputs=inp, - input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect), + input_kwargs=self._sanitize_kwargs(self._module_kwargs, module2inspect), ) fp16_output = fp16_output.clip( torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max @@ -407,15 +402,6 @@ def smooth(module): smooth(layer) smooth(smooth_layer) - if self.apply_clip: - clip_list = self._search_best_clip( - balance_layers=balance_layers, - balance_names=balance_names, - input_feat=inp, - ) - - _apply_clip(model, clip_list) - # clear out allocated smoothing scales torch.cuda.empty_cache() @@ -432,7 +418,7 @@ def _compute_best_scale( Compute loss and select best scales L(s) = || Q(W * s) (s^-1 * X) - W * X || - Q: weight quantization function | pseudo_quantize_tensor(W * s) + Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X @@ -461,7 +447,7 @@ def _compute_best_scale( else: scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() - scales_view = scales.view(1, -1).to(device) + _scalesview = scales.view(1, -1).to(device) # avoid scaling values that overflow scales[torch.isinf(scales)] = 1 @@ -470,22 +456,22 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: with align_module_device(fc): - fc.weight.mul_(scales_view) + fc.weight.mul_(_scalesview) update_offload_parameter( fc, "weight", - pseudo_quantize_tensor( + _pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view, + / _scalesview, ) # W * X int_w_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ + module=module2inspect, inputs=x, input_kwargs=self._module_kwargs ) int_w_output = int_w_output.clip( torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max @@ -606,7 +592,7 @@ def forward(self, *args, **kwargs): best_device ) - self.module_kwargs_ = layer_kwargs + self._module_kwargs = layer_kwargs def _forward_input_with_kwargs( self, @@ -622,7 +608,7 @@ def _forward_input_with_kwargs( :param input_kwargs: additional arguments to pass to the module :return: the first output tensor from the forward pass """ - kwargs = input_kwargs or self.module_kwargs_ + kwargs = input_kwargs or self._module_kwargs kwargs = self._sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, @@ -630,84 +616,6 @@ def _forward_input_with_kwargs( input_kwargs=kwargs, )[0] - @torch.no_grad() - def _search_best_clip(self, balance_layers, balance_names, input_feat): - clip_list = [] - avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] - - for name, layer in zip(balance_names, balance_layers): - # due to qk bmm, it is hard to clip precisely - if any([_ in name for _ in avoid_clipping]): - continue - - max_val = self._compute_best_clip(layer.weight, input_feat) - clip_list.append((name, max_val)) - - return clip_list - - @torch.no_grad() - def _compute_best_clip( - self, - w: torch.Tensor, - input_feat: torch.Tensor, - n_grid=20, - max_shrink=0.5, - n_sample_token=512, - ): - assert w.dim() == 2 - org_w_shape = w.shape - # w [co, ci] -> [co, 1, n_group, group size] - # input_feat [n_token, ci] -> [1, n_token, n_group, group size] - group_size = self.group_size if self.group_size > 0 else org_w_shape[1] - input_feat = input_feat.view(-1, input_feat.shape[-1]) - input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) - - # Compute input feature step size (minimum 1) - step_size = max(1, input_feat.shape[1] // n_sample_token) - input_feat = input_feat[:, ::step_size] - - w = w.reshape(org_w_shape[0], 1, -1, group_size) - - oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM - assert org_w_shape[0] % oc_batch_size == 0 - w_all = w - best_max_val_all = [] - - for i_b in range(org_w_shape[0] // oc_batch_size): - w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] - - org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 - - best_max_val = org_max_val.clone() - min_errs = torch.ones_like(org_max_val) * 1e9 - input_feat = input_feat.to(w.device) - org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group - - for i_s in range(int(max_shrink * n_grid)): - max_val = org_max_val * (1 - i_s / n_grid) - min_val = -max_val - cur_w = torch.clamp(w, min_val, max_val) - q_w = pseudo_quantize_tensor( - w=cur_w, - symmetric=self.symmetric, - group_size=group_size, - bit_width=self.bits, - )[0] - cur_out = (input_feat * q_w).sum(dim=-1) - - # co, 1, n_group, 1 - err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) - del cur_w - del cur_out - cur_best_idx = err < min_errs - min_errs[cur_best_idx] = err[cur_best_idx] - best_max_val[cur_best_idx] = max_val[cur_best_idx] - best_max_val_all.append(best_max_val) - - best_max_val = torch.cat(best_max_val_all, dim=0) - - return best_max_val.squeeze(1) - def _sanitize_kwargs(self, inputs_kwargs, module): """ Remove the arguments that are not supported in the module's @@ -728,22 +636,42 @@ def _sanitize_kwargs(self, inputs_kwargs, module): return sanitized_kwargs -@torch.no_grad() -def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): - """ - Apply clipping to the weights of the given module - :post-condition: the weights of the module are clipped to the given maximum values - :param module: module to apply clipping to - :param clip_list: list of tuples containing the name of the layer and the maximum - value to clip the weights to - """ - for name, max_val in clip_list: - _, layer = get_layer(target=name, module=module) - assert isinstance(layer, torch.nn.Linear) - with align_module_device(layer): - max_val = max_val.to(layer.weight.device) - org_shape = layer.weight.shape - layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) - layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) - layer.weight.data = layer.weight.data.reshape(org_shape) +def _pseudo_quantize_tensor( + w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 +): + org_w_shape = w.shape + if group_size > 0: + assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + w = w.reshape(-1, group_size) + assert w.dim() == 2 + assert torch.isnan(w).sum() == 0 + + # zero point quantization + if not symmetric: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**bit_width - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) + else: + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (bit_width - 1) - 1 + min_int = -(2 ** (bit_width - 1)) + scales = max_val / max_int + zeros = None + w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + + return w, scales, zeros diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index e16d9d93b..4c3ee5a88 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -5,4 +5,3 @@ from .base import * from .min_max import * from .mse import * -from .rtn import * diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py deleted file mode 100644 index 889b03318..000000000 --- a/src/llmcompressor/observers/rtn.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams -from compressed_tensors.utils import deprecated - -from llmcompressor.observers.base import Observer -from llmcompressor.pytorch.utils import pseudo_quantize_tensor - -__all__ = ["RoundToNearestObserver"] - - -@Observer.register("rtn") -class RoundToNearestObserver(Observer): - """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average - """ - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - - _, scales, zp = pseudo_quantize_tensor( - observed, - symmetric=self.quantization_args.symmetric, - bit_width=self.quantization_args.num_bits, - group_size=self.quantization_args.group_size or -1, - ) - return (scales, zp) - - def get_qparams_along_dim( - self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None - ): - """ - Calculate quantization parameters along the specified dimension - """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 3a42dcb8a..feeb5ed1c 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -87,8 +87,6 @@ "detach", "adjust_quantization_for_onnx_export", "get_dependency_order", - "pseudo_quantize_tensor", - "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", ] @@ -1238,62 +1236,3 @@ def swap_modules( return cur - -def pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros - - -def pseudo_dequantize_linear( - w: torch.Tensor, - scales: torch.Tensor, - zeros: Optional[torch.Tensor] = None, - symmetric: bool = False, -): - # get repeated count - repeat_count = w.weight.data.shape[-1] // scales.shape[-1] - scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) - - # dequantize - if not symmetric: - zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) - w = (w.weight.data - zeros) * scales - else: - w = w.weight.data * scales - - return w diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 918238718..3fff33e23 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -24,5 +24,5 @@ def test_awq_is_registered(self): self.assertIsInstance( modifier, AWQModifier, - "PyTorch AWQModifier not registered", + "AWQModifier not registered", )