diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 5ddc7ebd5..e2e1a91b7 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -6,6 +6,7 @@ from loguru import logger from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel from llmcompressor.core import active_session, create_session, pre_initialize_structure from llmcompressor.typing import Processor @@ -14,20 +15,19 @@ __all__ = [ "initialize_recipe", - "save_model_and_recipe", "copy_python_files_from_model_cache", "fallback_to_cpu", "parse_dtype", "get_session_model", "get_completed_stages", "save_completed_stages", + "save_checkpoint", ] def initialize_recipe(model: Module, recipe_path: str): """ Initializes a recipe that has been previously applied to the model - :param model: PyTorch model to apply structure to :param recipe_path: path to recipe to apply to the model """ @@ -49,43 +49,22 @@ def initialize_recipe(model: Module, recipe_path: str): logger.info(f"Applied {msg} to the model") -def save_model_and_recipe( - model: Module, +def save_checkpoint( save_path: str, - processor: Optional[Processor] = None, - save_safetensors: bool = False, - save_compressed: bool = False, + model: PreTrainedModel, + processor: Processor, + save_safetensors: bool = True, + save_compressed: bool = True, ): - """ - Save a model, processor and the currently loaded recipe to file - - :param model: pytorch model to save - :param save_path: path to save output to - :param processor: model processor or tokenizer to save - :param save_safetensors: whether to save as safetensors or pickle (bin) - :param save_compressed: whether to compress sparse weights on disk - """ - # avoid circular import - from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME - + # saving the model also saves the recipe model.save_pretrained( - save_path, save_compressed=save_compressed, safe_serialization=save_safetensors + save_path, + save_safetensors=save_safetensors, + save_compressed=save_compressed, ) - if processor is not None: processor.save_pretrained(save_path) - logger.info("Saving output to {}".format(os.path.abspath(save_path))) - - recipe_path = os.path.join(save_path, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - # copy python files from cache dir to save_path if any - copy_python_files_from_model_cache(model, save_path) - def fallback_to_cpu(device: str) -> str: """ diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 37f7fbb12..dd45b7daf 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -17,6 +17,7 @@ from llmcompressor.pytorch.model_load.helpers import ( get_completed_stages, get_session_model, + save_checkpoint, save_completed_stages, ) from llmcompressor.recipe import Recipe, StageRunType @@ -26,7 +27,6 @@ make_dataset_splits, ) from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import save_model_and_recipe class StageRunner: @@ -231,14 +231,20 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): checkpoint = None - if self._training_args.output_dir: - save_model_and_recipe( - model=self.trainer.model, + # save model between stages + if ( + self._training_args.output_dir + != TrainingArguments.__dataclass_fields__["output_dir"].default + and self.trainer.accelerator.is_main_process + ): + save_checkpoint( save_path=self._output_dir, + model=self.trainer.model, processor=self.processor, save_safetensors=self._training_args.save_safetensors, save_compressed=self._model_args.save_compressed, ) + self.trainer.accelerator.wait_for_everyone() # save stage to checkpoint dir if self.trainer.accelerator.is_main_process: diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index dcf1dacb7..c6e35c2fc 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -23,15 +23,13 @@ from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import ( KDModelWrapper, ) -from llmcompressor.pytorch.model_load.helpers import get_session_model +from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint from llmcompressor.pytorch.utils import ModuleSparsificationInfo -from llmcompressor.transformers import RECIPE_FILE_NAME from llmcompressor.transformers.finetune.callbacks import ( DisableHalfPrecisionCallback, TrainingLoopCallbacks, ) from llmcompressor.utils.fsdp.context import summon_full_params_context -from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: @@ -64,8 +62,8 @@ class SessionManagerMixIn: def __init__( self, recipe: str, + data_args: "DatasetArguments", model_args: "ModelArguments", - data_args: Optional["DatasetArguments"] = None, teacher: Optional[Union[Module, str]] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, **kwargs, @@ -183,7 +181,6 @@ def initialize_structure(self, stage: Optional[str] = None): """ Initialize any recipe structural changes such as quantization on the model, return immediately if session has already been initialized - :param stage: Optional stage of recipe to run, or None to run all stages """ session = active_session() @@ -399,44 +396,19 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): # knowledge distillation requires making wrappers transparent during if isinstance(self.model, KDModelWrapper): - self.model.prepare_for_save() + self.model.prepare_for_save() # TODO: move to finalize - if not is_fsdp_model(self.model): - self.model.save_pretrained( + # save checkpoint + self.save_state() + if self.accelerator.is_main_process: + processor = getattr(self, "processing_class", self.tokenizer) + save_checkpoint( output_dir, - save_compressed=self.model_args.save_compressed, - safe_serialization=self.args.save_safetensors, - ) - else: # FSDP model - save_pretrained_fsdp( model=self.model, - accelerator=self.accelerator, - output_dir=output_dir, + processor=processor, + save_safetensors=self.args.save_safetensors, save_compressed=self.model_args.save_compressed, - save_safetensors=self.metadata.get("save_safetensors", False), - ) - - self.save_state() - processor = getattr(self, "processing_class", self.tokenizer) - if processor is not None: - processor.save_pretrained(output_dir) - - if not self.recipe: - return - - if self.accelerator.is_main_process: - # save recipe, will contain modifiers from the model's original recipe as - # well as those added from self.recipe - recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME) - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) - - logger.info( - f"Saved LLM Compressor recipe with model state to {recipe_path}" ) - self.accelerator.wait_for_everyone() if isinstance(self.model, KDModelWrapper): diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 9cb733c30..c1be354db 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -46,12 +46,12 @@ get_session_model, initialize_recipe, parse_dtype, + save_checkpoint, ) from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - modify_fsdp_model_save_pretrained, modify_save_pretrained, patch_tied_tensors_bug, ) @@ -415,7 +415,10 @@ def main( # wrap model.save_pretrained if is_fsdp_model(model): - modify_fsdp_model_save_pretrained(trainer, processor) + raise NotImplementedError( + "FSDP models are not supported in the current release but will be " + "suported in future releases of LLM Compressor" + ) else: modify_save_pretrained(model) @@ -440,16 +443,19 @@ def main( stage_runner.train(checkpoint) # save if model was provided as a string or custom output_dir was set - if isinstance(model_args.model, str) or ( training_args.output_dir != TrainingArguments.__dataclass_fields__["output_dir"].default + and trainer.accelerator.is_main_process ): - model.save_pretrained( - training_args.output_dir, save_compressed=model_args.save_compressed + save_checkpoint( + save_path=training_args.output_dir, + model=model, + processor=processor, + save_safetensors=True, + save_compressed=model_args.save_compressed, ) - if processor is not None: - processor.save_pretrained(training_args.output_dir) + trainer.accelerator.wait_for_everyone() # Clean up the CompressionSession before exit if requested if recipe_args.clear_sparse_session: diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index f64857fec..b3ac28383 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -16,6 +16,7 @@ ) from loguru import logger from safetensors.torch import storage_ptr +from transformers import PreTrainedModel from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache @@ -26,81 +27,11 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.utils import RECIPE_FILE_NAME -from llmcompressor.typing import Processor -from llmcompressor.utils.fsdp.helpers import ( - find_and_move_state_dicts_to_cpu, - unwrap_and_export_model, -) -__all__ = ["modify_save_pretrained", "modify_fsdp_model_save_pretrained"] +__all__ = ["modify_save_pretrained"] -def modify_fsdp_model_save_pretrained(trainer, processor: Processor): - """ - Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that - supports compression for fsdp model - """ - - def save_pretrained_compressed(save_pretrained_method): - if getattr(save_pretrained_method, "_overridden", False): - # `model.save_pretrained` has already been replaced, return. - return save_pretrained_method - - # Keep a weak reference to the model class and unbound save_pretrained - # method so we can call the original - original_save_pretrained = save_pretrained_method.__func__ - del save_pretrained_method - - @wraps(original_save_pretrained) - def save_pretrained_wrapper( - save_directory: str, - **kwargs, - ): - """ - Wrapper around PreTrainedModel.save_pretrained(), adds functionality for - saving models in a compressed format on disk. The compression format is - saved to the model's config file - - :param save_directory: output directory to save model to - :param sparsity_config: optional sparsity config to compress model with, - if no config is provided it will be inferred from the model - :param quantization_format: optional compression format for quantized - models. If none is provided it will be inferred from the model - :param save_compressed: whether or not to compress the model on disk - :param skip_compression_stats: whether to skip the calculation of - compression statistics (such as global sparsity and sparsity structure) when - saving a model in dense format - :param kwargs: additional kwargs to pass on to model.save_pretrained - """ - try: - trainer.save_model(output_dir=save_directory, _is_oneshot=True) - except AssertionError: - # fallback to this in the case of quantization - unwrap_and_export_model( - model=trainer.model, - accelerator=trainer.accelerator, - output_dir=save_directory, - processor=processor, - ) - # only allow the main process move the state - # dicts to cpu - if trainer.accelerator.is_main_process: - # assuming quantization is the last step - # we no longer need the original model - # and can safely delete it to save memory - del trainer.model - find_and_move_state_dicts_to_cpu(save_directory) - - save_pretrained_wrapper._overriden = True - return save_pretrained_wrapper - - # wrap save_pretrained - trainer.model.save_pretrained = save_pretrained_compressed( - trainer.model.save_pretrained - ) - - -def modify_save_pretrained(model: torch.nn.Module): +def modify_save_pretrained(model: PreTrainedModel): """ Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that supports compression @@ -124,6 +55,7 @@ def save_pretrained_wrapper( sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_format: Optional[str] = None, save_compressed: bool = True, + safe_serialization: bool = True, skip_compression_stats: bool = False, disable_sparse_compression: bool = False, **kwargs, @@ -189,19 +121,16 @@ def skip(*args, **kwargs): # make sure we're on the main process when saving if state_dict is not None and len(state_dict) > 0: compressed_state_dict = compressor.compress(model, state_dict) - - kwargs["safe_serialization"] = kwargs.get("safe_serialization", True) original_save_pretrained.__get__(model, model_class)( - save_directory, state_dict=compressed_state_dict, **kwargs + save_directory, + state_dict=compressed_state_dict, + safe_serialization=safe_serialization, + **kwargs, ) compressor.update_config(save_directory) - recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) - session = active_session() - - if (recipe_yaml_str := session.get_serialized_recipe()) is not None: - with open(recipe_path, "w") as fp: - fp.write(recipe_yaml_str) + # TODO: update existing recipe + update_and_save_recipe(model.name_or_path, save_directory) # copy python files from cache dir to save_path if any copy_python_files_from_model_cache(model, save_directory) @@ -321,3 +250,13 @@ def get_model_compressor( sparsity_config=sparsity_config, quantization_format=quantization_format, ) + + +def update_and_save_recipe(model_path: str, save_directory: str): + # TODO: update existing recipe + recipe_path = os.path.join(save_directory, RECIPE_FILE_NAME) + session = active_session() + + if (recipe_yaml_str := session.get_serialized_recipe()) is not None: + with open(recipe_path, "w") as fp: + fp.write(recipe_yaml_str) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 3a3248fa5..53fc04ca8 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -17,15 +17,11 @@ from torch.nn import Module from llmcompressor.core.state import State -from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe -from llmcompressor.typing import Processor -from llmcompressor.utils.pytorch import set_layer __all__ = [ "is_fsdp_model", "maybe_get_wrapped", "set_wrapped_model", - "unwrap_and_export_model", "save_pretrained_fsdp", "get_fsdp_parent", "find_and_move_state_dicts_to_cpu", @@ -72,34 +68,6 @@ def set_wrapped_model(state: State, wrapped_model: Module): state.model = wrapped_model -def unwrap_and_export_model(model, accelerator, output_dir: str, processor: Processor): - """ - Recursively unwraps an FSDP model, then saves the unwrapped model and the - currently active recipe to disk - - :param model: model to unwrap - :param accelerator: Accelerator instance used to perform unwrapping - :param output_dir: where to save output model - :param processor: processor used by the model - """ - full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FullyShardedDataParallel.state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - full_state_dict_config, - ): - unwrapped_model = accelerator.unwrap_model(model) - for name, module in unwrapped_model.named_modules(): - if isinstance(module, FullyShardedDataParallel): - set_layer(name, accelerator.unwrap_model(module), unwrapped_model) - - save_model_and_recipe( - model=unwrapped_model, - save_path=output_dir, - processor=processor, - ) - - def find_and_move_state_dicts_to_cpu(output_dir: str): """ Looks for state dicts in the output directory and overwrites them