Skip to content

Commit

Permalink
[Bugfix] SparseGPT, Pipelines (#1130)
Browse files Browse the repository at this point in the history
## Purpose ##
* SparseGPT
* Fix behavior where `targets` specifies which modules to sparsity, not
which layers to target
  * Fix broken behavior with `_infer_owl_layer_sparsity` and add test
  * Fix owl argument validation
  * Add type hints and abstract methods for clarity
* Pipelines
* Fix bug revealed by decorators added to the llama model definition in
the latest transformers release
    * huggingface/transformers#35757
* For the sequential pipeline, this revealed a bug in
torch.fx._symbolic_trace where wrapped functions were not being handled
properly
    * Future work could involve upstreaming a bug fix
  * Fix issue caused by changes to llama model definition
    * huggingface/transformers#34858
* For the layer sequential pipeline, this challenges the assumption that
each layer input is the previous layer's output (which was known to be a
fragile assumption)
  * Fix issue related to basic pipeline slowdowns and inaccuracy

## Changes ##
* SparseGPT
  * Fully separate `targets` and `sequential_targets`
    * Modify hooks adding logic to reflect this change
  * Fix behavior of `_infer_owl_layer_sparsity` and add test
  * Code clarity
    * Add additional type hints
* Designate `calibrate_module` as an abstract method on the sgpt mixin
* Pipelines
* Sequential pipeline: unwrap model forward function to avoid issues
with pytorch function patching
* Layer Sequential Pipeline: Add `maybe_inject_pos_embeddings` to
sequential pipeline to hackily support models with `position_embeddings`
* Basic Pipeline: Fix `on_sequential_batch_end` to call on the end of
epoch, rather than every batch
    * Calling every batch was likely causing slowdowns

## Followups ##
* Remove deprecated `sequential_update` option from examples and tests

## Testing ##
* Added `tests/llmcompressor/transformers/obcq/test_obcq_owl.py`
* Tested OBCQ+llama with sequential, layer sequential, and basic
pipelines independently

## Regression Evaluations ##
Models were compressed using
`examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py` without fp8
option

<details><summary>sparsegpt</summary>

Main
```
vllm (pretrained=/home/kyle/llm-compressor/Meta-Llama-3-8B-InstructSparseGPTModifierMAIN,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1
|  Tasks   |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|----------|------:|------|-----:|------|---|-----:|---|-----:|
|winogrande|      1|none  |     5|acc   |↑  |0.6243|±  |0.0136|
```

This branch

```
vllm (pretrained=/home/kyle/llm-compressor/Meta-Llama-3-8B-InstructSparseGPTModifierFEATURE,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1
|  Tasks   |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|----------|------:|------|-----:|------|---|-----:|---|-----:|
|winogrande|      1|none  |     5|acc   |↑  |0.6306|±  |0.0136|
```
</details>

To test wanda, the `SparseGPTModifier` was replaced with the
`WandaPruningModifier`

<details><summary>wanda</summary>

Main
```
vllm (pretrained=/home/kyle/llm-compressor/Meta-Llama-3-8B-InstructWandaPruningModifierMAIN,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1
|  Tasks   |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|----------|------:|------|-----:|------|---|-----:|---|-----:|
|winogrande|      1|none  |     5|acc   |↑  |0.5912|±  |0.0138|
```

This branch
```
vllm (pretrained=/home/kyle/llm-compressor/Meta-Llama-3-8B-InstructWandaPruningModifierFEATURE,dtype=bfloat16,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 1
|  Tasks   |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|----------|------:|------|-----:|------|---|-----:|---|-----:|
|winogrande|      1|none  |     5|acc   |↑  |0.5817|±  |0.0139|
```
</details>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
  • Loading branch information
kylesayrs and dsikka authored Feb 11, 2025
1 parent e604f41 commit b55ec42
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 80 deletions.
20 changes: 4 additions & 16 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple

