diff --git a/src/llmcompressor/modifiers/pruning/magnitude/base.py b/src/llmcompressor/modifiers/pruning/magnitude/base.py index e557ef091..fb0fa1817 100644 --- a/src/llmcompressor/modifiers/pruning/magnitude/base.py +++ b/src/llmcompressor/modifiers/pruning/magnitude/base.py @@ -1,5 +1,8 @@ +import warnings from typing import Any, Dict, List, Union +from pydantic import field_validator + from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.pruning.helpers import ( @@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking): update_scheduler: str = "cubic" scheduler_args: Dict[str, Any] = {} mask_structure: str = "unstructured" - leave_enabled: bool = True + leave_enabled: bool = False apply_globally: bool = False parameterized_layers_: Dict[str, ModelParameterizedLayer] = None @@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking): mask_creator_function_: MaskCreatorType = None current_sparsity_: float = None + @field_validator("leave_enabled") + def validate_leave_enabled(value: bool) -> bool: + warnings.warn( + "MagnitudePruningModifier.leave_enable has been deprecated", + DeprecationWarning, + ) + return False + def on_initialize(self, state: State, **kwargs) -> bool: if self.apply_globally: raise NotImplementedError("global pruning not implemented yet for PyTorch") @@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True def on_finalize(self, state: State, **kwargs) -> bool: - if not self.leave_enabled: - for layer_param_name, _ in self.parameterized_layers_.items(): - self.remove_mask(layer_param_name) + for layer_param_name, _ in self.parameterized_layers_.items(): + self.remove_mask(layer_param_name) return True @@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs): self._update_masks(event) def on_end(self, state: State, event: Event, **kwargs): - if not self.leave_enabled: - self.disable_masks() - - def on_event(self, state: State, event: Event, **kwargs): - if event.current_index >= self.end and self.leave_enabled: - self._update_masks(event) + self.disable_masks() def _update_masks(self, event: Event): if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks: