Skip to content

Commit

Permalink
Merge branch 'main' into attn_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm authored Mar 10, 2025
2 parents 3d19401 + 2a59554 commit 0eb4c60
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings
from typing import Any, Dict, List, Union

from pydantic import field_validator

from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.helpers import (
Expand All @@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
update_scheduler: str = "cubic"
scheduler_args: Dict[str, Any] = {}
mask_structure: str = "unstructured"
leave_enabled: bool = True
leave_enabled: bool = False
apply_globally: bool = False

parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
Expand All @@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None

@field_validator("leave_enabled")
def validate_leave_enabled(value: bool) -> bool:
warnings.warn(
"MagnitudePruningModifier.leave_enable has been deprecated",
DeprecationWarning,
)
return False

def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
Expand Down Expand Up @@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
return True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.leave_enabled:
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)

return True

Expand Down Expand Up @@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
self._update_masks(event)

def on_end(self, state: State, event: Event, **kwargs):
if not self.leave_enabled:
self.disable_masks()

def on_event(self, state: State, event: Event, **kwargs):
if event.current_index >= self.end and self.leave_enabled:
self._update_masks(event)
self.disable_masks()

def _update_masks(self, event: Event):
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:
Expand Down

0 comments on commit 0eb4c60

Please sign in to comment.