Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Callbacks] Remove pre_initialize_structure #1160

Merged
merged 29 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9b3e216
remove pre_initialize_structure
kylesayrs Feb 17, 2025
3bff7d1
remove preinit event
kylesayrs Feb 17, 2025
14e47a5
remove order test
kylesayrs Feb 17, 2025
c6a9e6b
Merge branch 'main' into kylesayrs/remove-preinitialize-structure
kylesayrs Feb 17, 2025
6b882bb
consolodate saving
kylesayrs Feb 18, 2025
bb35a74
typos
kylesayrs Feb 18, 2025
71903ff
add todos
kylesayrs Feb 18, 2025
d39d375
dreggs, style
kylesayrs Feb 18, 2025
7cc5a6d
typo
kylesayrs Feb 18, 2025
9865fa3
adjust typehint
kylesayrs Feb 18, 2025
68ce624
allow prepending
kylesayrs Feb 18, 2025
5b7cc03
check saved recipe contents
kylesayrs Feb 18, 2025
bdc4fa5
consolidate saving paths
kylesayrs Feb 18, 2025
a83b0aa
remove broken import
kylesayrs Feb 18, 2025
4efd116
Merge remote-tracking branch 'origin' into kylesayrs/consolidate-saving
kylesayrs Feb 18, 2025
b9f0bd1
add back def
kylesayrs Feb 18, 2025
29ab794
Merge remote-tracking branch 'origin' into kylesayrs/remove-preinitia…
kylesayrs Feb 18, 2025
0a2642b
save state
kylesayrs Feb 18, 2025
0c70881
Merge branch 'kylesayrs/consolidate-saving' into kylesayrs/remove-pre…
kylesayrs Feb 18, 2025
60371ef
remove verbose messages
kylesayrs Feb 18, 2025
bf9a8cd
fix double initialization
kylesayrs Feb 18, 2025
3d64d57
rename function
kylesayrs Feb 25, 2025
55984c4
remove accidentally added files
kylesayrs Feb 25, 2025
05fa5f6
Merge remote-tracking branch 'origin' into kylesayrs/remove-preinitia…
kylesayrs Feb 25, 2025
53d762e
Merge remote-tracking branch 'origin' into kylesayrs/remove-preinitia…
kylesayrs Feb 25, 2025
e026658
add debug statement
kylesayrs Feb 25, 2025
00961a0
pass model to stage runner
kylesayrs Feb 25, 2025
c7678b0
remove breakpoint
kylesayrs Feb 26, 2025
049e3cc
Merge branch 'main' into kylesayrs/remove-preinitialize-structure
dsikka Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,5 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
2 changes: 0 additions & 2 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -37,7 +36,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class EventType(Enum):
The purpose of each EventType is to trigger the corresponding
modifier callback during training or post training pipelines.

:param PRE_INIT: Event type for pre-initialization.
:param INITIALIZE: Event type for initialization.
:param FINALIZE: Event type for finalization.
:param BATCH_START: Event type for the start of a batch.
Expand All @@ -38,7 +37,6 @@ class EventType(Enum):
"""

# training lifecycle
PRE_INIT = "pre_init"
INITIALIZE = "initialize"
FINALIZE = "finalize"

Expand All @@ -51,35 +49,6 @@ class EventType(Enum):
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"

def order(self) -> int:
"""
Returns the priority order of the current EventType.
Lower values have higher priority.

:raises ValueError: if the event type is invalid.
:return: The order of the event type, lower has higher priority.
:rtype: int
"""
if self == EventType.PRE_INIT:
return 0
elif self == EventType.INITIALIZE:
return 10
elif self == EventType.FINALIZE:
return 20
elif self == EventType.BATCH_START:
return 100
elif self == EventType.LOSS_CALCULATED:
return 110
elif self == EventType.OPTIM_PRE_STEP:
return 120
elif self == EventType.OPTIM_POST_STEP:
return 130
elif self == EventType.BATCH_END:
return 140
else:
logger.error("Invalid event type: {}", self)
raise ValueError(f"Invalid event type {self}")


@dataclass
class Event:
Expand Down
99 changes: 24 additions & 75 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)

__all__ = ["CompressionLifecycle"]

Expand All @@ -38,12 +43,11 @@ class CompressionLifecycle:
:type event_lifecycle: Optional[EventLifecycle]
"""

state: Optional[State] = None
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
event_lifecycle: Optional[EventLifecycle] = None

initialized_structure: bool = False
initialized_: bool = False
finalized: bool = False
event_called: bool = False
Expand All @@ -64,66 +68,35 @@ def reset(self):
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")

self.state = None
self.recipe_container = RecipeContainer()
self.modifiers = []
self.event_lifecycle = None

