Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bdellabe/Rtuli awq modifier v3 #1177

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
98a5b73
cherry picked files from stale PR #181 branch awq-feature-branch
brian-dellabetta Feb 18, 2025
2611966
updated to be compatible with latest, unit tests passing
brian-dellabetta Feb 18, 2025
88aeab8
switch to using HooksMixin api
brian-dellabetta Feb 18, 2025
2b74ccf
pydantic serialization issue fix
brian-dellabetta Feb 18, 2025
cb5956e
switch to accelerate with align_module_device
brian-dellabetta Feb 19, 2025
5cb055c
AWQ running but OOMs unless NUM_CALIBRATION_SAMPLES and MAX_SEQUENCE_…
brian-dellabetta Feb 19, 2025
28f8bca
working with larger num_calibration_samples
brian-dellabetta Feb 20, 2025
2226bfd
fix pile dataset issue
brian-dellabetta Feb 20, 2025
5ca7eb2
updated config dataclasses
brian-dellabetta Feb 24, 2025
405aeb3
OOM error resolved
brian-dellabetta Feb 25, 2025
e819fcd
codereview updates
brian-dellabetta Feb 25, 2025
e801307
minor touchups
brian-dellabetta Feb 25, 2025
386ead2
updates from debugging
brian-dellabetta Mar 3, 2025
32b0b53
styling
brian-dellabetta Mar 4, 2025
31884cf
slightly improved rtn calculate_qparams logic
brian-dellabetta Mar 5, 2025
b03124a
code cleanup
brian-dellabetta Mar 10, 2025
4488a8c
rename smoothquant private vars
brian-dellabetta Mar 10, 2025
1e90168
Merge branch 'main' into bdellabe/awq-modifier-v3
brian-dellabetta Mar 10, 2025
b464290
address gh comment on updating offloaded parameter
brian-dellabetta Mar 10, 2025
e0cb4d4
drop pile dataset, lint error fixes
brian-dellabetta Mar 10, 2025
06a12bf
style fixes
brian-dellabetta Mar 10, 2025
38d1548
fix update_offload_parameter
brian-dellabetta Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from accelerate.utils import align_module_device
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module
Expand Down Expand Up @@ -318,7 +318,7 @@ def _apply_smoothing(self, model: Module):

# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([_m.weight for _m in balance_layers], dim=0)
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.group_size)
Expand Down Expand Up @@ -373,22 +373,18 @@ def smooth(module):
# TODO calls to module._hf_hook.pre_forward(module) and
# module._hf_hook.post_forward(module, None) appear a couple places
# in SmoothQuantModifier, do we need them anywhere else?
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1).to(module.weight.device))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales.to(module.weight.device))
else:
module.weight.div_(scales.view(-1, 1).to(module.weight.device))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales.to(module.bias.device))

if offloaded:
module._hf_hook.post_forward(module, None)
with align_module_device(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1).to(module.weight.device))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales.to(module.weight.device))
else:
module.weight.div_(
scales.view(-1, 1).to(module.weight.device)
)
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales.to(module.bias.device))

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand Down Expand Up @@ -461,16 +457,17 @@ def _compute_best_scale(

# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = (
pseudo_quantize_tensor(
w=fc.weight.data,
symmetric=self.symmetric,
bit_width=self.bits,
group_size=self.group_size,
)[0]
/ scales_view
)
with align_module_device(fc):
fc.weight.mul_(scales_view)
fc.weight.data = (
pseudo_quantize_tensor(
w=fc.weight.data,
symmetric=self.symmetric,
bit_width=self.bits,
group_size=self.group_size,
)[0]
/ scales_view
)

# W * X
int_w_output = self._forward_input_with_kwargs(
Expand Down Expand Up @@ -691,10 +688,6 @@ def _compute_best_clip(

best_max_val = torch.cat(best_max_val_all, dim=0)

# TODO this appears unneeded, clear_memory removed
# clear_memory(input_feat)
# clear_memory(org_out)

return best_max_val.squeeze(1)


Expand All @@ -711,8 +704,9 @@ def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
for name, max_val in clip_list:
_, layer = get_layer(target=name, module=module)
assert isinstance(layer, torch.nn.Linear)
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
with align_module_device(layer):
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
40 changes: 14 additions & 26 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from accelerate.utils import align_module_device
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module
Expand Down Expand Up @@ -293,22 +293,16 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)
with align_module_device(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand All @@ -333,15 +327,9 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)
with align_module_device(layer):
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

Expand Down