Skip to content

Commit

Permalink
[Training] Unifying Preprocess + Postprocessing logic for Train/Onesh…
Browse files Browse the repository at this point in the history
…ot (#1212)

Order of reviews:
#1206
#1207
#1209
#1212  <-- Here
#1214

SUMMARY:
* Move the preprocessing and postprocessing logic out of
`src/llmcompressor/transformers/finetune/text_generation.py` and into
`src/llmcompressor/entrypoints/utils.py`

TEST PLAN:
Pass tests
  • Loading branch information
horheynm authored Mar 6, 2025
1 parent 14ac2e7 commit 9d82f35
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 290 deletions.
1 change: 1 addition & 0 deletions src/llmcompressor/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# flake8: noqa
from .oneshot import Oneshot, oneshot
from .utils import post_process, pre_process
116 changes: 5 additions & 111 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from torch.utils.data import DataLoader
from transformers import PreTrainedModel

from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.entrypoints.utils import post_process, pre_process

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -71,7 +62,7 @@ class Oneshot:
Initializes the `Oneshot` object by parsing input arguments, performing
preprocessing, and setting instance attributes.
run(**kwargs):
__call__(**kwargs):
Performs the one-shot calibration process by preparing a calibration
dataloader, applying recipe modifiers to the model, and executing
postprocessing steps.
Expand All @@ -86,17 +77,6 @@ class Oneshot:
defined in the recipe. Each action is executed via the global
`CompressionSession`.
_pre_process():
Handles preprocessing steps, including model initialization,
tokenizer/processor setup, and resolving tied embedding issues.
check_tied_embeddings():
Logs a warning if `tie_word_embeddings=True`, which may interfere with
saving in the one-shot workflow.
_post_process():
Executes postprocessing steps such as saving the model and resetting
lifecycle actions, especially when a custom `output_dir` is specified.
"""

def __init__(
Expand Down Expand Up @@ -151,7 +131,7 @@ def from_args(

# only run for the first oneshot call
if do_preprocess:
instance._pre_process()
pre_process(model_args)

# Set instance attributes
instance.model = instance.model_args.model
Expand All @@ -172,7 +152,7 @@ def __call__(self):
"""
# TODO: move back once stage runner is removed
# Preprocess the model and tokenizer/processor
self._pre_process()
pre_process(self.model_args)
self.model = self.model_args.model
self.recipe = self.recipe_args.recipe
self.processor = self.model_args.processor
Expand All @@ -183,24 +163,7 @@ def __call__(self):
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
)
self._post_process()

def save(self):
"""
Saves the model and tokenizer/processor to the output directory.
The model is saved in a compressed format if specified in `model_args`.
The tokenizer or processor, if available, is also saved.
Raises:
ValueError: If saving fails due to an invalid `output_dir` or other issues.
"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
)
if self.processor is not None:
self.processor.save_pretrained(self.output_dir)
post_process(model_args=self.model_args, output_dir=self.output_dir)

def apply_recipe_modifiers(
self,
Expand Down Expand Up @@ -236,75 +199,6 @@ def apply_recipe_modifiers(
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)

def _pre_process(self):
"""
Prepares the model and tokenizer/processor for calibration.
- Initializes the model if it's specified as a path or string.
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
behavior.
- Initializes the processor if specified as a path or `None`.
- Sets the minimum tokens per module if `dataset_args` are provided.
Raises:
FileNotFoundError: If the model or processor path is invalid.
"""
self.check_tied_embeddings()

# Initialize model
if isinstance(self.model_args.model, (str, PosixPath)):
self.model_args.model, _ = initialize_model_from_path(self.model_args)

patch_tied_tensors_bug(self.model_args.model)
modify_save_pretrained(self.model_args.model)

# Initialize processor
if isinstance(self.model_args.processor, (str, type(None))):
self.model_args.processor = initialize_processor_from_path(
self.model_args, self.model_args.model
)
# TODO: move to init once stage runner is removed
self.processor = self.model_args.processor

# Set minimum tokens per module if data arguments are provided
if self.dataset_args:
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module

def check_tied_embeddings(self):
"""
Logs a warning if the model has tied word embeddings.
The `tie_word_embeddings` flag may cause issues during saving in the one-shot
calibration workflow due to shared tensor addresses.
"""
if self.model_args.tie_word_embeddings:
logger.debug(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save."
)

def _post_process(self):
"""
Executes post-calibration steps.
This method saves the model and resets lifecycle actions if the `output_dir`
is not the default directory.
Raises:
ValueError: If saving fails due to invalid configurations.
"""
if self.output_dir is not None:
self.save()
return

logger.warning(
"Optimized model not saved. To save, please provide",
"`output_dir` as input arg.",
"Ex. `oneshot(..., output_dir=...)`",
)


def oneshot(**kwargs) -> PreTrainedModel:
one_shot = Oneshot(**kwargs)
Expand Down
Loading

0 comments on commit 9d82f35

Please sign in to comment.