Skip to content

Commit

Permalink
remove pre_initialize_structure
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Feb 17, 2025
1 parent c2db397 commit 9b3e216
Show file tree
Hide file tree
Showing 15 changed files with 18 additions and 235 deletions.
1 change: 0 additions & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,5 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
2 changes: 0 additions & 2 deletions src/llmcompressor/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
create_session,
finalize,
initialize,
pre_initialize_structure,
reset_session,
)
from llmcompressor.core.state import Data, Hardware, ModifiedState, State
Expand All @@ -37,7 +36,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down
53 changes: 2 additions & 51 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ class CompressionLifecycle:
:type event_lifecycle: Optional[EventLifecycle]
"""

state: Optional[State] = None
state: Optional[State] = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
event_lifecycle: Optional[EventLifecycle] = None

initialized_structure: bool = False
initialized_: bool = False
finalized: bool = False
event_called: bool = False
Expand All @@ -64,48 +63,9 @@ def reset(self):
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")

self.state = None
self.recipe_container = RecipeContainer()
self.modifiers = []
self.event_lifecycle = None

self.initialized_structure = False
self.initialized_ = False
self.finalized = False
self.event_called = False
self.__init__()
logger.info("Compression lifecycle reset")

def pre_initialize_structure(self, **kwargs) -> List[Any]:
"""
Pre-initialize the structure of the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from pre-initialization of modifiers
:rtype: List[Any]
"""
logger.debug("Pre-initializing structure")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
mod_data = []
for mod in self.modifiers:
data = mod.pre_initialize_structure(state=self.state, **extras)
logger.debug("Pre-initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)

self.initialized_structure = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle structure pre-initialized for {} modifiers",
len(self.modifiers),
)

return mod_data

def initialize(self, **kwargs) -> List[Any]:
"""
Initialize the compression lifecycle.
Expand All @@ -115,7 +75,6 @@ def initialize(self, **kwargs) -> List[Any]:
:rtype: List[Any]
"""
logger.debug("Initializing compression lifecycle")
self._check_create_state()
extras = self.state.update(**kwargs)
extras = self.recipe_container.update(**extras)

Expand Down Expand Up @@ -229,14 +188,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:

return mod_data

def _check_create_state(self):
if self.state is not None:
return

logger.debug("Creating new State instance for compression lifecycle")
self.state = State()
logger.info("State created for compression lifecycle")

def _check_compile_recipe(self):
if not self.recipe_container.check_compile_recipe():
return
Expand Down
39 changes: 0 additions & 39 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,6 @@ def state(self) -> State:
"""
return self._lifecycle.state

