Skip to content

Commit

Permalink
address gh comment on updating offloaded parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Dellabetta <[email protected]>
  • Loading branch information
brian-dellabetta committed Mar 10, 2025
1 parent 1e90168 commit b464290
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
16 changes: 9 additions & 7 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.pytorch.utils import (
tensor_forward_with_input_args,
)
from llmcompressor.pytorch.utils import tensor_forward_with_input_args
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
Expand Down Expand Up @@ -189,7 +187,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
def _set_resolved_mappings(self, model: Module) -> None:
"""
Transforms the list of activations to smooth and their corresponding weights
into ResolvedMapping objects, resolving regular expressions.
into ResolvedMapping objects, resolving regular expressions.
Result is stored in _resolved_mappings.
For each activation in the mapping list, we find the corresponding weight to
Expand Down Expand Up @@ -386,12 +384,15 @@ def smooth(module):
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales.to(module.weight.device))
update_offload_parameter(module, "weight")
else:
module.weight.div_(
scales.view(-1, 1).to(module.weight.device)
)
update_offload_parameter(module, "weight")
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales.to(module.bias.device))
update_offload_parameter(module, "bias")

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand Down Expand Up @@ -636,13 +637,14 @@ def _sanitize_kwargs(self, inputs_kwargs, module):
return sanitized_kwargs



def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
org_w_shape = w.shape
if group_size > 0:
assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!"
assert (
org_w_shape[-1] % group_size == 0
), f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!"
w = w.reshape(-1, group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0
Expand All @@ -658,7 +660,7 @@ def _pseudo_quantize_tensor(
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1)
zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,4 +1235,3 @@ def swap_modules(
parent.__setattr__(sections[-1], submodule_to_replace)

return cur

0 comments on commit b464290

Please sign in to comment.