|
1 | 1 | import contextlib
|
2 | 2 | from functools import wraps
|
3 |
| -from typing import Any, Callable, ClassVar, List, Union |
| 3 | +from typing import Any, Callable, ClassVar, Optional, Set, Union |
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | from loguru import logger
|
@@ -30,7 +30,7 @@ class HooksMixin(BaseModel):
|
30 | 30 | """
|
31 | 31 |
|
32 | 32 | _HOOKS_DISABLED: ClassVar[bool] = False # attached to global HooksMixin
|
33 |
| - _hooks: List[RemovableHandle] = [] # attached to local subclasses |
| 33 | + _hooks: Set[RemovableHandle] = set() # attached to local subclasses |
34 | 34 |
|
35 | 35 | @classmethod
|
36 | 36 | @contextlib.contextmanager
|
@@ -70,14 +70,22 @@ def wrapped_hook(*args, **kwargs):
|
70 | 70 |
|
71 | 71 | register_function = getattr(target, f"register_{hook_type}_hook")
|
72 | 72 | handle = register_function(wrapped_hook, **kwargs)
|
73 |
| - self._hooks.append(handle) |
| 73 | + self._hooks.add(handle) |
74 | 74 | logger.debug(f"{self} added {handle}")
|
75 | 75 |
|
76 | 76 | return handle
|
77 | 77 |
|
78 |
| - def remove_hooks(self): |
79 |
| - """Remove all hooks belonging to a modifier""" |
80 |
| - for hook in self._hooks: |
| 78 | + def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None): |
| 79 | + """ |
| 80 | + Removes hooks registered by this modifier |
| 81 | +
|
| 82 | + :param handles: optional list of handles to remove, defaults to all hooks |
| 83 | + registerd by this modifier |
| 84 | + """ |
| 85 | + if handles is None: |
| 86 | + handles = self._hooks |
| 87 | + |
| 88 | + for hook in handles: |
81 | 89 | hook.remove()
|
82 | 90 |
|
83 |
| - self._hooks = [] |
| 91 | + self._hooks -= handles |
0 commit comments