18
18
)
19
19
from llmcompressor .core .state import State
20
20
from llmcompressor .modifiers import StageModifiers
21
- from llmcompressor .recipe import RecipeContainer
21
+ from llmcompressor .recipe import (
22
+ RecipeArgsInput ,
23
+ RecipeContainer ,
24
+ RecipeInput ,
25
+ RecipeStageInput ,
26
+ )
22
27
23
28
__all__ = ["CompressionLifecycle" ]
24
29
@@ -38,7 +43,7 @@ class CompressionLifecycle:
38
43
:type event_lifecycle: Optional[EventLifecycle]
39
44
"""
40
45
41
- state : Optional [ State ] = None
46
+ state : State = field ( default_factory = State )
42
47
recipe_container : RecipeContainer = field (default_factory = RecipeContainer )
43
48
modifiers : List [StageModifiers ] = field (default_factory = list )
44
49
event_lifecycle : Optional [EventLifecycle ] = None
@@ -62,63 +67,35 @@ def reset(self):
62
67
except Exception as e :
63
68
logger .warning (f"Exception during finalizing modifier: { e } " )
64
69
65
- self .state = None
66
- self .recipe_container = RecipeContainer ()
67
- self .modifiers = []
68
- self .event_lifecycle = None
69
-
70
- self .initialized_ = False
71
- self .finalized = False
70
+ self .__init__ ()
72
71
logger .info ("Compression lifecycle reset" )
73
72
74
- def pre_initialize_structure (self , ** kwargs ) -> List [Any ]:
75
- """
76
- Pre-initialize the structure of the compression lifecycle.
77
-
78
- :param kwargs: Additional arguments to update the state with
79
- :return: List of data returned from pre-initialization of modifiers
80
- :rtype: List[Any]
81
- """
82
- logger .debug ("Pre-initializing structure" )
83
- self ._check_create_state ()
84
- extras = self .state .update (** kwargs )
85
- extras = self .recipe_container .update (** extras )
86
-
87
- self ._check_compile_recipe ()
88
- mod_data = []
89
- for mod in self .modifiers :
90
- data = mod .pre_initialize_structure (state = self .state , ** extras )
91
- logger .debug ("Pre-initialized modifier: {}" , mod )
92
- if data is not None :
93
- mod_data .append (data )
94
-
95
- applied_stage_names = [mod .unique_id for mod in self .modifiers if mod .applied ]
96
- self .recipe_container .update_applied_stages (applied_stage_names )
97
- logger .info (
98
- "Compression lifecycle structure pre-initialized for {} modifiers" ,
99
- len (self .modifiers ),
100
- )
101
-
102
- return mod_data
103
-
104
- def initialize (self , ** kwargs ) -> List [Any ]:
73
+ def initialize (
74
+ self ,
75
+ recipe : Optional [RecipeInput ] = None ,
76
+ recipe_stage : Optional [RecipeStageInput ] = None ,
77
+ recipe_args : Optional [RecipeArgsInput ] = None ,
78
+ ** kwargs ,
79
+ ) -> List [Any ]:
105
80
"""
106
81
Initialize the compression lifecycle.
107
82
108
83
:param kwargs: Additional arguments to update the state with
109
84
:return: List of data returned from initialization of modifiers
110
85
:rtype: List[Any]
111
86
"""
112
- logger .debug ("Initializing compression lifecycle" )
113
- self ._check_create_state ()
114
- extras = self .state .update (** kwargs )
115
- extras = self .recipe_container .update (** extras )
87
+ self .state .update (** kwargs )
88
+ if self .initialized_ : # TODO: do not initialize twice
89
+ return
116
90
117
- self ._check_compile_recipe ()
91
+ logger .debug ("Initializing compression lifecycle" )
92
+ self .recipe_container .append (recipe , recipe_stage , recipe_args )
93
+ self .modifiers = self .recipe_container .get_modifiers ()
118
94
self ._set_model_layer_prefix ()
95
+
119
96
mod_data = []
120
97
for mod in self .modifiers :
121
- data = mod .initialize (state = self .state , ** extras )
98
+ data = mod .initialize (state = self .state , ** kwargs )
122
99
logger .debug ("Initialized modifier: {}" , mod )
123
100
if data is not None :
124
101
mod_data .append (data )
@@ -185,7 +162,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
185
162
logger .error ("Cannot invoke event after finalizing" )
186
163
raise ValueError ("Cannot invoke event after finalizing" )
187
164
188
- if event_type in [EventType .PRE_INIT , EventType . INITIALIZE , EventType .FINALIZE ]:
165
+ if event_type in [EventType .INITIALIZE , EventType .FINALIZE ]:
189
166
logger .error (
190
167
"Cannot invoke {} event. Use the corresponding method instead." ,
191
168
event_type ,
@@ -223,30 +200,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
223
200
224
201
return mod_data
225
202
226
- def _check_create_state (self ):
227
- if self .state is not None :
228
- return
229
-
230
- logger .debug ("Creating new State instance for compression lifecycle" )
231
- self .state = State ()
232
- logger .info ("State created for compression lifecycle" )
233
-
234
- def _check_compile_recipe (self ):
235
- if not self .recipe_container .check_compile_recipe ():
236
- return
237
-
238
- logger .debug (
239
- "Compiling recipe and creating modifiers for compression lifecycle"
240
- )
241
- self .modifiers = self .recipe_container .compiled_recipe .create_modifier ()
242
- for mod in self .modifiers :
243
- if mod .unique_id in self .recipe_container .applied_stages :
244
- mod .applied = True
245
- logger .info (
246
- "Recipe compiled and {} modifiers created" ,
247
- len (self .modifiers ),
248
- )
249
-
250
203
def _check_setup_event_lifecycle (self , event_type : EventType ):
251
204
if self .event_lifecycle is not None :
252
205
return
0 commit comments