Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Callbacks] Remove on_update #1199

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/distillation/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def on_start(self, state: State, event: Event, **kwargs):
teacher_wrapper.kd_enabled = True
self.wrapped_kd_model_.kd_enabled = True

def on_update(self, state: State, event: Event, **kwargs):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.LOSS_CALCULATED and event.should_update(
self.start, self.end, self.update
):
Expand Down
29 changes: 4 additions & 25 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def finalize(self, state: State, **kwargs):
def update_event(self, state: State, event: Event, **kwargs):
"""
Update modifier based on the given event. In turn calls
on_start, on_update, and on_end based on the event and
on_start, on_event, and on_end based on the event and
modifier settings. Returns immediately if the modifier is
not initialized

Expand All @@ -148,12 +148,10 @@ def update_event(self, state: State, event: Event, **kwargs):
:param kwargs: Additional arguments for updating the modifier
"""
if not self.initialized_:
return
raise RuntimeError("Please call `initialize()` before triggering events")

if self.finalized_:
raise RuntimeError("cannot update a finalized modifier")

self.on_event(state, event, **kwargs)
raise RuntimeError("Cannot trigger events after `finalize()`")

# handle starting the modifier if needed
if (
Expand All @@ -163,9 +161,8 @@ def update_event(self, state: State, event: Event, **kwargs):
):
self.on_start(state, event, **kwargs)
self.started_ = True
self.on_update(state, event, **kwargs)

return
self.on_event(state, event, **kwargs)

# handle ending the modifier if needed
if (
Expand All @@ -175,12 +172,6 @@ def update_event(self, state: State, event: Event, **kwargs):
):
self.on_end(state, event, **kwargs)
self.ended_ = True
self.on_update(state, event, **kwargs)

return

if self.started_ and not self.ended_:
self.on_update(state, event, **kwargs)

def should_start(self, event: Event) -> bool:
"""
Expand Down Expand Up @@ -239,18 +230,6 @@ def on_start(self, state: State, event: Event, **kwargs):
"""
pass

def on_update(self, state: State, event: Event, **kwargs):
"""
on_update is called when the model in question must be
updated based on passed in event. Must be implemented by the
inheriting modifier.

:param state: The current state of the model
:param event: The event that triggered the update
:param kwargs: Additional arguments for updating the model
"""
pass

def on_end(self, state: State, event: Event, **kwargs):
"""
on_end is called when the modifier ends and must be implemented
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/pruning/constant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def on_start(self, state: State, event: Event, **kwargs):
self.enable_masks()

@torch.no_grad()
def on_update(self, state: State, event: Event, **kwargs):
def on_event(self, state: State, event: Event, **kwargs):
if self._use_hooks:
# hooks are used to update, so nothing to do here
return
Expand Down
27 changes: 16 additions & 11 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings
from typing import Any, Dict, List, Union

from pydantic import field_validator

from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.helpers import (
Expand All @@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
update_scheduler: str = "cubic"
scheduler_args: Dict[str, Any] = {}
mask_structure: str = "unstructured"
leave_enabled: bool = True
leave_enabled: bool = False
apply_globally: bool = False

parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
Expand All @@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None

@field_validator("leave_enabled")
def validate_leave_enabled(value: bool) -> bool:
warnings.warn(
"MagnitudePruningModifier.leave_enable has been deprecated",
DeprecationWarning,
)
return False

def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
Expand Down Expand Up @@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
return True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.leave_enabled:
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)

return True

Expand All @@ -97,7 +107,7 @@ def on_start(self, state: State, event: Event, **kwargs):

self.enable_masks()

def on_update(self, state: State, event: Event, **kwargs):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
sparsity = self.scheduler_function_(event, state)
if sparsity != self.current_sparsity_:
Expand All @@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
self._update_masks(event)

def on_end(self, state: State, event: Event, **kwargs):
if not self.leave_enabled:
self.disable_masks()

def on_event(self, state: State, event: Event, **kwargs):
if event.current_index >= self.end and self.leave_enabled:
self._update_masks(event)
self.disable_masks()

def _update_masks(self, event: Event):
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def on_start(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(update_weight_zp_scale)

def on_update(self, state: State, event: Event, **kwargs):
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
if self.check_should_disable_observer(event):
module = state.model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ def test_constant_pruning_modifier_e2e(model, optimizer):
assert manipulated_sparsities != expected_sparsities, "Sparsity manipulation failed"

# apply modifier

modifier.on_update(state, event=Event(type_=EventType.OPTIM_PRE_STEP))
modifier.on_update(state, event=Event(type_=EventType.OPTIM_POST_STEP))
modifier.on_event(state, event=Event(type_=EventType.OPTIM_PRE_STEP))
modifier.on_event(state, event=Event(type_=EventType.OPTIM_POST_STEP))
modifier.on_end(state, None)

# copy old mask settings as finalize will remove them
Expand Down