Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 7, 2025
1 parent eed650c commit b9bea3c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
5 changes: 3 additions & 2 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr
from pydantic import Field, PrivateAttr

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
Expand Down Expand Up @@ -85,7 +85,8 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets
targets: Union[str, List[str]] = ["Linear"]
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
Expand Down
77 changes: 47 additions & 30 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import warnings
from abc import abstractmethod
from collections import defaultdict
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy
import torch
from loguru import logger
from pydantic import Field, field_validator, model_validator
from pydantic import Field, PrivateAttr, field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
Expand All @@ -19,7 +20,7 @@
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_layers_params,
)


Expand All @@ -34,9 +35,15 @@ class SparsityModifierMixin(HooksMixin):
# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets
targets: Union[str, List[str]] = ["Linear"]
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
Expand All @@ -62,14 +69,12 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
return value

@model_validator(mode="after")
def validate_model_after(model: "Modifier") -> "Modifier":
def validate_model_after(model: "SparsityModifierMixin") -> "Modifier":
sparsity = model.sparsity
profile = model.sparsity_profile
owl_m = model.owl_m
owl_lmbda = model.owl_lmbda
mask_structure = model.mask_structure
targets = model.targets
sequential_targets = model.sequential_targets

if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)):
raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither")
Expand All @@ -80,38 +85,41 @@ def validate_model_after(model: "Modifier") -> "Modifier":
if owl_m is not None and sparsity is not None:
raise ValueError("Cannot provide both sparsity and owl parameters")

if targets is not None:
if sequential_targets is not None:
raise ValueError("Cannot use both `targets` and `sequential_targets`")
model.sequential_targets = targets
model.targets = None

model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)

return model

@abstractmethod
def calibrate_module(
self,
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
raise NotImplementedError()

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
model = state.model
dataloader = state.data.calib
model: torch.nn.Module = state.model
dataloader: torch.utils.data.DataLoader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)

# infer layer sparsities
if self.sparsity_profile == "owl":
logger.info(
"Using OWL to infer target layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_owl_layer_sparsity()
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

# get layers and validate sparsity
layers = get_layers(self.sequential_targets, model)
if isinstance(self.sparsity, (list, dict)) and len(layers) != len(
self.sparsity
):
Expand All @@ -121,18 +129,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
)

# register hooks
for index, (name, layer) in enumerate(layers.items()):
#target_modules = match_layers_params(self.targets, model)
for index, (layer_name, layer) in enumerate(layers.items()):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[name]
layer_sparsity = self.sparsity[layer_name]
elif isinstance(self.sparsity, list):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer).items():
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")
# TODO: match module or param
for name, module in layer.named_modules(prefix=layer_name):
if module in target_modules.values():
self._module_names[module] = name
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
Expand Down Expand Up @@ -177,8 +188,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
run_basic(state.model, state.data.calib, self)
return True

return True

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
Expand All @@ -188,15 +197,23 @@ def _infer_sequential_targets(
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(self, activations):
def _infer_owl_layer_sparsity(
self,
model: torch.nn.Module,
layers: Dict[str, torch.nn.Module],
dataloader: torch.utils.data.DataLoader,
) -> Dict[str, float]:
activations = self._get_activations(model, dataloader)
groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)

target_modules = match_layers_params(self.targets, model)
for layer_name, layer in layers.items():
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
module.weight.abs() * activations[f"{layer_name}.{name}"].unsqueeze(0)
for name, module in layer.named_modules.items()
if module in target_modules.values()
]
groups[name] = torch.cat([item.flatten().cpu() for item in z])
groups[layer_name] = torch.cat([item.flatten().cpu() for item in z])

del activations

Expand Down
7 changes: 4 additions & 3 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr
from pydantic import PrivateAttr, Field

from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
Expand Down Expand Up @@ -69,8 +69,9 @@ class WandaPruningModifier(SparsityModifierMixin, Modifier):

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str], None] = None # alias sequential_targets
sequential_targets: Union[str, List[str]] = None
targets: Union[str, List[str], None] = ["Linear"]
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
Expand Down

0 comments on commit b9bea3c

Please sign in to comment.