import torch
from compressed_tensors.utils import (
Expand Down Expand Up @@ -69,31 +69,19 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# modifier arguments
sparsity: Optional[Union[float, List[float]]] = None
sparsity_profile: Optional[str] = None
mask_structure: str = "0:0"
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
preserve_sparsity_mask: bool = False
offload_hessians: bool = False

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
Expand Down
89 changes: 52 additions & 37 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import warnings
from abc import abstractmethod
from collections import defaultdict
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy
import torch
from loguru import logger
from pydantic import Field, field_validator, model_validator
from pydantic import Field, PrivateAttr, field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
Expand All @@ -20,12 +20,13 @@
get_layers,
get_no_split_params,
get_prunable_layers,
match_targets,
)


class SparsityModifierMixin(HooksMixin):
# modifier arguments
sparsity: Optional[Union[float, List[float]]] = None
sparsity: Optional[Union[float, List[float]]]
sparsity_profile: Optional[str] = None
mask_structure: str = "0:0"
owl_m: Optional[int] = None
Expand All @@ -34,9 +35,15 @@ class SparsityModifierMixin(HooksMixin):
# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets
targets: Union[str, List[str]] = ["Linear"]
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
Expand All @@ -62,77 +69,80 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
return value

@model_validator(mode="after")
def validate_model_after(model: "Modifier") -> "Modifier":
sparsity = model.sparsity
def validate_model_after(model: "SparsityModifierMixin") -> "SparsityModifierMixin":
profile = model.sparsity_profile
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")

if profile != "owl" and (owl_m is not None or owl_lmbda is not None):
raise ValueError("Must provide both `owl_m` and `owl_lmbda`")

if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets
model.targets = None
has_owl_m = owl_m is not None
has_owl_lmbda = owl_lmbda is not None
has_owl = profile == "owl"
owl_args = (has_owl_m, has_owl_lmbda, has_owl)
if any(owl_args) and not all(owl_args):
raise ValueError(
'Must provide all of `profile="owl"`, `owl_m` and `owl_lmbda` or none'
)

model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

return model

