Skip to content

Commit

Permalink
[Callbacks] Consolidate Saving Methods (#1168)
Browse files Browse the repository at this point in the history
## Purpose ##
* Simplify all methods of saving into one point, namely the wrapped
`save_pretrained` function
* Precursor to #1160
* Needed for having a single point for saving on top of existing recipes

## Background ## 
All the things needed to be done during saving
1. Save the model weights, potentially compressed
2. Save the processor
3. Update the recipe checkpoint
4. Copy any necessary python files from the model cache
5. Only save on the main process

After these changes, (1, 2, 3, 4) will be done within the
`save_pretrained` function, and (5) will be the responsibility of the
caller. (3) will be implemented by #1160 so as not to conflict with
existing logic in pre_init

All of the places where a model is saved are
* If an output dir is specified, at the end of the main function
* Between stages of the stage runner
* Between epochs of the HF Trainer
* By the user after oneshot/training completes

After these changes, all of these will be replaced by a single
`save_checkpoint` function which calls `save_pretrained` to do all the
necessary things

## Changes ##
* Remove `save_model_and_recipe`
  * Saving recipes is now done by `save_pretrained` function
* Implement `save_checkpoint`
  * Single entrypoint for saving a model and its processor
  * Performs actions (1, 2, 4)
* Replace all locations where a model is saved with `save_checkpoint`
  * All applicable callers with only saving on the main process (5)
* Remove support for `modify_fsdp_model_save_pretrained` and
`unwrap_and_export_model`, to be added back in a future release

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
  • Loading branch information
kylesayrs and dsikka authored Feb 25, 2025
1 parent d810e4a commit 6e101b2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 194 deletions.
43 changes: 11 additions & 32 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from loguru import logger
from safetensors import safe_open
from torch.nn import Module
from transformers import PreTrainedModel

from llmcompressor.core import active_session, create_session, pre_initialize_structure
from llmcompressor.typing import Processor
Expand All @@ -14,20 +15,19 @@

__all__ = [
"initialize_recipe",
"save_model_and_recipe",
"copy_python_files_from_model_cache",
"fallback_to_cpu",
"parse_dtype",
"get_session_model",
"get_completed_stages",
"save_completed_stages",
"save_checkpoint",
]


def initialize_recipe(model: Module, recipe_path: str):
"""
Initializes a recipe that has been previously applied to the model
:param model: PyTorch model to apply structure to
:param recipe_path: path to recipe to apply to the model
"""
Expand All @@ -49,43 +49,22 @@ def initialize_recipe(model: Module, recipe_path: str):
logger.info(f"Applied {msg} to the model")


def save_model_and_recipe(
model: Module,
def save_checkpoint(
save_path: str,
processor: Optional[Processor] = None,
save_safetensors: bool = False,
save_compressed: bool = False,
model: PreTrainedModel,
processor: Processor,
save_safetensors: bool = True,
save_compressed: bool = True,
):
"""
Save a model, processor and the currently loaded recipe to file
:param model: pytorch model to save
:param save_path: path to save output to
:param processor: model processor or tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle (bin)
:param save_compressed: whether to compress sparse weights on disk
"""
# avoid circular import
from llmcompressor.transformers.utils.helpers import RECIPE_FILE_NAME

# saving the model also saves the recipe
model.save_pretrained(
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
save_path,
save_safetensors=save_safetensors,
save_compressed=save_compressed,
)

if processor is not None:
processor.save_pretrained(save_path)

logger.info("Saving output to {}".format(os.path.abspath(save_path)))

recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

# copy python files from cache dir to save_path if any
copy_python_files_from_model_cache(model, save_path)


def fallback_to_cpu(device: str) -> str:
"""
Expand Down
14 changes: 10 additions & 4 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from llmcompressor.pytorch.model_load.helpers import (
get_completed_stages,
get_session_model,
save_checkpoint,
save_completed_stages,
)
from llmcompressor.recipe import Recipe, StageRunType
Expand All @@ -26,7 +27,6 @@
make_dataset_splits,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import save_model_and_recipe


class StageRunner:
Expand Down Expand Up @@ -231,14 +231,20 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):

checkpoint = None

if self._training_args.output_dir:
save_model_and_recipe(
model=self.trainer.model,
# save model between stages
if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and self.trainer.accelerator.is_main_process
):
save_checkpoint(
save_path=self._output_dir,
model=self.trainer.model,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._model_args.save_compressed,
)
self.trainer.accelerator.wait_for_everyone()

# save stage to checkpoint dir
if self.trainer.accelerator.is_main_process:
Expand Down
48 changes: 10 additions & 38 deletions src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@
from llmcompressor.modifiers.distillation.utils.pytorch.model_wrapper import (
KDModelWrapper,
)
from llmcompressor.pytorch.model_load.helpers import get_session_model
from llmcompressor.pytorch.model_load.helpers import get_session_model, save_checkpoint
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.transformers import RECIPE_FILE_NAME
from llmcompressor.transformers.finetune.callbacks import (
DisableHalfPrecisionCallback,
TrainingLoopCallbacks,
)
from llmcompressor.utils.fsdp.context import summon_full_params_context
from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from llmcompressor.utils.pytorch import qat_active

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,8 +62,8 @@ class SessionManagerMixIn:
def __init__(
self,
recipe: str,
data_args: "DatasetArguments",
model_args: "ModelArguments",
data_args: Optional["DatasetArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
**kwargs,
Expand Down Expand Up @@ -183,7 +181,6 @@ def initialize_structure(self, stage: Optional[str] = None):
"""
Initialize any recipe structural changes such as quantization on the model,
return immediately if session has already been initialized
:param stage: Optional stage of recipe to run, or None to run all stages
"""
session = active_session()
Expand Down Expand Up @@ -399,44 +396,19 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):

# knowledge distillation requires making wrappers transparent during
if isinstance(self.model, KDModelWrapper):
self.model.prepare_for_save()
self.model.prepare_for_save() # TODO: move to finalize

if not is_fsdp_model(self.model):
self.model.save_pretrained(
# save checkpoint
self.save_state()
if self.accelerator.is_main_process:
processor = getattr(self, "processing_class", self.tokenizer)
save_checkpoint(
output_dir,
save_compressed=self.model_args.save_compressed,
safe_serialization=self.args.save_safetensors,
)
else: # FSDP model
save_pretrained_fsdp(
model=self.model,
accelerator=self.accelerator,
output_dir=output_dir,
processor=processor,
save_safetensors=self.args.save_safetensors,
save_compressed=self.model_args.save_compressed,
save_safetensors=self.metadata.get("save_safetensors", False),
)

self.save_state()
processor = getattr(self, "processing_class", self.tokenizer)
if processor is not None:
processor.save_pretrained(output_dir)

if not self.recipe:
return

if self.accelerator.is_main_process:
# save recipe, will contain modifiers from the model's original recipe as
# well as those added from self.recipe
recipe_path = os.path.join(output_dir, RECIPE_FILE_NAME)
session = active_session()
recipe_yaml_str = session.get_serialized_recipe()
with open(recipe_path, "w") as fp:
fp.write(recipe_yaml_str)

logger.info(
f"Saved LLM Compressor recipe with model state to {recipe_path}"
)

self.accelerator.wait_for_everyone()

if isinstance(self.model, KDModelWrapper):
Expand Down
20 changes: 13 additions & 7 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
get_session_model,
initialize_recipe,
parse_dtype,
save_checkpoint,
)
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.runner import StageRunner
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_fsdp_model_save_pretrained,
modify_save_pretrained,
patch_tied_tensors_bug,
)
Expand Down Expand Up @@ -415,7 +415,10 @@ def main(

# wrap model.save_pretrained
if is_fsdp_model(model):
modify_fsdp_model_save_pretrained(trainer, processor)
raise NotImplementedError(
"FSDP models are not supported in the current release but will be "
"suported in future releases of LLM Compressor"
)
else:
modify_save_pretrained(model)

Expand All @@ -440,16 +443,19 @@ def main(
stage_runner.train(checkpoint)

# save if model was provided as a string or custom output_dir was set

if isinstance(model_args.model, str) or (
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and trainer.accelerator.is_main_process
):
model.save_pretrained(
training_args.output_dir, save_compressed=model_args.save_compressed
save_checkpoint(
save_path=training_args.output_dir,
model=model,
processor=processor,
save_safetensors=True,
save_compressed=model_args.save_compressed,
)
if processor is not None:
processor.save_pretrained(training_args.output_dir)
trainer.accelerator.wait_for_everyone()

# Clean up the CompressionSession before exit if requested
if recipe_args.clear_sparse_session:
Expand Down
Loading

0 comments on commit 6e101b2

Please sign in to comment.