diff --git a/src/llmcompressor/entrypoints/__init__.py b/src/llmcompressor/entrypoints/__init__.py index dd1d4aa83..299ab9084 100644 --- a/src/llmcompressor/entrypoints/__init__.py +++ b/src/llmcompressor/entrypoints/__init__.py @@ -1,2 +1,3 @@ # flake8: noqa from .oneshot import Oneshot, oneshot +from .utils import post_process, pre_process diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index cfaf83f92..21e29057b 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -1,21 +1,12 @@ -from pathlib import PosixPath from typing import Optional -from loguru import logger from torch.utils.data import DataLoader from transformers import PreTrainedModel from llmcompressor.args import parse_args from llmcompressor.core.session_functions import active_session from llmcompressor.datasets import get_calibration_dataloader -from llmcompressor.transformers.finetune.text_generation import ( - initialize_model_from_path, - initialize_processor_from_path, -) -from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - modify_save_pretrained, - patch_tied_tensors_bug, -) +from llmcompressor.entrypoints.utils import post_process, pre_process __all__ = ["Oneshot", "oneshot"] @@ -71,7 +62,7 @@ class Oneshot: Initializes the `Oneshot` object by parsing input arguments, performing preprocessing, and setting instance attributes. - run(**kwargs): + __call__(**kwargs): Performs the one-shot calibration process by preparing a calibration dataloader, applying recipe modifiers to the model, and executing postprocessing steps. @@ -86,17 +77,6 @@ class Oneshot: defined in the recipe. Each action is executed via the global `CompressionSession`. - _pre_process(): - Handles preprocessing steps, including model initialization, - tokenizer/processor setup, and resolving tied embedding issues. - - check_tied_embeddings(): - Logs a warning if `tie_word_embeddings=True`, which may interfere with - saving in the one-shot workflow. - - _post_process(): - Executes postprocessing steps such as saving the model and resetting - lifecycle actions, especially when a custom `output_dir` is specified. """ def __init__( @@ -151,7 +131,7 @@ def from_args( # only run for the first oneshot call if do_preprocess: - instance._pre_process() + pre_process(model_args) # Set instance attributes instance.model = instance.model_args.model @@ -172,7 +152,7 @@ def __call__(self): """ # TODO: move back once stage runner is removed # Preprocess the model and tokenizer/processor - self._pre_process() + pre_process(self.model_args) self.model = self.model_args.model self.recipe = self.recipe_args.recipe self.processor = self.model_args.processor @@ -183,24 +163,7 @@ def __call__(self): self.apply_recipe_modifiers( calibration_dataloader=calibration_dataloader, ) - self._post_process() - - def save(self): - """ - Saves the model and tokenizer/processor to the output directory. - - The model is saved in a compressed format if specified in `model_args`. - The tokenizer or processor, if available, is also saved. - - Raises: - ValueError: If saving fails due to an invalid `output_dir` or other issues. - """ - self.model.save_pretrained( - self.output_dir, - save_compressed=self.model_args.save_compressed, - ) - if self.processor is not None: - self.processor.save_pretrained(self.output_dir) + post_process(model_args=self.model_args, output_dir=self.output_dir) def apply_recipe_modifiers( self, @@ -236,75 +199,6 @@ def apply_recipe_modifiers( session.initialize(**session_kwargs) session.finalize(**session_kwargs) - def _pre_process(self): - """ - Prepares the model and tokenizer/processor for calibration. - - - Initializes the model if it's specified as a path or string. - - Applies patches to fix tied tensor issues and modifies `save_pretrained` - behavior. - - Initializes the processor if specified as a path or `None`. - - Sets the minimum tokens per module if `dataset_args` are provided. - - Raises: - FileNotFoundError: If the model or processor path is invalid. - """ - self.check_tied_embeddings() - - # Initialize model - if isinstance(self.model_args.model, (str, PosixPath)): - self.model_args.model, _ = initialize_model_from_path(self.model_args) - - patch_tied_tensors_bug(self.model_args.model) - modify_save_pretrained(self.model_args.model) - - # Initialize processor - if isinstance(self.model_args.processor, (str, type(None))): - self.model_args.processor = initialize_processor_from_path( - self.model_args, self.model_args.model - ) - # TODO: move to init once stage runner is removed - self.processor = self.model_args.processor - - # Set minimum tokens per module if data arguments are provided - if self.dataset_args: - self.min_tokens_per_module = self.dataset_args.min_tokens_per_module - - def check_tied_embeddings(self): - """ - Logs a warning if the model has tied word embeddings. - - The `tie_word_embeddings` flag may cause issues during saving in the one-shot - calibration workflow due to shared tensor addresses. - """ - if self.model_args.tie_word_embeddings: - logger.debug( - "The tie_word_embeddings flag is by default set to False. " - "This guarantees that the one-shot algorithm saves the final " - "weights without errors. Detected tie_word_embeddings=True. " - "This may cause issues with the one-shot algorithm on save." - ) - - def _post_process(self): - """ - Executes post-calibration steps. - - This method saves the model and resets lifecycle actions if the `output_dir` - is not the default directory. - - Raises: - ValueError: If saving fails due to invalid configurations. - """ - if self.output_dir is not None: - self.save() - return - - logger.warning( - "Optimized model not saved. To save, please provide", - "`output_dir` as input arg.", - "Ex. `oneshot(..., output_dir=...)`", - ) - def oneshot(**kwargs) -> PreTrainedModel: one_shot = Oneshot(**kwargs) diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py new file mode 100644 index 000000000..c8cc3ba07 --- /dev/null +++ b/src/llmcompressor/entrypoints/utils.py @@ -0,0 +1,272 @@ +import inspect +import os +from pathlib import PosixPath +from typing import Optional, Tuple + +from loguru import logger +from torch.nn import Module +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + PreTrainedModel, + set_seed, +) +from transformers.utils.quantization_config import CompressedTensorsConfig + +from llmcompressor.args import ModelArguments, TrainingArguments +from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + modify_save_pretrained, + patch_tied_tensors_bug, +) +from llmcompressor.transformers.utils.helpers import ( + detect_last_checkpoint, + is_model_ct_quantized_from_path, +) +from llmcompressor.typing import Processor +from llmcompressor.utils.fsdp.helpers import is_fsdp_model + + +def pre_process(model_args: "ModelArguments"): + """ + Prepares the model and tokenizer/processor for calibration. + - Initializes the model if it's specified as a path or string. + - Applies patches to fix tied tensor issues and modifies `save_pretrained` + behavior. + - Initializes the processor if specified as a path or `None`. + - Sets the minimum tokens per module if `dataset_args` are provided. + Raises: + FileNotFoundError: If the model or processor path is invalid. + """ + _warn_tied_embeddings(model_args.tie_word_embeddings) + + # Initialize model + if isinstance(model_args.model, (str, PosixPath)): + model, distill_teacher = initialize_model_from_path(model_args) + if is_fsdp_model(model): + raise NotImplementedError( + "FSDP models are not supported in the current release but will be " + "suported in future releases of LLM Compressor." + ) + model_args.model = model + model_args.distill_teacher = distill_teacher + + # Initialize processor + if isinstance(model_args.processor, (str, type(None))): + model_args.processor = initialize_processor_from_path( + model_args, model_args.model + ) + + # untie tie_word_embeddings weights + patch_tied_tensors_bug(model_args.model) + + # wrap model.save_pretrained + modify_save_pretrained(model_args.model) + + +def post_process( + model_args: "ModelArguments", + output_dir: Optional[str] = None, +): + """ + Saves the model and tokenizer/processor to the output directory. + + If the `output_dir` is not the default directory, the method resets lifecycle + actions. The model is saved in a compressed format if specified in `model_args`. + Additionally, the tokenizer or processor, if available, is also saved. + + Raises: + ValueError: If saving fails due to an invalid `output_dir` or other issues. + """ + if output_dir is not None: + model_args.model.save_pretrained( + output_dir, + save_compressed=model_args.save_compressed, + ) + if model_args.processor: + model_args.processor.save_pretrained(output_dir) + return + + logger.warning( + "Optimized model is not saved. To save, please provide", + "`output_dir` as input arg.", + "Ex. `oneshot(..., output_dir=...)`", + ) + + +def _warn_tied_embeddings(tie_word_embeddings: bool = False): + """ + Logs a warning if the model has tied word embeddings. + The `tie_word_embeddings` flag may cause issues during saving in the one-shot + calibration workflow due to shared tensor addresses. + """ + if tie_word_embeddings: + logger.debug( + "The tie_word_embeddings flag is by default set to False. " + "This guarantees that the one-shot algorithm saves the final " + "weights without errors. Detected tie_word_embeddings=True. " + "This may cause issues with the one-shot algorithm on save." + ) + + +def initialize_model_from_path( + model_args: ModelArguments, + training_args: Optional[TrainingArguments] = None, +) -> Tuple[PreTrainedModel, Optional[PreTrainedModel]]: + # Load pretrained model + # The .from_pretrained methods guarantee that only one local process can + # concurrently download model & vocab. + model_path = model_args.model + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + + last_checkpoint = None + teacher = None + + if training_args is not None: + # Load teacher configuration if applicable + teacher_config = ( + AutoConfig.from_pretrained( + model_args.distill_teacher, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + if model_args.distill_teacher + else None + ) + + # Detect last checkpoint + last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) + + # Set seed before initializing model + set_seed(training_args.seed) + + # Initialize teacher model if teacher path is provided + if model_args.distill_teacher is not None: + teacher_device_map = ( + None + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" + else "auto" + ) + teacher_kwargs = { + "config": teacher_config, + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": teacher_device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + teacher = AutoModelForCausalLM.from_pretrained( + model_args.distill_teacher, + **teacher_kwargs, + ) + if "sequence_length" in teacher_kwargs: + teacher.seqlen = teacher_kwargs["sequence_length"] + + model_path = ( + last_checkpoint or model_args.model + if hasattr(model_args, "model") + else model_args.model_name_or_path + ) + + # Fallback to CPU if GPU requested and not available + model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) + + device_map = model_args.oneshot_device + if training_args is not None and training_args.do_train: + device_map = "auto" + + model_kwargs = { + "config": config, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + # optimized models must be decompressed to carry out oneshot/train/etc + if is_model_ct_quantized_from_path(model_path): + model_kwargs["quantization_config"] = CompressedTensorsConfig( + run_compressed=False + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + **model_kwargs, + ) + if "sequence_length" in model_kwargs: + model.seqlen = model_kwargs["sequence_length"] + + return model, teacher + + +def initialize_processor_from_path( + model_args: ModelArguments, + model: PreTrainedModel, + teacher: Optional[PreTrainedModel] = None, +) -> Processor: + processor_src = model_args.processor or get_processor_name_from_model( + model, teacher + ) + # The use_fast=True option is not currently supported safely in Transformers + # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 + try: + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + except Exception: + logger.debug("Could not load fast processor, loading slow processor instead") + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=False, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + + return processor + + +def get_processor_name_from_model(student: Module, teacher: Optional[Module]) -> str: + """ + Get a processor/tokenizer source used for both student and teacher, assuming + that they could be shared + + :param student: the student model + :param teacher: the teacher model + :return: the source for the processor/tokenizer shared between teacher and model + """ + if teacher is not None and teacher not in ("disable", "self"): + student_forward_params = list( + inspect.signature(student.forward).parameters.keys() + ) + teacher_forward_params = list( + inspect.signature(teacher.forward).parameters.keys() + ) + diff = [p for p in student_forward_params if p not in teacher_forward_params] + if diff: + raise RuntimeError( + "Teacher tokenizer cannot be used for student " + f"due to missing args: {diff}" + ) + src_model = teacher + else: + src_model = student + return src_model.config._name_or_path diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index d03867b85..66652b686 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -17,22 +17,12 @@ # Adapted from https://github.com/huggingface/transformers # vllm-project: no copyright -import os import warnings from pathlib import PosixPath -from typing import Optional from compressed_tensors.utils.helpers import deprecated from loguru import logger -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoProcessor, - HfArgumentParser, - PreTrainedModel, - set_seed, -) -from transformers.utils.quantization_config import CompressedTensorsConfig +from transformers import HfArgumentParser from llmcompressor.args import ( DatasetArguments, @@ -41,11 +31,7 @@ TrainingArguments, ) from llmcompressor.core import reset_session -from llmcompressor.pytorch.model_load.helpers import ( - fallback_to_cpu, - parse_dtype, - save_checkpoint, -) +from llmcompressor.pytorch.model_load.helpers import save_checkpoint from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer @@ -53,14 +39,6 @@ modify_save_pretrained, patch_tied_tensors_bug, ) -from llmcompressor.transformers.sparsification.sparse_model import ( - get_processor_name_from_model, -) -from llmcompressor.transformers.utils.helpers import ( - detect_last_checkpoint, - is_model_ct_quantized_from_path, -) -from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model @@ -73,15 +51,6 @@ def train(**kwargs): main(model_args, dataset_args, recipe_args, training_args) -def eval(**kwargs): - """ - CLI entrypoint for running evaluation - """ - model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) - training_args.do_eval = True - main(model_args, dataset_args, recipe_args, training_args) - - @deprecated( message=( "`from llmcompressor.transformers import oneshot` is deprecated, " @@ -98,10 +67,14 @@ def apply(**kwargs): """ CLI entrypoint for any of training, oneshot """ - report_to = kwargs.get("report_to", None) - model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) + from llmcompressor.args import parse_args + + model_args, dataset_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.run_stages = True + report_to = kwargs.get("report_to", None) if report_to is None: # user didn't specify any reporters # get rid of the reporters inferred from hugging face training_args.report_to = [] @@ -123,7 +96,6 @@ def parse_args(**kwargs): src/llmcompressor/transformers/utils/arg_parser/recipe_args.py * training_args in src/llmcompressor/transformers/utils/arg_parser/training_args.py - """ parser = HfArgumentParser( (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) @@ -161,147 +133,6 @@ def parse_args(**kwargs): return model_args, dataset_args, recipe_args, training_args -def initialize_model_from_path( - model_args: ModelArguments, - training_args: Optional[TrainingArguments] = None, -): - # Load pretrained model - # The .from_pretrained methods guarantee that only one local process can - # concurrently download model & vocab. - model_path = model_args.model - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_path, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, - ) - - last_checkpoint = None - teacher = None - - if training_args is not None: - # Load teacher configuration if applicable - teacher_config = ( - AutoConfig.from_pretrained( - model_args.distill_teacher, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, - ) - if model_args.distill_teacher - else None - ) - - # Detect last checkpoint - last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) - - # Set seed before initializing model - set_seed(training_args.seed) - - # Initialize teacher model if teacher path is provided - if model_args.distill_teacher is not None: - teacher_device_map = ( - None - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - else "auto" - ) - teacher_kwargs = { - "config": teacher_config, - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": teacher_device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } - - teacher = AutoModelForCausalLM.from_pretrained( - model_args.distill_teacher, - **teacher_kwargs, - ) - if "sequence_length" in teacher_kwargs: - teacher.seqlen = teacher_kwargs["sequence_length"] - - model_path = ( - last_checkpoint or model_args.model - if hasattr(model_args, "model") - else model_args.model_name_or_path - ) - - # Fallback to CPU if GPU requested and not available - model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) - - # Trainer handles device assignment for FSDP and training, don't do mapping here - # if running oneshot outside of FSDP, apply user device settings - - fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - - device_map = model_args.oneshot_device - if not fsdp_enabled and training_args is not None and training_args.do_train: - device_map = "auto" - - model_kwargs = { - "config": config, - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } - - # this calls from_pretrained under the hood so should be FSDP safe - - # optimized models must be decompressed to carry out oneshot/train/etc - if is_model_ct_quantized_from_path(model_path): - model_kwargs["quantization_config"] = CompressedTensorsConfig( - run_compressed=False - ) - - model = AutoModelForCausalLM.from_pretrained( - model_path, - **model_kwargs, - ) - if "sequence_length" in model_kwargs: - model.seqlen = model_kwargs["sequence_length"] - - return model, teacher - - -def initialize_processor_from_path( - model_args: ModelArguments, - model: PreTrainedModel, - teacher: Optional[PreTrainedModel] = None, -) -> Processor: - processor_src = model_args.processor or get_processor_name_from_model( - model, teacher - ) - # The use_fast=True option is not currently supported safely in Transformers - # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 - try: - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) - except Exception: - logger.debug("Could not load fast processor, loading slow processor instead") - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=False, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) - - return processor - - def main( model_args: ModelArguments, dataset_args: DatasetArguments, @@ -326,10 +157,15 @@ def main( :param model_args: Arguments pertaining to which model/config/tokenizer we are going to fine-tune from - :param dataset_args: Arguments pertaining to what data we are going to input - our model for training + :param dataset_args: Arguments pertaining to what data we are + going to input our model for training :param training_args: Arguments pertaining to training loop configuration """ + from llmcompressor.args import TrainingArguments + from llmcompressor.entrypoints.utils import ( + initialize_model_from_path, + initialize_processor_from_path, + ) # Temporary warning, to be removed if model_args.tie_word_embeddings is True: @@ -426,6 +262,7 @@ def main( # exit immediately return + # Training if training_args.do_train: checkpoint = None