Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Mar 10, 2025
1 parent 31884cf commit b03124a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 263 deletions.
212 changes: 70 additions & 142 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.pytorch.utils import (
pseudo_quantize_tensor,
tensor_forward_with_input_args,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
Expand Down Expand Up @@ -49,7 +48,7 @@ class AWQMapping:
"re:.*input_layernorm",
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
),
# TODO this generally results in higher perplexity for llama 2 7B on wikitext
# TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case?
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping(
"re:.*post_attention_layernorm",
Expand Down Expand Up @@ -127,8 +126,6 @@ class AWQModifier(Modifier):
smoothing (the second entry of the mappings list).
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
:param calibration_function: optional function to use for the forward pass, or None
to use the default tensor_module_forward
:param group_size: number of weights to group together for scaling
:param max_chunk_memory: maximum memory to use for each chunk of input activations
:param bits: number of bits to quantize the weights to
Expand All @@ -144,17 +141,15 @@ class AWQModifier(Modifier):
mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS
ignore: List[str] = []
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None
group_size: int = 128
max_chunk_memory: int = 1024 * 1024 * 1024
bits: int = 4
symmetric: bool = False
duo_scaling: bool = True
apply_clip: bool = True

resolved_mappings_: List[ResolvedMapping] = []
scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {}
module_kwargs_: Dict = {}
_resolved_mappings: List[ResolvedMapping] = []
_scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {}
_module_kwargs: Dict = {}

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand All @@ -164,7 +159,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
:return: True on a successful run, False otherwise
"""

self.resolved_mappings_ = self._get_resolved_mappings(state.model)
self._set_resolved_mappings(state.model)

calibration_dataloader = state.data.calib

Expand All @@ -184,17 +179,18 @@ def on_finalize(self, state: State, **kwargs) -> bool:
:param state: unused
:return: True
"""
if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
self.resolved_mappings_.clear()
if self._scales is not None:
self._scales.clear()
if self._resolved_mappings is not None:
self._resolved_mappings.clear()

return True

def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]:
def _set_resolved_mappings(self, model: Module) -> None:
"""
Transforms the list of activations to smooth and their corresponding weights
into ResolvedMapping objects, resolving regular expressions.
into ResolvedMapping objects, resolving regular expressions.
Result is stored in _resolved_mappings.
For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
Expand Down Expand Up @@ -239,7 +235,8 @@ def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]:
parent_name=parent_name,
)
)
return resolved_mappings
self._resolved_mappings = resolved_mappings
return

def _setup_scale_hooks(self):
"""
Expand All @@ -251,14 +248,14 @@ def create_hook_fn(layer_name):
def hook_fn(module, inp, out):
inp = inp[0].cpu().detach()

if layer_name in self.scales_:
self.scales_[layer_name].append(inp)
if layer_name in self._scales:
self._scales[layer_name].append(inp)
else:
self.scales_[layer_name] = [inp]
self._scales[layer_name] = [inp]

return hook_fn

for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
name = mapping.smooth_name
# storing inps to first balance layer
# is enough, as other balance layers
Expand Down Expand Up @@ -288,7 +285,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
model,
calibration_dataloader,
self.num_calibration_steps,
self.calibration_function,
)

# remove the hooks now that we are done calibrating
Expand All @@ -299,12 +295,12 @@ def _concat_collected_activations(self):
Concatenate the collected activation values from each forward pass into a single
tensor for each layer
:postcondition: each layer in self.scales_ will have a single tensor containing
:postcondition: each layer in self._scales will have a single tensor containing
all the activation values seen during calibration
"""
for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
name = mapping.smooth_name
self.scales_[name] = torch.cat(self.scales_[name], dim=0)
self._scales[name] = torch.cat(self._scales[name], dim=0)

torch.cuda.empty_cache()

Expand All @@ -318,12 +314,11 @@ def _apply_smoothing(self, model: Module):
:param model: model to apply smoothing to
"""
logger.info("Smoothing activation scales...")
for mapping in tqdm(self.resolved_mappings_):
for mapping in tqdm(self._resolved_mappings):
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
balance_names = mapping.balance_names

activations = self.scales_[mapping.smooth_name]
activations = self._scales[mapping.smooth_name]

module2inspect = mapping.parent

