diff --git a/src/llmcompressor/__init__.py b/src/llmcompressor/__init__.py index 264d434f0..044d449c9 100644 --- a/src/llmcompressor/__init__.py +++ b/src/llmcompressor/__init__.py @@ -41,6 +41,5 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index 171c95395..ed4134af7 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -16,7 +16,6 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) from llmcompressor.core.state import Data, Hardware, ModifiedState, State @@ -37,7 +36,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "apply", diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 232d76b83..f41ab15df 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -38,12 +38,11 @@ class CompressionLifecycle: :type event_lifecycle: Optional[EventLifecycle] """ - state: Optional[State] = None + state: Optional[State] = field(default_factory=State) recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[StageModifiers] = field(default_factory=list) event_lifecycle: Optional[EventLifecycle] = None - initialized_structure: bool = False initialized_: bool = False finalized: bool = False event_called: bool = False @@ -64,48 +63,9 @@ def reset(self): except Exception as e: logger.warning(f"Exception during finalizing modifier: {e}") - self.state = None - self.recipe_container = RecipeContainer() - self.modifiers = [] - self.event_lifecycle = None - - self.initialized_structure = False - self.initialized_ = False - self.finalized = False - self.event_called = False + self.__init__() logger.info("Compression lifecycle reset") - def pre_initialize_structure(self, **kwargs) -> List[Any]: - """ - Pre-initialize the structure of the compression lifecycle. - - :param kwargs: Additional arguments to update the state with - :return: List of data returned from pre-initialization of modifiers - :rtype: List[Any] - """ - logger.debug("Pre-initializing structure") - self._check_create_state() - extras = self.state.update(**kwargs) - extras = self.recipe_container.update(**extras) - - self._check_compile_recipe() - mod_data = [] - for mod in self.modifiers: - data = mod.pre_initialize_structure(state=self.state, **extras) - logger.debug("Pre-initialized modifier: {}", mod) - if data is not None: - mod_data.append(data) - - self.initialized_structure = True - applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied] - self.recipe_container.update_applied_stages(applied_stage_names) - logger.info( - "Compression lifecycle structure pre-initialized for {} modifiers", - len(self.modifiers), - ) - - return mod_data - def initialize(self, **kwargs) -> List[Any]: """ Initialize the compression lifecycle. @@ -115,7 +75,6 @@ def initialize(self, **kwargs) -> List[Any]: :rtype: List[Any] """ logger.debug("Initializing compression lifecycle") - self._check_create_state() extras = self.state.update(**kwargs) extras = self.recipe_container.update(**extras) @@ -229,14 +188,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: return mod_data - def _check_create_state(self): - if self.state is not None: - return - - logger.debug("Creating new State instance for compression lifecycle") - self.state = State() - logger.info("State created for compression lifecycle") - def _check_compile_recipe(self): if not self.recipe_container.check_compile_recipe(): return diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 7c489f36f..07eb2dc57 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -65,45 +65,6 @@ def state(self) -> State: """ return self._lifecycle.state - def pre_initialize_structure( - self, - model: Any, - recipe: Union[str, List[str], Recipe, List[Recipe], None] = None, - recipe_stage: Union[str, List[str], None] = None, - recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None, - **kwargs, - ) -> ModifiedState: - """ - A method to pre-initialize the structure of the model for compression. - This will run the pre-initialize structure method for each modifier in the - session's lifecycle. This will also set the session's state to the - pre-initialized state. Takes care of cases when the model(s) structure - has been previously modified by a modifier. - - :param model: the model to pre-initialize the structure for - :param recipe: the recipe to use for the compression, can be a path to a - recipe file, a raw recipe string, a recipe object, or a list - of recipe objects. - :param recipe_stage: the stage to use for the compression - :param recipe_args: the args to use for overriding the recipe defaults - :return: A ModifiedState instance holding the modified model and modifier_data - after pre-initializing the structure - """ - mod_data = self._lifecycle.pre_initialize_structure( - model=model, - recipe=recipe, - recipe_stage=recipe_stage, - recipe_args=recipe_args, - **kwargs, - ) - - return ModifiedState( - model=self.state.model, - optimizer=None, - loss=None, - modifier_data=mod_data, - ) - def initialize( self, recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 9a123a030..0dc522c7c 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -11,7 +11,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "apply", @@ -60,16 +59,6 @@ def reset_session(): session._lifecycle.reset() -def pre_initialize_structure(**kwargs): - """ - A method to pre-initialize the structure of the model for the active session - - :param kwargs: the kwargs to pass to the active session's pre-initialize-structure - method - """ - active_session().pre_initialize_structure(**kwargs) - - def initialize( recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, recipe_stage: Union[str, List[str], None] = None, diff --git a/src/llmcompressor/modifiers/interface.py b/src/llmcompressor/modifiers/interface.py index e3a3786b4..f1c73c54b 100644 --- a/src/llmcompressor/modifiers/interface.py +++ b/src/llmcompressor/modifiers/interface.py @@ -11,15 +11,6 @@ class ModifierInterface(ABC): Defines the contract that all modifiers must implement """ - @property - @abstractmethod - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - raise NotImplementedError() - @property @abstractmethod def initialized(self) -> bool: @@ -58,16 +49,6 @@ def calculate_end(self) -> float: """ raise NotImplementedError() - @abstractmethod - def pre_initialize_structure(self, state: State, **kwargs): - """ - Apply the modifier structure to the model - - :param state: The current state of the model - :param kwargs: Additional arguments for the modifier - """ - raise NotImplementedError() - @abstractmethod def initialize(self, state: State, **kwargs): """ diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 65b4a4029..4092cc3de 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin): end: Optional[float] = None update: Optional[float] = None - initialized_structure_: bool = False initialized_: bool = False finalized_: bool = False started_: bool = False ended_: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - return self.initialized_structure_ - @property def initialized(self) -> bool: """ @@ -78,15 +69,6 @@ def calculate_end(self) -> float: """ return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: State, **kwargs): - """ - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - self.on_initialize_structure(state, **kwargs) - self.initialized_structure_ = True - def initialize(self, state: State, **kwargs): """ Initialize the modifier for the given model and state. @@ -221,19 +203,6 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: State, **kwargs): - """ - on_initialize_structure is called before the model is initialized - with the modifier structure. - - TODO: Depreciate this function as part of the lifecycle - - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - pass - @abstractmethod def on_initialize(self, state: State, **kwargs) -> bool: """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 65e1c90e0..b68be2c59 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -62,9 +62,8 @@ class GPTQModifier(Modifier, HooksMixin): | actorder: False Lifecycle: - - on_initialize_structure - - _build_quant_modifier - on_initialize + - _build_quant_modifier - register_hook(module, compress_module, "forward") - run_sequential / run_layer_sequential / run_basic - make_empty_hessian @@ -141,16 +140,16 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def on_initialize_structure(self, state: State, **kwargs): + def _maybe_build_quant_modifier(self, model: torch.nn.Module): """ Check the model's quantization state matches that expected by this modifier, adding a default quantization scheme if needed - TODO: Depreciate and fold into `on_initialize` + # TODO: build modifier during recipe validation :param state: session state storing input model and calibration data """ - quantization_already_active = qat_active(state.model) + quantization_already_active = qat_active(model) if isinstance(self.quantize, bool): if not self.quantize and quantization_already_active: logger.warning( @@ -191,18 +190,15 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self._quantization_modifier: - self._quantization_modifier.on_initialize_structure(state, **kwargs) - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run the GPTQ algorithm on the current state :param state: session state storing input model and calibration data """ - # initialize quantization modifier - if not self.initialized_structure_: - self.on_initialize_structure(state, **kwargs) + # build quantization modifier + self._maybe_build_quant_modifier(state.model) + if self._quantization_modifier: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: diff --git a/src/llmcompressor/modifiers/stage.py b/src/llmcompressor/modifiers/stage.py index 7e63245b6..fe773bcb5 100644 --- a/src/llmcompressor/modifiers/stage.py +++ b/src/llmcompressor/modifiers/stage.py @@ -19,7 +19,7 @@ class StageModifiers(ModifierInterface, BaseModel): :param index: The index of the stage, if applicable :param group: The group name of the stage, if applicable :param applied: Flag for indicating if this stage has has already been - applied to the model, through structure initialization or finalization + applied to the model through finalization """ modifiers: List["Modifier"] = Field(default_factory=list) @@ -27,14 +27,6 @@ class StageModifiers(ModifierInterface, BaseModel): group: Optional[str] = None applied: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if any of the stage modifiers have initialized structure, - False otherwise - """ - return any(mod.initialized_structure for mod in self.modifiers) - @property def initialized(self) -> bool: """ @@ -93,20 +85,6 @@ def calculate_end(self) -> float: """ return max(mod.calculate_end() for mod in self.modifiers) - def pre_initialize_structure(self, state: "State", **kwargs): - """ - Pre initialize the structure for all stage modifiers mark the stage applied - - :param state: The current state of the training - :param kwargs: Additional kwargs to pass to the modifier(s) - pre_initialize_structure method - """ - for modifier in self.modifiers: - modifier.pre_initialize_structure(state, **kwargs) - - self.applied = True - state.loggers.system.info(tag="stage", string="Model structure initialized") - def initialize(self, state: "State", **kwargs): """ Initialize all the stage modifiers diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 5ddc7ebd5..ce6b186f0 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -7,7 +7,7 @@ from safetensors import safe_open from torch.nn import Module -from llmcompressor.core import active_session, create_session, pre_initialize_structure +from llmcompressor.core import active_session, create_session from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" @@ -33,7 +33,6 @@ def initialize_recipe(model: Module, recipe_path: str): """ if not active_session(): create_session() - pre_initialize_structure(model=model, recipe=recipe_path) # no need to reload if no recipe was applied if recipe_path is None: diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 769b84248..1a8b01a3d 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -238,12 +238,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): "the stage name." ) - # just load structure if stage has already applied - if stage_name in completed_stages: - self.trainer.initialize_structure(stage=stage) - self.trainer.accelerator.wait_for_everyone() - continue - # setup checkpoint dir, TODO: this should be optional self._output_dir = os.path.join( self.parent_output_dir, "stage_" + stage_name diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index a44b08b4f..9b812ff27 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -18,7 +18,6 @@ create_session, finalize, initialize, - pre_initialize_structure, ) from llmcompressor.metrics import LoggerManager from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( @@ -167,26 +166,6 @@ def initialize_session( torch.cuda.empty_cache() - def initialize_structure(self, stage: Optional[str] = None): - """ - Initialize any recipe structural changes such as quantization on the model, - return immediately if session has already been initialized - - :param stage: Optional stage of recipe to run, or None to run all stages - """ - session = active_session() - if session.lifecycle.initialized_: - return False - - pre_initialize_structure( - model=self.model, - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - ) - logger.info(f"Initialized LLM Compressor structure from recipe {self.recipe}") - torch.cuda.empty_cache() - def finalize_session(self): """ Wrap up training by finalizing all modifiers initialized in the current session @@ -399,14 +378,12 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): def evaluate(self, *args, **kwargs): """ Run a sparsification evaluation cycle. - Runs initialize_structure for the sparse session before calling - super().evaluate() and finalization of the session after. :param args: positional args to pass to super().evaluate() :param kwargs: keyword args to pass to super().evaluate() :return: the output from super.evaluate() """ - self.initialize_structure() + # TODO remove output = super().evaluate(*args, **kwargs) self.finalize_session() @@ -416,14 +393,13 @@ def evaluate(self, *args, **kwargs): def predict(self, *args, **kwargs): """ Run a sparsification prediction cycle. - Runs initialize_structure for the sparse session before calling - super().predict() and finalization of the session after. :param args: positional args to pass to super().predict() :param kwargs: keyword args to pass to super().predict() :return: the output from super.predict() """ - self.initialize_structure() + # TODO remove + output = super().predict(*args, **kwargs) self.finalize_session() diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index d79d8cbbe..9f4f81685 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -39,10 +39,9 @@ RecipeArguments, TrainingArguments, ) -from llmcompressor.core import pre_initialize_structure, reset_session +from llmcompressor.core import reset_session from llmcompressor.pytorch.model_load.helpers import ( fallback_to_cpu, - get_session_model, initialize_recipe, parse_dtype, ) @@ -384,8 +383,6 @@ def main( if isinstance(processor, str) or processor is None: processor = initialize_processor_from_path(model_args, model, teacher) - pre_initialize_structure(model=model) - # initialize session manager initialize_recipe(model, None) @@ -403,7 +400,7 @@ def main( calib_dataset = stage_runner.get_dataset_split("calibration") trainer = Trainer( - model_init=get_session_model, + model_init=lambda: model, teacher=teacher, recipe=recipe_args.recipe, recipe_args=recipe_args.recipe_args, diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index c1f0cb425..6d50bfff6 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -71,7 +71,7 @@ def test_create_default_quant_modifier(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._maybe_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize assert isinstance(modifier._quantization_modifier, QuantizationModifier) modifier._quantization_modifier.create_init_config() @@ -105,10 +105,6 @@ def test_set_quant_if_modifer_already_exists(self): modifier = GPTQModifier(block_size=128) assert not modifier._quantization_modifier - - modifier.on_initialize_structure(testing_harness.get_state()) - # since quantization modifier is already applied, quantization must be set in - # GPTQ assert modifier.quantize @@ -142,7 +138,7 @@ def test_set_quant_in_gptq(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._maybe_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 93bd74cd1..f10c0ed51 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -64,5 +64,4 @@ def test_mixin_session_init(mixin_trainer): mixin_trainer.initialize_session(epoch=0.0, checkpoint=None) session = active_session() - assert not session.lifecycle.initialized_structure assert session.lifecycle.initialized_