diff --git a/src/llmcompressor/__init__.py b/src/llmcompressor/__init__.py index 6f174a59e..f979a7453 100644 --- a/src/llmcompressor/__init__.py +++ b/src/llmcompressor/__init__.py @@ -40,7 +40,6 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) from llmcompressor.entrypoints import Oneshot, oneshot diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index 75335164d..47e710943 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -15,7 +15,6 @@ create_session, finalize, initialize, - pre_initialize_structure, reset_session, ) from llmcompressor.core.state import Data, Hardware, ModifiedState, State @@ -36,7 +35,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "apply", diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index 9d5d48d63..89eb780c8 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -27,7 +27,6 @@ class EventType(Enum): The purpose of each EventType is to trigger the corresponding modifier callback during training or post training pipelines. - :param PRE_INIT: Event type for pre-initialization. :param INITIALIZE: Event type for initialization. :param FINALIZE: Event type for finalization. :param BATCH_START: Event type for the start of a batch. @@ -38,7 +37,6 @@ class EventType(Enum): """ # training lifecycle - PRE_INIT = "pre_init" INITIALIZE = "initialize" FINALIZE = "finalize" @@ -51,35 +49,6 @@ class EventType(Enum): OPTIM_PRE_STEP = "optim_pre_step" OPTIM_POST_STEP = "optim_post_step" - def order(self) -> int: - """ - Returns the priority order of the current EventType. - Lower values have higher priority. - - :raises ValueError: if the event type is invalid. - :return: The order of the event type, lower has higher priority. - :rtype: int - """ - if self == EventType.PRE_INIT: - return 0 - elif self == EventType.INITIALIZE: - return 10 - elif self == EventType.FINALIZE: - return 20 - elif self == EventType.BATCH_START: - return 100 - elif self == EventType.LOSS_CALCULATED: - return 110 - elif self == EventType.OPTIM_PRE_STEP: - return 120 - elif self == EventType.OPTIM_POST_STEP: - return 130 - elif self == EventType.BATCH_END: - return 140 - else: - logger.error("Invalid event type: {}", self) - raise ValueError(f"Invalid event type {self}") - @dataclass class Event: diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index b76a57523..e69882800 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -18,7 +18,12 @@ ) from llmcompressor.core.state import State from llmcompressor.modifiers import StageModifiers -from llmcompressor.recipe import RecipeContainer +from llmcompressor.recipe import ( + RecipeArgsInput, + RecipeContainer, + RecipeInput, + RecipeStageInput, +) __all__ = ["CompressionLifecycle"] @@ -38,7 +43,7 @@ class CompressionLifecycle: :type event_lifecycle: Optional[EventLifecycle] """ - state: Optional[State] = None + state: State = field(default_factory=State) recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[StageModifiers] = field(default_factory=list) event_lifecycle: Optional[EventLifecycle] = None @@ -62,46 +67,16 @@ def reset(self): except Exception as e: logger.warning(f"Exception during finalizing modifier: {e}") - self.state = None - self.recipe_container = RecipeContainer() - self.modifiers = [] - self.event_lifecycle = None - - self.initialized_ = False - self.finalized = False + self.__init__() logger.info("Compression lifecycle reset") - def pre_initialize_structure(self, **kwargs) -> List[Any]: - """ - Pre-initialize the structure of the compression lifecycle. - - :param kwargs: Additional arguments to update the state with - :return: List of data returned from pre-initialization of modifiers - :rtype: List[Any] - """ - logger.debug("Pre-initializing structure") - self._check_create_state() - extras = self.state.update(**kwargs) - extras = self.recipe_container.update(**extras) - - self._check_compile_recipe() - mod_data = [] - for mod in self.modifiers: - data = mod.pre_initialize_structure(state=self.state, **extras) - logger.debug("Pre-initialized modifier: {}", mod) - if data is not None: - mod_data.append(data) - - applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied] - self.recipe_container.update_applied_stages(applied_stage_names) - logger.info( - "Compression lifecycle structure pre-initialized for {} modifiers", - len(self.modifiers), - ) - - return mod_data - - def initialize(self, **kwargs) -> List[Any]: + def initialize( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + **kwargs, + ) -> List[Any]: """ Initialize the compression lifecycle. @@ -109,16 +84,18 @@ def initialize(self, **kwargs) -> List[Any]: :return: List of data returned from initialization of modifiers :rtype: List[Any] """ - logger.debug("Initializing compression lifecycle") - self._check_create_state() - extras = self.state.update(**kwargs) - extras = self.recipe_container.update(**extras) + self.state.update(**kwargs) + if self.initialized_: # TODO: do not initialize twice + return - self._check_compile_recipe() + logger.debug("Initializing compression lifecycle") + self.recipe_container.append(recipe, recipe_stage, recipe_args) + self.modifiers = self.recipe_container.get_modifiers() self._set_model_layer_prefix() + mod_data = [] for mod in self.modifiers: - data = mod.initialize(state=self.state, **extras) + data = mod.initialize(state=self.state, **kwargs) logger.debug("Initialized modifier: {}", mod) if data is not None: mod_data.append(data) @@ -185,7 +162,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: logger.error("Cannot invoke event after finalizing") raise ValueError("Cannot invoke event after finalizing") - if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + if event_type in [EventType.INITIALIZE, EventType.FINALIZE]: logger.error( "Cannot invoke {} event. Use the corresponding method instead.", event_type, @@ -223,30 +200,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: return mod_data - def _check_create_state(self): - if self.state is not None: - return - - logger.debug("Creating new State instance for compression lifecycle") - self.state = State() - logger.info("State created for compression lifecycle") - - def _check_compile_recipe(self): - if not self.recipe_container.check_compile_recipe(): - return - - logger.debug( - "Compiling recipe and creating modifiers for compression lifecycle" - ) - self.modifiers = self.recipe_container.compiled_recipe.create_modifier() - for mod in self.modifiers: - if mod.unique_id in self.recipe_container.applied_stages: - mod.applied = True - logger.info( - "Recipe compiled and {} modifiers created", - len(self.modifiers), - ) - def _check_setup_event_lifecycle(self, event_type: EventType): if self.event_lifecycle is not None: return diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 888db3f1e..f028510bc 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -65,45 +65,6 @@ def state(self) -> State: """ return self._lifecycle.state - def pre_initialize_structure( - self, - model: Any, - recipe: Union[str, List[str], Recipe, List[Recipe], None] = None, - recipe_stage: Union[str, List[str], None] = None, - recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None, - **kwargs, - ) -> ModifiedState: - """ - A method to pre-initialize the structure of the model for compression. - This will run the pre-initialize structure method for each modifier in the - session's lifecycle. This will also set the session's state to the - pre-initialized state. Takes care of cases when the model(s) structure - has been previously modified by a modifier. - - :param model: the model to pre-initialize the structure for - :param recipe: the recipe to use for the compression, can be a path to a - recipe file, a raw recipe string, a recipe object, or a list - of recipe objects. - :param recipe_stage: the stage to use for the compression - :param recipe_args: the args to use for overriding the recipe defaults - :return: A ModifiedState instance holding the modified model and modifier_data - after pre-initializing the structure - """ - mod_data = self._lifecycle.pre_initialize_structure( - model=model, - recipe=recipe, - recipe_stage=recipe_stage, - recipe_args=recipe_args, - **kwargs, - ) - - return ModifiedState( - model=self.state.model, - optimizer=None, - loss=None, - modifier_data=mod_data, - ) - def initialize( self, recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index da54872c4..4d12f22ff 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -11,7 +11,6 @@ "create_session", "active_session", "reset_session", - "pre_initialize_structure", "initialize", "finalize", "callbacks", @@ -59,16 +58,6 @@ def reset_session(): session._lifecycle.reset() -def pre_initialize_structure(**kwargs): - """ - A method to pre-initialize the structure of the model for the active session - - :param kwargs: the kwargs to pass to the active session's pre-initialize-structure - method - """ - active_session().pre_initialize_structure(**kwargs) - - def initialize( recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None, recipe_stage: Union[str, List[str], None] = None, @@ -156,7 +145,7 @@ def event(cls, event_type: EventType, **kwargs) -> ModifiedState: :param kwargs: additional kwargs to pass to the current session's event method :return: the modified state of the active session after invoking the event """ - if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + if event_type in [EventType.INITIALIZE, EventType.FINALIZE]: raise ValueError( f"Cannot invoke {event_type} event. " f"Use the corresponding method instead." diff --git a/src/llmcompressor/modifiers/interface.py b/src/llmcompressor/modifiers/interface.py index e3a3786b4..f1c73c54b 100644 --- a/src/llmcompressor/modifiers/interface.py +++ b/src/llmcompressor/modifiers/interface.py @@ -11,15 +11,6 @@ class ModifierInterface(ABC): Defines the contract that all modifiers must implement """ - @property - @abstractmethod - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - raise NotImplementedError() - @property @abstractmethod def initialized(self) -> bool: @@ -58,16 +49,6 @@ def calculate_end(self) -> float: """ raise NotImplementedError() - @abstractmethod - def pre_initialize_structure(self, state: State, **kwargs): - """ - Apply the modifier structure to the model - - :param state: The current state of the model - :param kwargs: Additional arguments for the modifier - """ - raise NotImplementedError() - @abstractmethod def initialize(self, state: State, **kwargs): """ diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 65b4a4029..4092cc3de 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -29,20 +29,11 @@ class Modifier(ModifierInterface, HooksMixin): end: Optional[float] = None update: Optional[float] = None - initialized_structure_: bool = False initialized_: bool = False finalized_: bool = False started_: bool = False ended_: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if the modifier structure has been - applied to the model - """ - return self.initialized_structure_ - @property def initialized(self) -> bool: """ @@ -78,15 +69,6 @@ def calculate_end(self) -> float: """ return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: State, **kwargs): - """ - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - self.on_initialize_structure(state, **kwargs) - self.initialized_structure_ = True - def initialize(self, state: State, **kwargs): """ Initialize the modifier for the given model and state. @@ -221,19 +203,6 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: State, **kwargs): - """ - on_initialize_structure is called before the model is initialized - with the modifier structure. - - TODO: Depreciate this function as part of the lifecycle - - :param state: The current state of the model - :param kwargs: Additional arguments for initializing the structure - of the model in question - """ - pass - @abstractmethod def on_initialize(self, state: State, **kwargs) -> bool: """ diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 65e1c90e0..525ba1301 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -62,9 +62,8 @@ class GPTQModifier(Modifier, HooksMixin): | actorder: False Lifecycle: - - on_initialize_structure - - _build_quant_modifier - on_initialize + - _build_quant_modifier - register_hook(module, compress_module, "forward") - run_sequential / run_layer_sequential / run_basic - make_empty_hessian @@ -141,16 +140,16 @@ def validate_sequential_update(cls, value: bool) -> bool: return True - def on_initialize_structure(self, state: State, **kwargs): + def _check_build_quant_modifier(self, model: torch.nn.Module): """ Check the model's quantization state matches that expected by this modifier, adding a default quantization scheme if needed - TODO: Depreciate and fold into `on_initialize` + # TODO: build modifier during recipe validation :param state: session state storing input model and calibration data """ - quantization_already_active = qat_active(state.model) + quantization_already_active = qat_active(model) if isinstance(self.quantize, bool): if not self.quantize and quantization_already_active: logger.warning( @@ -191,18 +190,15 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self._quantization_modifier: - self._quantization_modifier.on_initialize_structure(state, **kwargs) - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run the GPTQ algorithm on the current state :param state: session state storing input model and calibration data """ - # initialize quantization modifier - if not self.initialized_structure_: - self.on_initialize_structure(state, **kwargs) + # build quantization modifier + self._check_build_quant_modifier(state.model) + if self._quantization_modifier: self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: diff --git a/src/llmcompressor/modifiers/stage.py b/src/llmcompressor/modifiers/stage.py index 7e63245b6..fe773bcb5 100644 --- a/src/llmcompressor/modifiers/stage.py +++ b/src/llmcompressor/modifiers/stage.py @@ -19,7 +19,7 @@ class StageModifiers(ModifierInterface, BaseModel): :param index: The index of the stage, if applicable :param group: The group name of the stage, if applicable :param applied: Flag for indicating if this stage has has already been - applied to the model, through structure initialization or finalization + applied to the model through finalization """ modifiers: List["Modifier"] = Field(default_factory=list) @@ -27,14 +27,6 @@ class StageModifiers(ModifierInterface, BaseModel): group: Optional[str] = None applied: bool = False - @property - def initialized_structure(self) -> bool: - """ - :return: True if any of the stage modifiers have initialized structure, - False otherwise - """ - return any(mod.initialized_structure for mod in self.modifiers) - @property def initialized(self) -> bool: """ @@ -93,20 +85,6 @@ def calculate_end(self) -> float: """ return max(mod.calculate_end() for mod in self.modifiers) - def pre_initialize_structure(self, state: "State", **kwargs): - """ - Pre initialize the structure for all stage modifiers mark the stage applied - - :param state: The current state of the training - :param kwargs: Additional kwargs to pass to the modifier(s) - pre_initialize_structure method - """ - for modifier in self.modifiers: - modifier.pre_initialize_structure(state, **kwargs) - - self.applied = True - state.loggers.system.info(tag="stage", string="Model structure initialized") - def initialize(self, state: "State", **kwargs): """ Initialize all the stage modifiers diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index e2e1a91b7..850fba32f 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -8,13 +8,12 @@ from torch.nn import Module from transformers import PreTrainedModel -from llmcompressor.core import active_session, create_session, pre_initialize_structure +from llmcompressor.core import active_session from llmcompressor.typing import Processor COMPLETED_STAGES_FILENAME = "completed_stages.json" __all__ = [ - "initialize_recipe", "copy_python_files_from_model_cache", "fallback_to_cpu", "parse_dtype", @@ -25,30 +24,6 @@ ] -def initialize_recipe(model: Module, recipe_path: str): - """ - Initializes a recipe that has been previously applied to the model - :param model: PyTorch model to apply structure to - :param recipe_path: path to recipe to apply to the model - """ - if not active_session(): - create_session() - pre_initialize_structure(model=model, recipe=recipe_path) - - # no need to reload if no recipe was applied - if recipe_path is None: - return - - session = active_session() - num_stages = len(session.lifecycle.recipe_container.compiled_recipe.stages) - msg = ( - "an unstaged recipe" - if num_stages == 1 - else f"a staged recipe with {num_stages} stages" - ) - logger.info(f"Applied {msg} to the model") - - def save_checkpoint( save_path: str, model: PreTrainedModel, diff --git a/src/llmcompressor/recipe/__init__.py b/src/llmcompressor/recipe/__init__.py index e02a18b39..bb4df06af 100644 --- a/src/llmcompressor/recipe/__init__.py +++ b/src/llmcompressor/recipe/__init__.py @@ -9,7 +9,7 @@ RecipeMetaData, ) from .modifier import RecipeModifier -from .recipe import Recipe, RecipeTuple +from .recipe import Recipe, RecipeArgsInput, RecipeInput, RecipeStageInput, RecipeTuple from .stage import RecipeStage, StageRunType __all__ = [ @@ -26,4 +26,7 @@ "Recipe", "RecipeTuple", "StageRunType", + "RecipeInput", + "RecipeStageInput", + "RecipeArgsInput", ] diff --git a/src/llmcompressor/recipe/container.py b/src/llmcompressor/recipe/container.py index 5cae0dd2c..90c9c1dad 100644 --- a/src/llmcompressor/recipe/container.py +++ b/src/llmcompressor/recipe/container.py @@ -1,8 +1,14 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional from llmcompressor.modifiers import Modifier -from llmcompressor.recipe.recipe import Recipe, RecipeTuple +from llmcompressor.recipe.recipe import ( + Recipe, + RecipeArgsInput, + RecipeInput, + RecipeStageInput, + RecipeTuple, +) __all__ = ["RecipeContainer"] @@ -22,52 +28,42 @@ class RecipeContainer: recipes: List[RecipeTuple] = field(default_factory=list) applied_stages: List[str] = field(default_factory=list) - def update( + def prepend( self, - recipe: Union[ - str, List[str], Recipe, List[Recipe], Modifier, List[Modifier], None - ] = None, - recipe_stage: Union[str, List[str], List[List[str]], None] = None, - recipe_args: Union[Dict[str, Any], List[Dict[str, Any]], None] = None, - **kwargs, - ) -> Dict: - """ - Update the recipes in the container. If a recipe is provided, it will - reset any existing compiled_recipe in the container. Must call - `check_compile_recipe` to re-compile the recipes into a single compiled_recipe. - If no recipe is provided, does nothing and returns the kwargs. - - Can provide multiple recipes to update the container with: - >>> container = RecipeContainer() - >>> recipe_str_1 = ''' - ... test_stage: - ... pruning_modifiers: - ... ConstantPruningModifier: - ... start: 0.0 - ... end: 2.0 - ... targets: ['re:.*weight'] - ... ''' - >>> recipe_str_2 = ''' - ... test_stage: - ... pruning_modifiers: - ... ConstantPruningModifier: - ... start: 3.0 - ... end: 4.0 - ... targets: ['re:.*weight'] - ... ''' - >>> result = container.update(recipe=[recipe_str_1, recipe_str_2]) - - :param recipe: the recipe to update the container with - :param recipe_stage: the recipe stage to update the container with - :param recipe_args: the recipe args to update the recipe with - :param kwargs: additional kwargs to return - :return: the passed in kwargs - """ - if recipe is None or isinstance(recipe, list) and len(recipe) == 0: - return kwargs + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ): + recipe_tuples = self._prepare_tuples(recipe, recipe_stage, recipe_args) + self.recipes = recipe_tuples + self.recipes + self._check_compile_recipe() + + def append( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ): + recipe_tuples = self._prepare_tuples(recipe, recipe_stage, recipe_args) + self.recipes = self.recipes + recipe_tuples + self._check_compile_recipe() - self.compiled_recipe = None + def get_modifiers(self) -> List[Modifier]: + if self.compiled_recipe is None: + return [] + return self.compiled_recipe.create_modifier() + + def _prepare_tuples( + self, + recipe: Optional[RecipeInput] = None, + recipe_stage: Optional[RecipeStageInput] = None, + recipe_args: Optional[RecipeArgsInput] = None, + ) -> List[RecipeTuple]: + if recipe is None or (isinstance(recipe, list) and len(recipe) == 0): + return [] + + # prepare recipe if isinstance(recipe, Modifier) or ( isinstance(recipe, list) and all(isinstance(mod, Modifier) for mod in recipe) @@ -77,6 +73,12 @@ def update( if not isinstance(recipe, list): recipe = [recipe] + recipe = [ + Recipe.create_instance(rec) if isinstance(rec, str) else rec + for rec in recipe + ] + + # prepare stage if recipe_stage is None: recipe_stage = [None] * len(recipe) else: @@ -85,22 +87,23 @@ def update( if not isinstance(recipe_stage[0], list): recipe_stage = [recipe_stage] * len(recipe) + # prepare args if recipe_args is None: recipe_args = [{}] * len(recipe) elif not isinstance(recipe_args, list): recipe_args = [recipe_args] * len(recipe) + # validation if len(recipe) != len(recipe_stage) or len(recipe) != len(recipe_args): raise ValueError( "recipe, recipe_stage, and recipe_args must be the same length" ) - for rec, stage, args in zip(recipe, recipe_stage, recipe_args): - if isinstance(rec, str): - rec = Recipe.create_instance(rec) - self.recipes.append(RecipeTuple(rec, stage, args)) - - return kwargs + # create tuples + return [ + RecipeTuple(rec, stage, args) + for rec, stage, args in zip(recipe, recipe_stage, recipe_args) + ] def update_applied_stages(self, new_stages: List[str]): """ @@ -113,7 +116,7 @@ def update_applied_stages(self, new_stages: List[str]): if stage not in self.applied_stages: self.applied_stages.append(stage) - def check_compile_recipe(self) -> bool: + def _check_compile_recipe(self): """ Check if the recipes need to be compiled into a single recipe and compile them if they do. @@ -122,9 +125,6 @@ def check_compile_recipe(self) -> bool: """ if self.compiled_recipe is None and self.recipes: self.compiled_recipe = Recipe.simplify_combine_recipes(self.recipes) - return True - - return False def check_any_recipe_exists(self) -> bool: """ diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 1e9851ba8..f48c4a568 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -14,7 +14,13 @@ from llmcompressor.recipe.metadata import RecipeMetaData from llmcompressor.recipe.stage import RecipeStage -__all__ = ["Recipe", "RecipeTuple"] +__all__ = [ + "Recipe", + "RecipeTuple", + "RecipeInput", + "RecipeStageInput", + "RecipeArgsInput", +] class Recipe(RecipeBase): @@ -150,7 +156,7 @@ def create_instance( @staticmethod def simplify_recipe( - recipe: Union["Recipe", "RecipeTuple"], shift: Optional[int] = None + recipe: Union[str, "Recipe", "RecipeTuple"], shift: Optional[int] = None ) -> "Recipe": """ Simplify a RecipeTuple by removing stages that are not in the target_stages @@ -177,6 +183,9 @@ def simplify_recipe( defaults to None (No shift) :return: The simplified Recipe instance """ + if isinstance(recipe, str): + recipe = Recipe.create_instance(recipe) + if isinstance(recipe, Recipe): recipe.evaluate(shift=shift) return recipe @@ -212,7 +221,7 @@ def simplify_recipe( @staticmethod def simplify_combine_recipes( - recipes: List[Union["Recipe", "RecipeTuple"]], + recipes: List[Union[str, "Recipe", "RecipeTuple"]], ) -> "Recipe": """ A method to combine multiple recipes into one recipe @@ -571,6 +580,11 @@ def _get_yaml_dict(self) -> Dict[str, Any]: return yaml_recipe_dict +RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]] +RecipeStageInput = Union[str, List[str], List[List[str]]] +RecipeArgsInput = Union[Dict[str, Any], List[Dict[str, Any]]] + + @dataclass class RecipeTuple: """ diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index dd45b7daf..1735a99b8 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -6,6 +6,7 @@ import torch from loguru import logger from torch.utils.data import Dataset +from transformers import PreTrainedModel from llmcompressor.args import ( DatasetArguments, @@ -154,7 +155,9 @@ def train(self, checkpoint: str, stage: Optional[str] = None): # this includes saving the state, optimizer and scheduler self.trainer.save_model(output_dir=self._output_dir) - def run_sequential_stages(self, checkpoint: Optional[str] = None): + def run_sequential_stages( + self, model: PreTrainedModel, checkpoint: Optional[str] = None + ): """ Run the recipe stage by stage, allowing for alternating between one-shot and finetuning flows. Optionally save the model output at the end of each stage @@ -181,12 +184,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): "the stage name." ) - # just load structure if stage has already applied - if stage_name in completed_stages: - self.trainer.initialize_structure(stage=stage) - self.trainer.accelerator.wait_for_everyone() - continue - # setup checkpoint dir, TODO: this should be optional self._output_dir = os.path.join( self.parent_output_dir, "stage_" + stage_name @@ -201,7 +198,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): if run_type is StageRunType.ONESHOT: from llmcompressor import Oneshot - model = get_session_model() self._model_args.model = model oneshot = Oneshot.from_args( diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index c6e35c2fc..27882d7d6 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -17,7 +17,6 @@ create_session, finalize, initialize, - pre_initialize_structure, ) from llmcompressor.metrics import LoggerManager from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( @@ -177,25 +176,6 @@ def initialize_session( torch.cuda.empty_cache() - def initialize_structure(self, stage: Optional[str] = None): - """ - Initialize any recipe structural changes such as quantization on the model, - return immediately if session has already been initialized - :param stage: Optional stage of recipe to run, or None to run all stages - """ - session = active_session() - if session.lifecycle.initialized_: - return False - - pre_initialize_structure( - model=self.model, - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - ) - logger.info(f"Initialized LLM Compressor structure from recipe {self.recipe}") - torch.cuda.empty_cache() - def finalize_session(self): """ Wrap up training by finalizing all modifiers initialized in the current session diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index c1be354db..9a3623f60 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -40,11 +40,9 @@ RecipeArguments, TrainingArguments, ) -from llmcompressor.core import pre_initialize_structure, reset_session +from llmcompressor.core import reset_session from llmcompressor.pytorch.model_load.helpers import ( fallback_to_cpu, - get_session_model, - initialize_recipe, parse_dtype, save_checkpoint, ) @@ -383,11 +381,6 @@ def main( if isinstance(processor, str) or processor is None: processor = initialize_processor_from_path(model_args, model, teacher) - pre_initialize_structure(model=model) - - # initialize session manager - initialize_recipe(model, None) - # Load datasets stage_runner = StageRunner( model_args=model_args, @@ -401,7 +394,7 @@ def main( calib_dataset = stage_runner.get_dataset_split("calibration") trainer = Trainer( - model_init=get_session_model, + model_init=lambda: model, teacher=teacher, recipe=recipe_args.recipe, recipe_args=recipe_args.recipe_args, @@ -429,7 +422,7 @@ def main( checkpoint = None if last_checkpoint is not None: checkpoint = last_checkpoint - stage_runner.run_sequential_stages(checkpoint) + stage_runner.run_sequential_stages(model, checkpoint) # exit immediately return diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index b3ac28383..5959f1699 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -20,6 +20,7 @@ from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache +from llmcompressor.recipe.recipe import Recipe from llmcompressor.transformers.compression.quantization_format import ( infer_quantization_format, ) @@ -27,6 +28,7 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.utils import RECIPE_FILE_NAME +from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path __all__ = ["modify_save_pretrained"] @@ -129,7 +131,7 @@ def skip(*args, **kwargs): ) compressor.update_config(save_directory) - # TODO: update existing recipe + # update existing recipe update_and_save_recipe(model.name_or_path, save_directory) # copy python files from cache dir to save_path if any @@ -253,10 +255,17 @@ def get_model_compressor( def update_and_save_recipe(model_path: str, save_directory: str): - # TODO: update existing recipe - recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) - session = active_session() + recipes_to_save = [] + existing_recipe = infer_recipe_from_model_path(model_path) + if existing_recipe is not None: + recipes_to_save.append(existing_recipe) + + new_recipe = active_session().lifecycle.recipe_container.compiled_recipe + if new_recipe is not None: + recipes_to_save.append(new_recipe) - if (recipe_yaml_str := session.get_serialized_recipe()) is not None: - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) + recipe = Recipe.simplify_combine_recipes(recipes_to_save) + + # save recipe + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) + recipe.yaml(recipe_path) diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index 0d9fd4f7d..cddd45d4f 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -109,7 +109,7 @@ def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]: recipe = recipe_from_huggingface_model_id(hf_stub=model_path) if recipe is None: - logger.info("Failed to infer the recipe from the model_path") + logger.debug("Failed to infer the recipe from the model_path") return recipe @@ -140,14 +140,10 @@ def recipe_from_huggingface_model_id( return None try: - logger.info( - "Attempting to download a recipe ", - f"{hf_stub} " f"from {HUGGINGFACE_CO_URL_HOME}", - ) recipe = hf_hub_download(repo_id=hf_stub, filename=recipe_file_name) logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.") - except Exception as e: - logger.error( + except Exception as e: # TODO: narrow acceptable exceptions + logger.debug( ( f"Unable to find recipe {recipe_file_name} " f"for model ID: {hf_stub}: {e}." diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 53fc04ca8..51c08010b 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,19 +1,11 @@ import operator -from pathlib import Path from typing import Optional -from loguru import logger - try: - from torch.distributed.fsdp import ( - FullStateDictConfig, - FullyShardedDataParallel, - StateDictType, - ) + from torch.distributed.fsdp import FullyShardedDataParallel except ImportError: FullyShardedDataParallel = None -import torch from torch.nn import Module from llmcompressor.core.state import State @@ -22,9 +14,7 @@ "is_fsdp_model", "maybe_get_wrapped", "set_wrapped_model", - "save_pretrained_fsdp", "get_fsdp_parent", - "find_and_move_state_dicts_to_cpu", ] @@ -68,63 +58,6 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def find_and_move_state_dicts_to_cpu(output_dir: str): - """ - Looks for state dicts in the output directory and overwrites them - with cpu state dicts. - - this is needed for quantized models trained with FSDP as the state dict - contains device information, which can cause issues when loading the model - using transformers AutoModel.from_pretrained(...) if the device information - is not removed, assumes the state dicts are named pytorch_model*.bin - """ - - for model_file in Path(output_dir).rglob("pytorch_model*.bin"): - loaded_dict = torch.load(model_file) - for key, value in loaded_dict.items(): - if isinstance(value, torch.Tensor): - loaded_dict[key] = value.cpu() - - torch.save(loaded_dict, model_file) - logger.info(f"Moved state dict {model_file} to cpu") - - -def save_pretrained_fsdp( - model, - accelerator, - output_dir, - save_safetensors: bool = True, - save_compressed: bool = False, -): - full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - """ - Gathers the full FSDP state dict of the model onto rank0 GPU, then uses it to save - the pretrained FSDP model to disk - - :param model: model to save - :param accelerator: Accelerator instance used to perform unwrapping - :param output_dir: where to save output model - :param save_safetensors: True to safe in safetensors format, otherwise .bin - :param save_compressed: whether to compress sparse weights on disk - """ - with FullyShardedDataParallel.state_dict_type( - model, StateDictType.FULL_STATE_DICT, full_state_dict_config - ): - state_dict = accelerator.get_state_dict(model, unwrap=False) - - if accelerator.is_main_process: - accelerator.unwrap_model(model).save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=state_dict, - save_compressed=save_compressed, - safe_serialization=save_safetensors, - ) - - accelerator.wait_for_everyone() - - def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: """ Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index c1f0cb425..0c5ad534d 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -71,7 +71,7 @@ def test_create_default_quant_modifier(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._check_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize assert isinstance(modifier._quantization_modifier, QuantizationModifier) modifier._quantization_modifier.create_init_config() @@ -105,10 +105,6 @@ def test_set_quant_if_modifer_already_exists(self): modifier = GPTQModifier(block_size=128) assert not modifier._quantization_modifier - - modifier.on_initialize_structure(testing_harness.get_state()) - # since quantization modifier is already applied, quantization must be set in - # GPTQ assert modifier.quantize @@ -142,7 +138,7 @@ def test_set_quant_in_gptq(self): assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) + modifier._check_build_quant_modifier(testing_harness.get_state().model) assert modifier.quantize self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index b1331afc0..b2176d0fe 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -8,8 +8,8 @@ from transformers import AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig +from llmcompressor.recipe import Recipe from llmcompressor.transformers.utils import is_model_ct_quantized_from_path -from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs" @@ -28,7 +28,6 @@ def _test_consecutive_runs( from llmcompressor import oneshot from llmcompressor.core import active_session - from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.pytorch.utils.helpers import tensor_sparsity from llmcompressor.utils.pytorch import qat_active @@ -61,11 +60,7 @@ def _test_consecutive_runs( self.assertEqual(len(stages), 1) session.reset() - recipe = infer_recipe_from_model_path(model_path=self.output_first) - if recipe: - initialize_recipe(model=first_model, recipe_path=recipe) - - # reload saved model and up sparsity to 0.7 + # reload saved model and increase sparsity to 0.7 oneshot( model=self.output_first, dataset=self.dataset, @@ -87,11 +82,6 @@ def _test_consecutive_runs( assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance) assert qat_active(second_model) - session = active_session() - session_recipe = session.lifecycle.recipe_container.compiled_recipe - stages = [stage.group for stage in session_recipe.stages] - self.assertEqual(len(stages), 2) - recipe_path = self.output_second / "recipe.yaml" recipe_data = yaml.safe_load(recipe_path.read_text()) stage_keys = recipe_data.keys() @@ -99,6 +89,24 @@ def _test_consecutive_runs( self.assertIn("test_stage_0", stage_keys) self.assertIn("test_stage_1", stage_keys) + # check saved modifier names are same + stage0_modifier_names = list( + list(recipe_data["test_stage_0"].values())[0].keys() + ) + exp_stage0_modifier_names = [ + mod.type + for mod in Recipe.create_instance(self.first_recipe).stages[0].modifiers + ] + stage1_modifier_names = list( + list(recipe_data["test_stage_1"].values())[0].keys() + ) + exp_stage1_modifier_names = [ + mod.type + for mod in Recipe.create_instance(self.second_recipe).stages[0].modifiers + ] + self.assertEqual(stage0_modifier_names, exp_stage0_modifier_names) + self.assertEqual(stage1_modifier_names, exp_stage1_modifier_names) + def tearDown(self): shutil.rmtree(self.output) diff --git a/tests/unit/core/events/test_event.py b/tests/unit/core/events/test_event.py index 06f78ec53..de18dcb28 100644 --- a/tests/unit/core/events/test_event.py +++ b/tests/unit/core/events/test_event.py @@ -3,13 +3,6 @@ from llmcompressor.core import Event, EventType -@pytest.mark.smoke -def test_event_type_order(): - assert EventType.PRE_INIT.order() == 0 - assert EventType.INITIALIZE.order() == 10 - assert EventType.FINALIZE.order() == 20 - - @pytest.mark.smoke def test_event_epoch_based(): event = Event(steps_per_epoch=10)