Skip to content

Commit

Permalink
working with larger num_calibration_samples
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Feb 20, 2025
1 parent 5cb055c commit 9273ef3
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool:

calibration_dataloader = state.data.calib

self._set_module_kwargs(state.model, calibration_dataloader)
self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._concat_collected_activations()
self._apply_smoothing(state.model)
# TODO is it ok to wrap the whole model in this context?
# I don't think we ever want gradients or to use kv cache
with calibration_forward_context(state.model):
self._set_module_kwargs(state.model, calibration_dataloader)
self._setup_scale_hooks()
self._calibrate(state.model, calibration_dataloader)
self._concat_collected_activations()
self._apply_smoothing(state.model)

return True

Expand Down Expand Up @@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
" CompressionSession to run the AWQ modifier"
)

with calibration_forward_context(model):
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)
# with calibration_forward_context(model):
run_calibration_forward(
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)

# remove the hooks now that we are done calibrating
self.remove_hooks()
Expand Down Expand Up @@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module):
x_mean = (x_sum / num_elements).to(inp.dtype)

# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = self._forward_input_with_kwargs(
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
)
fp16_output = self._forward_input_with_kwargs(
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
)

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
Expand Down Expand Up @@ -488,7 +490,9 @@ def _compute_best_scale(
logger.debug(history)
raise Exception

assert torch.isnan(best_scales).sum() == 0, best_scales
assert (
torch.isnan(best_scales).sum() == 0
), f"Nan found in scales: {best_scales}"

return best_scales.detach().cpu()

Expand Down

0 comments on commit 9273ef3

Please sign in to comment.