self.initialized_structure = False
self.initialized_ = False
self.finalized = False
self.event_called = False
self.__init__()
logger.info("Compression lifecycle reset")

def pre_initialize_structure(self, **kwargs) -> List[Any]:
"""
Pre-initialize the structure of the compression lifecycle.

:param kwargs: Additional arguments to update the state with
:return: List of data returned from pre-initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Pre-initializing structure")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
mod_data = []
for mod in self.modifiers:
data = mod.pre_initialize_structure(state=self.state, **extras)
logger.debug("Pre-initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)

self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle structure pre-initialized for {} modifiers",
len(self.modifiers),
)

return mod_data

def initialize(self, **kwargs) -> List[Any]:
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.

:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Initializing compression lifecycle")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return

self._check_compile_recipe()
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
Expand Down Expand Up @@ -190,7 +163,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")

if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
Expand Down Expand Up @@ -229,30 +202,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _check_create_state(self):
if self.state is not None:
return

logger.debug("Creating new State instance for compression lifecycle")
self.state = State()
logger.info("State created for compression lifecycle")

def _check_compile_recipe(self):
if not self.recipe_container.check_compile_recipe():
return

logger.debug(
"Compiling recipe and creating modifiers for compression lifecycle"
)
self.modifiers = self.recipe_container.compiled_recipe.create_modifier()
for mod in self.modifiers:
if mod.unique_id in self.recipe_container.applied_stages:
mod.applied = True
logger.info(
"Recipe compiled and {} modifiers created",
len(self.modifiers),
)

def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
return
Expand Down
39 changes: 0 additions & 39 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,6 @@ def state(self) -> State:
"""
return self._lifecycle.state

def pre_initialize_structure(
self,
model: Any,
recipe: Union[str, List[str], Recipe, List[Recipe], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None,
**kwargs,
) -> ModifiedState:
"""
A method to pre-initialize the structure of the model for compression.
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previously modified by a modifier.

:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the compression, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list
of recipe objects.
:param recipe_stage: the stage to use for the compression
:param recipe_args: the args to use for overriding the recipe defaults
:return: A ModifiedState instance holding the modified model and modifier_data
after pre-initializing the structure
"""
mod_data = self._lifecycle.pre_initialize_structure(
model=model,
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
**kwargs,
)

return ModifiedState(
model=self.state.model,
optimizer=None,
loss=None,
modifier_data=mod_data,
)

def initialize(
self,
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
Expand Down
13 changes: 1 addition & 12 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down Expand Up @@ -60,16 +59,6 @@ def reset_session():
session._lifecycle.reset()


def pre_initialize_structure(**kwargs):
"""
A method to pre-initialize the structure of the model for the active session

:param kwargs: the kwargs to pass to the active session's pre-initialize-structure
method
"""
active_session().pre_initialize_structure(**kwargs)


def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
Expand Down Expand Up @@ -213,7 +202,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]:
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
Expand Down
19 changes: 0 additions & 19 deletions src/llmcompressor/modifiers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ class ModifierInterface(ABC):
Defines the contract that all modifiers must implement
"""

@property
@abstractmethod
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
raise NotImplementedError()

@property
@abstractmethod
def initialized(self) -> bool:
Expand Down Expand Up @@ -58,16 +49,6 @@ def calculate_end(self) -> float:
"""
raise NotImplementedError()

@abstractmethod
def pre_initialize_structure(self, state: State, **kwargs):
"""
Apply the modifier structure to the model

:param state: The current state of the model
:param kwargs: Additional arguments for the modifier
"""
raise NotImplementedError()

@abstractmethod
def initialize(self, state: State, **kwargs):
"""
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin):
end: Optional[float] = None
update: Optional[float] = None

initialized_structure_: bool = False
initialized_: bool = False
finalized_: bool = False
started_: bool = False
ended_: bool = False

@property
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
return self.initialized_structure_

@property
def initialized(self) -> bool:
"""
Expand Down Expand Up @@ -78,15 +69,6 @@ def calculate_end(self) -> float:
"""
return self.end if self.end is not None else -1

def pre_initialize_structure(self, state: State, **kwargs):
"""
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
self.on_initialize_structure(state, **kwargs)
self.initialized_structure_ = True

def initialize(self, state: State, **kwargs):
"""
Initialize the modifier for the given model and state.
Expand Down Expand Up @@ -221,19 +203,6 @@ def should_end(self, event: Event):

return self.end is not None and current >= self.end

def on_initialize_structure(self, state: State, **kwargs):
"""
on_initialize_structure is called before the model is initialized
with the modifier structure.

TODO: Depreciate this function as part of the lifecycle

:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
pass

@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down
Loading