def pre_initialize_structure(
self,
model: Any,
recipe: Union[str, List[str], Recipe, List[Recipe], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None,
**kwargs,
) -> ModifiedState:
"""
A method to pre-initialize the structure of the model for compression.
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previously modified by a modifier.
:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the compression, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list
of recipe objects.
:param recipe_stage: the stage to use for the compression
:param recipe_args: the args to use for overriding the recipe defaults
:return: A ModifiedState instance holding the modified model and modifier_data
after pre-initializing the structure
"""
mod_data = self._lifecycle.pre_initialize_structure(
model=model,
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
**kwargs,
)

return ModifiedState(
model=self.state.model,
optimizer=None,
loss=None,
modifier_data=mod_data,
)

def initialize(
self,
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
Expand Down
11 changes: 0 additions & 11 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"create_session",
"active_session",
"reset_session",
"pre_initialize_structure",
"initialize",
"finalize",
"apply",
Expand Down Expand Up @@ -60,16 +59,6 @@ def reset_session():
session._lifecycle.reset()


def pre_initialize_structure(**kwargs):
"""
A method to pre-initialize the structure of the model for the active session
:param kwargs: the kwargs to pass to the active session's pre-initialize-structure
method
"""
active_session().pre_initialize_structure(**kwargs)


def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
Expand Down
19 changes: 0 additions & 19 deletions src/llmcompressor/modifiers/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ class ModifierInterface(ABC):
Defines the contract that all modifiers must implement
"""

@property
@abstractmethod
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
raise NotImplementedError()

@property
@abstractmethod
def initialized(self) -> bool:
Expand Down Expand Up @@ -58,16 +49,6 @@ def calculate_end(self) -> float:
"""
raise NotImplementedError()

@abstractmethod
def pre_initialize_structure(self, state: State, **kwargs):
"""
Apply the modifier structure to the model
:param state: The current state of the model
:param kwargs: Additional arguments for the modifier
"""
raise NotImplementedError()

@abstractmethod
def initialize(self, state: State, **kwargs):
"""
Expand Down
31 changes: 0 additions & 31 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin):
end: Optional[float] = None
update: Optional[float] = None

initialized_structure_: bool = False
initialized_: bool = False
finalized_: bool = False
started_: bool = False
ended_: bool = False

@property
def initialized_structure(self) -> bool:
"""
:return: True if the modifier structure has been
applied to the model
"""
return self.initialized_structure_

@property
def initialized(self) -> bool:
"""
Expand Down Expand Up @@ -78,15 +69,6 @@ def calculate_end(self) -> float:
"""
return self.end if self.end is not None else -1

def pre_initialize_structure(self, state: State, **kwargs):
"""
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
self.on_initialize_structure(state, **kwargs)
self.initialized_structure_ = True

def initialize(self, state: State, **kwargs):
"""
Initialize the modifier for the given model and state.
Expand Down Expand Up @@ -221,19 +203,6 @@ def should_end(self, event: Event):

return self.end is not None and current >= self.end

def on_initialize_structure(self, state: State, **kwargs):
"""
on_initialize_structure is called before the model is initialized
with the modifier structure.
TODO: Depreciate this function as part of the lifecycle
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
pass

@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down
18 changes: 7 additions & 11 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ class GPTQModifier(Modifier, HooksMixin):
| actorder: False
Lifecycle:
- on_initialize_structure
- _build_quant_modifier
- on_initialize
- _build_quant_modifier
- register_hook(module, compress_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
Expand Down Expand Up @@ -141,16 +140,16 @@ def validate_sequential_update(cls, value: bool) -> bool:

return True

def on_initialize_structure(self, state: State, **kwargs):
def _maybe_build_quant_modifier(self, model: torch.nn.Module):
"""
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
TODO: Depreciate and fold into `on_initialize`
# TODO: build modifier during recipe validation
:param state: session state storing input model and calibration data
"""
quantization_already_active = qat_active(state.model)
quantization_already_active = qat_active(model)
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
logger.warning(
Expand Down Expand Up @@ -191,18 +190,15 @@ def on_initialize_structure(self, state: State, **kwargs):
self._build_quant_modifier_from_dict(self.quantize)
self.quantize = True

if self._quantization_modifier:
self._quantization_modifier.on_initialize_structure(state, **kwargs)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the GPTQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
# initialize quantization modifier
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
# build quantization modifier
self._maybe_build_quant_modifier(state.model)

if self._quantization_modifier:
self._quantization_modifier.initialize(state, **kwargs)
if not self.quantize:
Expand Down
24 changes: 1 addition & 23 deletions src/llmcompressor/modifiers/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,14 @@ class StageModifiers(ModifierInterface, BaseModel):
:param index: The index of the stage, if applicable
:param group: The group name of the stage, if applicable
:param applied: Flag for indicating if this stage has has already been
applied to the model, through structure initialization or finalization
applied to the model through finalization
"""

modifiers: List["Modifier"] = Field(default_factory=list)
index: Optional[int] = None
group: Optional[str] = None
applied: bool = False

@property
def initialized_structure(self) -> bool:
"""
:return: True if any of the stage modifiers have initialized structure,
False otherwise
"""
return any(mod.initialized_structure for mod in self.modifiers)

@property
def initialized(self) -> bool:
"""
Expand Down Expand Up @@ -93,20 +85,6 @@ def calculate_end(self) -> float:
"""
return max(mod.calculate_end() for mod in self.modifiers)

def pre_initialize_structure(self, state: "State", **kwargs):
"""
Pre initialize the structure for all stage modifiers mark the stage applied
:param state: The current state of the training
:param kwargs: Additional kwargs to pass to the modifier(s)
pre_initialize_structure method
"""
for modifier in self.modifiers:
modifier.pre_initialize_structure(state, **kwargs)

self.applied = True
state.loggers.system.info(tag="stage", string="Model structure initialized")

def initialize(self, state: "State", **kwargs):
"""
Initialize all the stage modifiers
Expand Down
3 changes: 1 addition & 2 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from safetensors import safe_open
from torch.nn import Module

from llmcompressor.core import active_session, create_session, pre_initialize_structure
from llmcompressor.core import active_session, create_session
from llmcompressor.typing import Processor

COMPLETED_STAGES_FILENAME = "completed_stages.json"
Expand All @@ -33,7 +33,6 @@ def initialize_recipe(model: Module, recipe_path: str):
"""
if not active_session():
create_session()
pre_initialize_structure(model=model, recipe=recipe_path)

# no need to reload if no recipe was applied
if recipe_path is None:
Expand Down
6 changes: 0 additions & 6 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
"the stage name."
)

# just load structure if stage has already applied
if stage_name in completed_stages:
self.trainer.initialize_structure(stage=stage)
self.trainer.accelerator.wait_for_everyone()
continue

# setup checkpoint dir, TODO: this should be optional
self._output_dir = os.path.join(
self.parent_output_dir, "stage_" + stage_name
Expand Down
Loading

0 comments on commit 9b3e216

Please sign in to comment.