Skip to content

Commit

Permalink
[Callbacks] Remove MagnitudePruningModifier.leave_enabled (#1198)
Browse files Browse the repository at this point in the history
## Purpose ##
* Simplify the modifier lifecycle by removing the ability for modifiers
to affect the model after the modifier's `end` event
* This allows the `on_event` method to be removed in a future change

## Background ##
* The `leave_enabled` option was originally intended as a shortcut to
simplify recipes which used magnitude pruning during the iterative
pruning, then needed the masks to stay enabled during stabilization SFT
* This change proposes making the recipe clearer by requiring the
ConstantPruningModifier after the MagnitudePruningModifier becomes
inactive

## Changes ##
* Remove `MagnitudePruningModifier.leave_enabled` with a deprecation
warning

Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs authored Mar 7, 2025
1 parent 4607036 commit 2a59554
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings
from typing import Any, Dict, List, Union

from pydantic import field_validator

from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.helpers import (
Expand All @@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
update_scheduler: str = "cubic"
scheduler_args: Dict[str, Any] = {}
mask_structure: str = "unstructured"
leave_enabled: bool = True
leave_enabled: bool = False
apply_globally: bool = False

parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
Expand All @@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None

@field_validator("leave_enabled")
def validate_leave_enabled(value: bool) -> bool:
warnings.warn(
"MagnitudePruningModifier.leave_enable has been deprecated",
DeprecationWarning,
)
return False

def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
Expand Down Expand Up @@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
return True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.leave_enabled:
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)

return True

Expand Down Expand Up @@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
self._update_masks(event)

def on_end(self, state: State, event: Event, **kwargs):
if not self.leave_enabled:
self.disable_masks()

def on_event(self, state: State, event: Event, **kwargs):
if event.current_index >= self.end and self.leave_enabled:
self._update_masks(event)
self.disable_masks()

def _update_masks(self, event: Event):
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:
Expand Down

0 comments on commit 2a59554

Please sign in to comment.