Expand Down Expand Up @@ -370,7 +365,7 @@ def _apply_smoothing(self, model: Module):
fp16_output = self._forward_input_with_kwargs(
module=module2inspect,
inputs=inp,
input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect),
input_kwargs=self._sanitize_kwargs(self._module_kwargs, module2inspect),
)
fp16_output = fp16_output.clip(
torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max
Expand Down Expand Up @@ -407,15 +402,6 @@ def smooth(module):
smooth(layer)
smooth(smooth_layer)

if self.apply_clip:
clip_list = self._search_best_clip(
balance_layers=balance_layers,
balance_names=balance_names,
input_feat=inp,
)

_apply_clip(model, clip_list)

# clear out allocated smoothing scales
torch.cuda.empty_cache()

Expand All @@ -432,7 +418,7 @@ def _compute_best_scale(
Compute loss and select best scales
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
Q: weight quantization function | _pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
Expand Down Expand Up @@ -461,7 +447,7 @@ def _compute_best_scale(
else:
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
_scalesview = scales.view(1, -1).to(device)

# avoid scaling values that overflow
scales[torch.isinf(scales)] = 1
Expand All @@ -470,22 +456,22 @@ def _compute_best_scale(
# Q(W * s)
for fc in linears2scale:
with align_module_device(fc):
fc.weight.mul_(scales_view)
fc.weight.mul_(_scalesview)
update_offload_parameter(
fc,
"weight",
pseudo_quantize_tensor(
_pseudo_quantize_tensor(
w=fc.weight.data,
symmetric=self.symmetric,
bit_width=self.bits,
group_size=self.group_size,
)[0]
/ scales_view,
/ _scalesview,
)

# W * X
int_w_output = self._forward_input_with_kwargs(
module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_
module=module2inspect, inputs=x, input_kwargs=self._module_kwargs
)
int_w_output = int_w_output.clip(
torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max
Expand Down Expand Up @@ -606,7 +592,7 @@ def forward(self, *args, **kwargs):
best_device
)

self.module_kwargs_ = layer_kwargs
self._module_kwargs = layer_kwargs

def _forward_input_with_kwargs(
self,
Expand All @@ -622,92 +608,14 @@ def _forward_input_with_kwargs(
:param input_kwargs: additional arguments to pass to the module
:return: the first output tensor from the forward pass
"""
kwargs = input_kwargs or self.module_kwargs_
kwargs = input_kwargs or self._module_kwargs
kwargs = self._sanitize_kwargs(kwargs, module)
return tensor_forward_with_input_args(
module=module,
inputs=inputs,
input_kwargs=kwargs,
)[0]

@torch.no_grad()
def _search_best_clip(self, balance_layers, balance_names, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]

for name, layer in zip(balance_names, balance_layers):
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue

max_val = self._compute_best_clip(layer.weight, input_feat)
clip_list.append((name, max_val))

return clip_list

@torch.no_grad()
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)

# Compute input feature step size (minimum 1)
step_size = max(1, input_feat.shape[1] // n_sample_token)
input_feat = input_feat[:, ::step_size]

w = w.reshape(org_w_shape[0], 1, -1, group_size)

oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(
w=cur_w,
symmetric=self.symmetric,
group_size=group_size,
bit_width=self.bits,
)[0]
cur_out = (input_feat * q_w).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)

best_max_val = torch.cat(best_max_val_all, dim=0)

return best_max_val.squeeze(1)

def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
Expand All @@ -728,22 +636,42 @@ def _sanitize_kwargs(self, inputs_kwargs, module):
return sanitized_kwargs


@torch.no_grad()
def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
"""
Apply clipping to the weights of the given module

:post-condition: the weights of the module are clipped to the given maximum values
:param module: module to apply clipping to
:param clip_list: list of tuples containing the name of the layer and the maximum
value to clip the weights to
"""
for name, max_val in clip_list:
_, layer = get_layer(target=name, module=module)
assert isinstance(layer, torch.nn.Linear)
with align_module_device(layer):
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
org_w_shape = w.shape
if group_size > 0:
assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!"
w = w.reshape(-1, group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0

# zero point quantization
if not symmetric:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**bit_width - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (bit_width - 1) - 1
min_int = -(2 ** (bit_width - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales

assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)

return w, scales, zeros
1 change: 0 additions & 1 deletion src/llmcompressor/observers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
from .base import *
from .min_max import *
from .mse import *
from .rtn import *
Loading

0 comments on commit b03124a

Please sign in to comment.