Skip to content

Commit

Permalink
working with larger num_calibration_samples
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 20, 2025
1 parent 5cb055c commit 28f8bca
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from accelerate.utils import align_module_device
from compressed_tensors.utils import align_module_device, update_offload_parameter
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module
Expand Down Expand Up @@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool:

calibration_dataloader = state.data.calib

self._set_module_kwargs(state.model, calibration_dataloader)
self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._concat_collected_activations()
self._apply_smoothing(state.model)
# TODO is it ok to wrap the whole model in this context?
# I don't think we ever want gradients or to use kv cache
with calibration_forward_context(state.model):
self._set_module_kwargs(state.model, calibration_dataloader)
self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._concat_collected_activations()
self._apply_smoothing(state.model)

return True

Expand Down Expand Up @@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
" CompressionSession to run the AWQ modifier"
)

with calibration_forward_context(model):
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)
# with calibration_forward_context(model):
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)

# remove the hooks now that we are done calibrating
self.remove_hooks()
Expand Down Expand Up @@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module):
x_mean = (x_sum / num_elements).to(inp.dtype)

# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = self._forward_input_with_kwargs(
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
)
fp16_output = self._forward_input_with_kwargs(
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
)

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
Expand Down Expand Up @@ -459,14 +461,16 @@ def _compute_best_scale(
for fc in linears2scale:
with align_module_device(fc):
fc.weight.mul_(scales_view)
fc.weight.data = (
update_offload_parameter(
fc,
"weight",
pseudo_quantize_tensor(
w=fc.weight.data,
symmetric=self.symmetric,
bit_width=self.bits,
group_size=self.group_size,
)[0]
/ scales_view
/ scales_view,
)

# W * X
Expand All @@ -488,7 +492,9 @@ def _compute_best_scale(
logger.debug(history)
raise Exception

assert torch.isnan(best_scales).sum() == 0, best_scales
assert (
torch.isnan(best_scales).sum() == 0
), f"Nan found in scales: {best_scales}"

return best_scales.detach().cpu()

Expand Down

0 comments on commit 28f8bca

Please sign in to comment.