Skip to content

Commit 507b1a4

Browse files
authored
Extend remove_hooks to remove subsets (#1021)
## Purpose ## * Allow subsets of hooks to be removed * Not strictly needed but helps promote code clarity in the case of wanda which adds and removes subsets of hooks at different times. ## Postrequisites ## * #1023 * Layer compressor deprecation ## Changes ## * Change the datatype of `_hooks` from `List` to `Set` * Add `handles` argument to `HooksMixin.remove_hooks` ## Testing ## * Added `test_remove_hooks_parameterized` test --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent ba8563c commit 507b1a4

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

src/llmcompressor/modifiers/utils/hooks.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
from functools import wraps
3-
from typing import Any, Callable, ClassVar, List, Union
3+
from typing import Any, Callable, ClassVar, Optional, Set, Union
44

55
import torch
66
from loguru import logger
@@ -30,7 +30,7 @@ class HooksMixin(BaseModel):
3030
"""
3131

3232
_HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin
33-
_hooks: List[RemovableHandle] = [] # attached to local subclasses
33+
_hooks: Set[RemovableHandle] = set() # attached to local subclasses
3434

3535
@classmethod
3636
@contextlib.contextmanager
@@ -70,14 +70,22 @@ def wrapped_hook(*args, **kwargs):
7070

7171
register_function = getattr(target, f"register_{hook_type}_hook")
7272
handle = register_function(wrapped_hook, **kwargs)
73-
self._hooks.append(handle)
73+
self._hooks.add(handle)
7474
logger.debug(f"{self} added {handle}")
7575

7676
return handle
7777

78-
def remove_hooks(self):
79-
"""Remove all hooks belonging to a modifier"""
80-
for hook in self._hooks:
78+
def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
79+
"""
80+
Removes hooks registered by this modifier
81+
82+
:param handles: optional list of handles to remove, defaults to all hooks
83+
registerd by this modifier
84+
"""
85+
if handles is None:
86+
handles = self._hooks
87+
88+
for hook in handles:
8189
hook.remove()
8290

83-
self._hooks = []
91+
self._hooks -= handles

tests/llmcompressor/modifiers/utils/test_hooks.py

+21
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ def test_remove_hooks():
6464
assert mod_a.hook_called and not mod_b.hook_called
6565

6666

67+
def test_remove_hooks_parameterized():
68+
model = DummyModel()
69+
70+
mod_a = ModA()
71+
mod_a_pre_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward_pre")
72+
mod_a_post_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward")
73+
74+
mod_b = ModB()
75+
mod_b_pre_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre")
76+
mod_b_post_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward")
77+
78+
mod_a.remove_hooks(set([mod_a_post_hook]))
79+
mod_b.remove_hooks(set([mod_b_pre_hook]))
80+
81+
assert len(mod_a._hooks) == 1 and next(iter(mod_a._hooks)) == mod_a_pre_hook
82+
assert len(mod_b._hooks) == 1 and next(iter(mod_b._hooks)) == mod_b_post_hook
83+
84+
model(model.dummy_inputs)
85+
assert mod_a.hook_called and mod_b.hook_called
86+
87+
6788
def test_disable_hooks():
6889
model = DummyModel()
6990

0 commit comments

Comments
 (0)