Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bdellabe/Rtuli awq modifier v3 #1177

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
98a5b73
cherry picked files from stale PR #181 branch awq-feature-branch
brian-dellabetta Feb 18, 2025
2611966
updated to be compatible with latest, unit tests passing
brian-dellabetta Feb 18, 2025
88aeab8
switch to using HooksMixin api
brian-dellabetta Feb 18, 2025
2b74ccf
pydantic serialization issue fix
brian-dellabetta Feb 18, 2025
cb5956e
switch to accelerate with align_module_device
brian-dellabetta Feb 19, 2025
5cb055c
AWQ running but OOMs unless NUM_CALIBRATION_SAMPLES and MAX_SEQUENCE_…
brian-dellabetta Feb 19, 2025
28f8bca
working with larger num_calibration_samples
brian-dellabetta Feb 20, 2025
2226bfd
fix pile dataset issue
brian-dellabetta Feb 20, 2025
5ca7eb2
updated config dataclasses
brian-dellabetta Feb 24, 2025
405aeb3
OOM error resolved
brian-dellabetta Feb 25, 2025
e819fcd
codereview updates
brian-dellabetta Feb 25, 2025
e801307
minor touchups
brian-dellabetta Feb 25, 2025
386ead2
updates from debugging
brian-dellabetta Mar 3, 2025
32b0b53
styling
brian-dellabetta Mar 4, 2025
31884cf
slightly improved rtn calculate_qparams logic
brian-dellabetta Mar 5, 2025
b03124a
code cleanup
brian-dellabetta Mar 10, 2025
4488a8c
rename smoothquant private vars
brian-dellabetta Mar 10, 2025
1e90168
Merge branch 'main' into bdellabe/awq-modifier-v3
brian-dellabetta Mar 10, 2025
b464290
address gh comment on updating offloaded parameter
brian-dellabetta Mar 10, 2025
e0cb4d4
drop pile dataset, lint error fixes
brian-dellabetta Mar 10, 2025
06a12bf
style fixes
brian-dellabetta Mar 10, 2025
38d1548
fix update_offload_parameter
brian-dellabetta Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
689 changes: 689 additions & 0 deletions src/llmcompressor/modifiers/awq/base.py

Large diffs are not rendered by default.

82 changes: 37 additions & 45 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from accelerate.utils import align_module_device
from loguru import logger
from pydantic import ConfigDict
from torch.nn import Module

from llmcompressor.core import State
Expand Down Expand Up @@ -99,14 +100,17 @@ class SmoothQuantModifier(Modifier):
to use the default tensor_module_forward
"""

# Allow arbitrary types because AWQMapping has field of type torch.nn.Module
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

smoothing_strength: float = 0.5
mappings: Optional[List[Union[Tuple, List]]] = None
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
calibration_function: Optional[Callable] = None

resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None
_resolved_mappings: Optional[List[SmoothQuantMapping]] = None
_scales: Optional[Dict] = None

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand All @@ -128,8 +132,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.ignore = [] if not self.ignore else self.ignore
self.mappings = self._infer_mappings_from_model(state.model)
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}
self._resolved_mappings = self._resolve_mappings(state.model)
self._scales = {}

calibration_dataloader = state.data.calib

Expand All @@ -146,10 +150,10 @@ def on_finalize(self, state: State, **kwargs) -> bool:
:param state: unused
:return: True
"""
if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
self.resolved_mappings_.clear()
if self._scales is not None:
self._scales.clear()
if self._resolved_mappings is not None:
self._resolved_mappings.clear()

return True

Expand All @@ -166,7 +170,7 @@ def _infer_mappings_from_model(
)

@handle_mapping_resolution_errors
def _resolve_mappings(self, model: Module) -> List:
def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
Expand Down Expand Up @@ -215,21 +219,21 @@ def hook_fn(module, inp, out):
latest_mins = torch.min(out, dim=0)[0]
latest_maxes = torch.max(out, dim=0)[0]

