From 28f8bca24d542ce6854f9924dc556e2c03eee156 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Feb 2025 17:15:54 +0000 Subject: [PATCH] working with larger num_calibration_samples Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++++----------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6675291bc..5e985559b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from accelerate.utils import align_module_device +from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool: calibration_dataloader = state.data.calib - self._set_module_kwargs(state.model, calibration_dataloader) - self._setup_scale_hooks() - self._calibrate(state.model, calibration_dataloader) - self._concat_collected_activations() - self._apply_smoothing(state.model) + # TODO is it ok to wrap the whole model in this context? + # I don't think we ever want gradients or to use kv cache + with calibration_forward_context(state.model): + self._set_module_kwargs(state.model, calibration_dataloader) + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._concat_collected_activations() + self._apply_smoothing(state.model) return True @@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - with calibration_forward_context(model): - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) + # with calibration_forward_context(model): + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) # remove the hooks now that we are done calibrating self.remove_hooks() @@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module): x_mean = (x_sum / num_elements).to(inp.dtype) # [STEP 3]: Compute output of module - with torch.no_grad(): - fp16_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ - ) + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + ) # [STEP 4]: Compute loss best_scales = self._compute_best_scale( @@ -459,14 +461,16 @@ def _compute_best_scale( for fc in linears2scale: with align_module_device(fc): fc.weight.mul_(scales_view) - fc.weight.data = ( + update_offload_parameter( + fc, + "weight", pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view + / scales_view, ) # W * X @@ -488,7 +492,9 @@ def _compute_best_scale( logger.debug(history) raise Exception - assert torch.isnan(best_scales).sum() == 0, best_scales + assert ( + torch.isnan(best_scales).sum() == 0 + ), f"Nan found in scales: {best_scales}" return best_scales.detach().cpu()