diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 394156306..1329f17de 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -8,7 +8,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr +from pydantic import Field, PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -85,7 +85,8 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + targets: Union[str, List[str]] = ["Linear"] + ignore: List[str] = Field(default_factory=list) # private variables _prune_n: Optional[int] = PrivateAttr(default=None) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index c3cf585fc..721ef3fd4 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -1,4 +1,5 @@ import warnings +from abc import abstractmethod from collections import defaultdict from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -6,7 +7,7 @@ import numpy import torch from loguru import logger -from pydantic import Field, field_validator, model_validator +from pydantic import Field, PrivateAttr, field_validator, model_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -19,7 +20,7 @@ from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, - get_prunable_layers, + match_layers_params, ) @@ -34,9 +35,15 @@ class SparsityModifierMixin(HooksMixin): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + targets: Union[str, List[str]] = ["Linear"] ignore: List[str] = Field(default_factory=list) + # private variables + _prune_n: Optional[int] = PrivateAttr(default=None) + _prune_m: Optional[int] = PrivateAttr(default=None) + _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -62,14 +69,12 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool: return value @model_validator(mode="after") - def validate_model_after(model: "Modifier") -> "Modifier": + def validate_model_after(model: "SparsityModifierMixin") -> "Modifier": sparsity = model.sparsity profile = model.sparsity_profile owl_m = model.owl_m owl_lmbda = model.owl_lmbda mask_structure = model.mask_structure - targets = model.targets - sequential_targets = model.sequential_targets if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)): raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither") @@ -80,27 +85,31 @@ def validate_model_after(model: "Modifier") -> "Modifier": if owl_m is not None and sparsity is not None: raise ValueError("Cannot provide both sparsity and owl parameters") - if targets is not None: - if sequential_targets is not None: - raise ValueError("Cannot use both `targets` and `sequential_targets`") - model.sequential_targets = targets - model.targets = None - model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) return model + @abstractmethod + def calibrate_module( + self, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, + ): + raise NotImplementedError() + def on_initialize(self, state: "State", **kwargs) -> bool: """ Initialize and run the OBCQ algorithm on the current state :param state: session state storing input model and calibration data """ - model = state.model - dataloader = state.data.calib + model: torch.nn.Module = state.model + dataloader: torch.utils.data.DataLoader = state.data.calib # infer module and sequential targets self.sequential_targets = self._infer_sequential_targets(model) + layers = get_layers(self.sequential_targets, model) # infer layer sparsities if self.sparsity_profile == "owl": @@ -108,10 +117,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: "Using OWL to infer target layer-wise sparsities from " f"{len(dataloader) if dataloader else 0} calibration samples..." ) - self.sparsity = self._infer_owl_layer_sparsity() + self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) # get layers and validate sparsity - layers = get_layers(self.sequential_targets, model) if isinstance(self.sparsity, (list, dict)) and len(layers) != len( self.sparsity ): @@ -121,18 +129,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - for index, (name, layer) in enumerate(layers.items()): + #target_modules = match_layers_params(self.targets, model) + for index, (layer_name, layer) in enumerate(layers.items()): if isinstance(self.sparsity, dict): - layer_sparsity = self.sparsity[name] + layer_sparsity = self.sparsity[layer_name] elif isinstance(self.sparsity, list): layer_sparsity = self.sparsity[index] else: layer_sparsity = self.sparsity - for name, module in get_prunable_layers(layer).items(): - self._module_names[module] = name - self._module_sparsities[module] = layer_sparsity - self.register_hook(module, self.calibrate_module, "forward") + # TODO: match module or param + for name, module in layer.named_modules(prefix=layer_name): + if module in target_modules.values(): + self._module_names[module] = name + self._module_sparsities[module] = layer_sparsity + self.register_hook(module, self.calibrate_module, "forward") # infer and run pipeline model_name = state.model.__class__.__name__ @@ -177,8 +188,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: run_basic(state.model, state.data.calib, self) return True - return True - def _infer_sequential_targets( self, model: torch.nn.Module ) -> Union[str, List[str]]: @@ -188,15 +197,23 @@ def _infer_sequential_targets( return [self.sequential_targets] return self.sequential_targets - def _infer_owl_layer_sparsity(self, activations): + def _infer_owl_layer_sparsity( + self, + model: torch.nn.Module, + layers: Dict[str, torch.nn.Module], + dataloader: torch.utils.data.DataLoader, + ) -> Dict[str, float]: + activations = self._get_activations(model, dataloader) groups = {} - for name, layer in self.compressible_layers_.items(): - prunable_layers = get_prunable_layers(layer) + + target_modules = match_layers_params(self.targets, model) + for layer_name, layer in layers.items(): z = [ - m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) - for n, m in prunable_layers.items() + module.weight.abs() * activations[f"{layer_name}.{name}"].unsqueeze(0) + for name, module in layer.named_modules.items() + if module in target_modules.values() ] - groups[name] = torch.cat([item.flatten().cpu() for item in z]) + groups[layer_name] = torch.cat([item.flatten().cpu() for item in z]) del activations diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index fb3696933..291e5ae48 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -7,7 +7,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr +from pydantic import PrivateAttr, Field from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -69,8 +69,9 @@ class WandaPruningModifier(SparsityModifierMixin, Modifier): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated - sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + sequential_targets: Union[str, List[str]] = None + targets: Union[str, List[str], None] = ["Linear"] + ignore: List[str] = Field(default_factory=list) # private variables _prune_n: Optional[int] = PrivateAttr(default=None)