if layer_name in self.scales_:
self.scales_[layer_name].min_channel_vals = torch.minimum(
self.scales_[layer_name].min_channel_vals, latest_mins
if layer_name in self._scales:
self._scales[layer_name].min_channel_vals = torch.minimum(
self._scales[layer_name].min_channel_vals, latest_mins
)
self.scales_[layer_name].max_channel_vals = torch.maximum(
self.scales_[layer_name].max_channel_vals, latest_maxes
self._scales[layer_name].max_channel_vals = torch.maximum(
self._scales[layer_name].max_channel_vals, latest_maxes
)
else:
self.scales_[layer_name] = SmoothQuantScale(
self._scales[layer_name] = SmoothQuantScale(
min_channel_vals=latest_mins, max_channel_vals=latest_maxes
)

return hook_fn

for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
name = mapping.smooth_name
layer = mapping.smooth_layer
self.register_hook(layer, create_hook_fn(name), "forward")
Expand Down Expand Up @@ -274,10 +278,10 @@ def _apply_smoothing(self, model: Module):
This modifies the weights of the model in-place.
"""
logger.info("Smoothing activation scales...")
for mapping in self.resolved_mappings_:
for mapping in self._resolved_mappings:
activation_scales = ( # get dynamic range for each activation channel
self.scales_[mapping.smooth_name].max_channel_vals
- self.scales_[mapping.smooth_name].min_channel_vals
self._scales[mapping.smooth_name].max_channel_vals
- self._scales[mapping.smooth_name].min_channel_vals
)
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
Expand All @@ -289,22 +293,16 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)
with align_module_device(module):
if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
if module.weight.ndim == 1:
module.weight.div_(scales)
else:
module.weight.div_(scales.view(-1, 1))
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
Expand All @@ -329,15 +327,9 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)
with align_module_device(layer):
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

Expand Down
41 changes: 41 additions & 0 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Utility / helper functions
"""

import functools
import inspect
import os
import random
import re
Expand Down Expand Up @@ -85,6 +87,8 @@
"detach",
"adjust_quantization_for_onnx_export",
"get_dependency_order",
"tensor_forward_with_input_args",
"sanitize_kwargs_for_module",
]


Expand Down Expand Up @@ -680,6 +684,43 @@ def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor:
return -1.0 * newly_masked + newly_unmasked


def sanitize_kwargs_for_module(
kwargs: Dict[str, Any], module: Module
) -> Dict[str, Any]:
"""
Sanitize the kwargs for a Module by removing any keys that are not
in the signature of the forward method.
:param kwargs: the kwargs to sanitize
:param module: the Module to sanitize the kwargs for
:return: the sanitized kwargs for the callable object
"""
if not isinstance(kwargs, dict):
raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}")

allowed_params = inspect.signature(module.forward).parameters
return {key: value for key, value in kwargs.items() if key in allowed_params}


def tensor_forward_with_input_args(
module: Module, inputs: Tensor, input_kwargs: Dict[str, Any]
) -> Tensor:
"""
Forward the given inputs through the given module with the given input_kwargs.
This function is a wrapper around tensors_module_forward that ensures that the
input_kwargs are sanitized and passed to the module as keyword arguments during
the forward pass.
:param module: the module to forward the inputs through
:param inputs: the inputs to forward through the module
:param input_kwargs: the keyword arguments to pass to the
module during the forward pass
:return: the output of the module after forwarding the inputs through it
"""
inputs = inputs.to(next(module.parameters()).device)
input_kwargs = sanitize_kwargs_for_module(input_kwargs, module)

return tensors_module_forward(inputs, functools.partial(module, **input_kwargs))


##############################
#
# pytorch module helper functions
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import List, Optional

import datasets
import torch
from loguru import logger
from torch.utils.data import Dataset
Expand Down Expand Up @@ -102,7 +103,9 @@ def _get_split_name(inp_str):
)
for split_name, split_str in splits.items():
dataset = self._dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
if isinstance(dataset, datasets.Dataset) or (
hasattr(dataset, "column_names") and "input_ids" in dataset.column_names
):
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
Expand Down
20 changes: 20 additions & 0 deletions src/llmcompressor/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"get_layers_params",
"get_matching_layer",
"get_no_split_params",
"get_parent_by_name",
]


Expand Down Expand Up @@ -338,3 +339,22 @@ def get_no_split_params(module: Module) -> Union[str, List[str]]:
if hasattr(model, "_no_split_modules"):
return model._no_split_modules
return ALL_TARGET