@abstractmethod
def calibrate_module(
self,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
raise NotImplementedError()

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib
model: torch.nn.Module = state.model
dataloader: torch.utils.data.DataLoader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
target_layers = get_layers(self.targets, model) # layers containing targets

# infer layer sparsities
if self.sparsity_profile == "owl":
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

# get layers and validate sparsity
layers = get_layers(self.sequential_targets, model)
if isinstance(self.sparsity, (list, dict)) and len(layers) != len(
if isinstance(self.sparsity, (list, dict)) and len(target_layers) != len(
self.sparsity
):
raise ValueError(
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
f"sparsities values, but model only has {len(layers)} layers"
f"sparsities values, but model has {len(layers)} target layers"
)

# register hooks
for index, (name, layer) in enumerate(layers.items()):
for index, (layer_name, layer) in enumerate(target_layers.items()):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[name]
layer_sparsity = self.sparsity[layer_name]
elif isinstance(self.sparsity, list):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer).items():
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")
name = f"{layer_name}.{name}"
if not match_targets(name, self.ignore)[0]:
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
Expand Down Expand Up @@ -177,8 +187,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
run_basic(state.model, state.data.calib, self)
return True

return True

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
Expand All @@ -188,9 +196,16 @@ def _infer_sequential_targets(
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(self, activations):
def _infer_owl_layer_sparsity(
self,
model: torch.nn.Module,
layers: Dict[str, torch.nn.Module],
dataloader: torch.utils.data.DataLoader,
) -> Dict[str, float]:
activations = self._get_activations(model, dataloader)

groups = {}
for name, layer in self.compressible_layers_.items():
for name, layer in layers.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
Expand Down
20 changes: 3 additions & 17 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Tuple

import torch
from compressed_tensors.utils import (
Expand Down Expand Up @@ -58,29 +58,15 @@ class WandaPruningModifier(SparsityModifierMixin, Modifier):
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# sparsity arguments
sparsity: Optional[Union[float, List[float]]] = None
sparsity_profile: Optional[str] = None
mask_structure: str = "0:0"
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_row_scalars: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(
default_factory=dict
)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

def calibrate_module(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def run_pipeline(
batch = tensors_to_device(batch, model_device)
model(**batch)

# TODO: replace with a lifecycle event
if callback_modifier:
callback_modifier.on_sequential_batch_end()
# TODO: replace with a lifecycle event
if callback_modifier:
callback_modifier.on_sequential_batch_end()
31 changes: 30 additions & 1 deletion src/llmcompressor/pipelines/layer_sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils.helpers import calibration_forward_context

__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"]
__all__ = [
"match_modules",
"capture_first_layer_intermediates",
"to_next_layer_kwargs",
"maybe_inject_pos_embeddings",
]


def match_modules(model: Module, target_names: List[str]) -> List[Module]:
Expand Down Expand Up @@ -126,3 +131,27 @@ class EarlyStopException(Exception):

_args: Tuple[Any, ...]
_kwargs: Dict[str, Any]


def maybe_inject_pos_embeddings(
output: Dict[str, Any],
next_layer: Module,
inputs: Dict[str, Any],
) -> Dict[str, Any]:
"""
As of https://github.com/huggingface/transformers/pull/34858, positional embeddings
must be passed into each decoder call as kwargs
:param output: output of the previous layer
:param next_layer: next layer to call
:param inputs: inputs to next layer
"""
signature = inspect.signature(next_layer.forward)
if (
"position_embeddings" in signature.parameters.keys()
and "position_embeddings" in inputs
and "position_embeddings" not in output
):
output["position_embeddings"] = inputs["position_embeddings"]

return output
6 changes: 5 additions & 1 deletion src/llmcompressor/pipelines/layer_sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from llmcompressor.pipelines.layer_sequential.helpers import (
capture_first_layer_intermediates,
match_modules,
maybe_inject_pos_embeddings,
to_next_layer_kwargs,
)
from llmcompressor.utils.helpers import calibration_forward_context
Expand Down Expand Up @@ -79,6 +80,9 @@ def run_pipeline(
output = layer(**inputs)

if layer_index < num_layers - 1:
output = to_next_layer_kwargs(output, layers[layer_index + 1])
next_layer = layers[layer_index + 1]
output = to_next_layer_kwargs(output, next_layer)
output = maybe_inject_pos_embeddings(output, next_layer, inputs)

intermediates.delete(batch_index)
intermediates.update(batch_index, output)
21 changes: 19 additions & 2 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections import deque
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Set, Union

from compressed_tensors import has_offloaded_params
from compressed_tensors.quantization import find_name_or_class_matches
Expand All @@ -13,7 +13,7 @@
from transformers.utils.fx import HFTracer

from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr

__all__ = ["trace_subgraphs", "Subgraph"]

Expand Down Expand Up @@ -114,6 +114,7 @@ def get_tracer(
:param sequential_targets: modules which are sequential targets
:param ignore: modules which are ignored
"""
# TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets
offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))
skip_trace_modules = sequential_targets | offloaded_modules | ignore

Expand All @@ -132,6 +133,22 @@ def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
module, module_qualified_name
)

def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
if isinstance(root, Module):
with preserve_attr(type(root), "forward"):
# due to a bug in Tracer.create_args_for_root (_patch_function),
# we must unwrap function wrappers prior to tracing, for example
# the `deprecate_kwarg` by transformers which wraps forward

# we override the class method because the
# class method is the one being traced
type(root).forward = inspect.unwrap(type(root).forward)

return super().trace(root, *args, **kwargs)

else:
return super().trace(root, *args, **kwargs)

return SequentialTracer()


Expand Down
Loading

0 comments on commit b55ec42

Please sign in to comment.