def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]:
"""
Get the parent layer of a layer by name.
:param layer_name: Name of the layer to find the parent of.
:param model: Model to search for the parent layer.
:return: Tuple containing the name of the parent layer
and the parent layer itself.
"""
if not any(layer_name == name for name, _ in model.named_modules()):
raise ValueError(f"Layer '{layer_name}' not found in model")

parent_name_parts = layer_name.split(".")[:-1]
if not parent_name_parts:
return "", model

parent_name = ".".join(parent_name_parts)
return get_layer(parent_name, model)
Empty file.
28 changes: 28 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

import pytest

from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.factory import ModifierFactory
from tests.llmcompressor.modifiers.conf import setup_modifier_factory


@pytest.mark.unit
class TestAWQIsRegistered(unittest.TestCase):
def setUp(self):
self.kwargs = {}
setup_modifier_factory()

def test_awq_is_registered(self):
modifier = ModifierFactory.create(
type_="AWQModifier",
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

self.assertIsInstance(
modifier,
AWQModifier,
"AWQModifier not registered",
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def test_successful_map(self):
modifier = LogarithmicEqualizationModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model)
modifier._resolved_mappings = modifier._resolve_mappings(self.state.model)

self.assertEqual(len(modifier.resolved_mappings_), len(mappings))
self.assertEqual(len(modifier._resolved_mappings), len(mappings))

mapping = modifier.resolved_mappings_[0]
mapping = modifier._resolved_mappings[0]
self.assertEqual(mapping.smooth_name, mappings[0][1])
self.assertIsInstance(mapping.smooth_layer, Linear)
self.assertIsInstance(mapping.balance_layers[0], Linear)
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def test_successful_map(self):
modifier = SmoothQuantModifier(mappings=mappings)

modifier.ignore = []
modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model)
modifier._resolved_mappings = modifier._resolve_mappings(self.state.model)

self.assertEqual(len(modifier.resolved_mappings_), len(mappings))
self.assertEqual(len(modifier._resolved_mappings), len(mappings))

mapping = modifier.resolved_mappings_[0]
mapping = modifier._resolved_mappings[0]
self.assertEqual(mapping.smooth_name, mappings[0][1])
self.assertIsInstance(mapping.smooth_layer, Linear)
self.assertIsInstance(mapping.balance_layers[0], Linear)
42 changes: 42 additions & 0 deletions tests/llmcompressor/pytorch/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
get_optim_learning_rate,
mask_difference,
memory_aware_threshold,
sanitize_kwargs_for_module,
set_optim_learning_rate,
tensor_density,
tensor_export,
tensor_forward_with_input_args,
tensor_sample,
tensor_sparsity,
tensors_batch_size,
Expand Down Expand Up @@ -855,3 +857,43 @@ def test_memory_aware_threshold(tensor, idx):

if prior_state is not None:
os.environ[MEMORY_BOUNDED] = prior_state


class TestSanitizeKwargsForModule:
@pytest.fixture
def module(self):
return Linear(10, 20)

def test_sanitize_kwargs_for_module_not_dict(self, module):
# Test with kwargs that are not a dictionary
with pytest.raises(TypeError):
sanitize_kwargs_for_module("not a dictionary", module)

def test_sanitize_kwargs_for_module_not_in_signature(self, module):
# Test with kwargs that are not in the signature of the forward method
kwargs = {"not_in_signature": 123}
sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module)
assert sanitized_kwargs == {}

def test_sanitize_kwargs_for_module_in_signature(self, module):
# Test with kwargs that are in the signature of the forward method
kwargs = {"input": torch.randn(1, 10)}
sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module)
assert sanitized_kwargs == kwargs


class TestTensorForwardWithInputArgs:
@pytest.fixture
def module(self):
return Linear(10, 20)

def test_tensor_forward_with_input_args(self, module):
# Test with valid inputs and input_kwargs
inputs = torch.randn(1, 10)
input_kwargs = {}
output = tensor_forward_with_input_args(module, inputs, input_kwargs)
assert output.shape == (1, 20)

# Test with input_kwargs that are not in the signature of the forward method
input_kwargs = {"not_in_signature": 123}
tensor_forward_with_input_args(module, inputs, input_kwargs)
Empty file.
Loading