diff --git a/README.md b/README.md index e61e2a49e..3ae778835 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,32 @@ * SmoothQuant * SparseGPT +### When to Use Which Optimization + +#### PTQ +PTQ is performed to reduce the precision of quantizable weights (e.g., linear layers) to a lower bit-width. Supported formats are: + +##### [W4A16](./examples/quantization_w4a16/README.md) +- Uses GPTQ to compress weights to 4 bits. Requires calibration dataset. +- Useful speed ups in low QPS regimes with more weight compression. +- Recommended for any GPUs types. +##### [W8A8-INT8](./examples/quantization_w8a8_int8/README.md) +- Uses channel-wise quantization to compress weights to 8 bits using GPTQ, and uses dynamic per-token quantization to compress activations to 8 bits. Requires calibration dataset for weight quantization. Activation quantization is carried out during inference on vLLM. +- Useful for speed ups in high QPS regimes or offline serving on vLLM. +- Recommended for NVIDIA GPUs with compute capability <8.9 (Ampere, Turing, Volta, Pascal, or older). +##### [W8A8-FP8](./examples/quantization_w8a8_fp8/README.md) +- Uses channel-wise quantization to compress weights to 8 bits, and uses dynamic per-token quantization to compress activations to 8 bits. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM. +- Useful for speed ups in high QPS regimes or offline serving on vLLM. +- Recommended for NVIDIA GPUs with compute capability >8.9 (Hopper and Ada Lovelace). + +#### Sparsification +Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include: + +##### [2:4-Sparsity with FP8 Weight, FP8 Input Activation](./examples/sparse_2of4_quantization_fp8/README.md) +- Uses (1) semi-structured sparsity (SparseGPT), where, for every four contiguous weights in a tensor, two are set to zero. (2) Uses channel-wise quantization to compress weights to 8 bits and dynamic per-token quantization to compress activations to 8 bits. +- Useful for better inference than W8A8-fp8, with almost no drop in its evaluation score [blog](https://neuralmagic.com/blog/24-sparse-llama-fp8-sota-performance-for-nvidia-hopper-gpus/). Note: Small models may experience accuracy drops when the remaining non-zero weights are insufficient to recapitulate the original distribution. +- Recommended for compute capability >8.9 (Hopper and Ada Lovelace). + ## Installation @@ -35,16 +61,16 @@ pip install llmcompressor ### End-to-End Examples Applying quantization with `llmcompressor`: -* [Activation quantization to `int8`](examples/quantization_w8a8_int8) -* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8) -* [Weight only quantization to `int4`](examples/quantization_w4a16) -* [Quantizing MoE LLMs](examples/quantizing_moe) -* [Quantizing Vision-Language Models](examples/multimodal_vision) -* [Quantizing Audio-Language Models](examples/multimodal_audio) +* [Activation quantization to `int8`](examples/quantization_w8a8_int8/README.md) +* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8/README.md) +* [Weight only quantization to `int4`](examples/quantization_w4a16/README.md) +* [Quantizing MoE LLMs](examples/quantizing_moe/README.md) +* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md) +* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md) ### User Guides Deep dives into advanced usage of `llmcompressor`: -* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate) +* [Quantizing with large models with the help of `accelerate`](examples/big_models_with_accelerate/README.md) ## Quick Tour diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ebd14c5d2..4ebb53276 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -19,12 +19,12 @@ max_seq_length = 512 # Load gsm8k using SparseML dataset tools -data_args = DatasetArguments( +dataset_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length ) dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train", processor=tokenizer, ) @@ -69,7 +69,7 @@ train_dataset=train_dataset, data_collator=data_collator, trl_sft_config_args=trl_sft_config_args, - data_args=data_args, + dataset_args=dataset_args, model_args=model_args, ) trainer.train() diff --git a/src/llmcompressor/args/__init__.py b/src/llmcompressor/args/__init__.py index d60435c42..26ad530b6 100644 --- a/src/llmcompressor/args/__init__.py +++ b/src/llmcompressor/args/__init__.py @@ -4,3 +4,4 @@ from .model_arguments import ModelArguments from .recipe_arguments import RecipeArguments from .training_arguments import TrainingArguments +from .utils import parse_args diff --git a/src/llmcompressor/args/utils.py b/src/llmcompressor/args/utils.py new file mode 100644 index 000000000..810d2f6ab --- /dev/null +++ b/src/llmcompressor/args/utils.py @@ -0,0 +1,73 @@ +from typing import Tuple + +from loguru import logger +from transformers import HfArgumentParser + +from llmcompressor.args import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) +from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args + + +def parse_args( + include_training_args: bool = False, **kwargs +) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments, str]: + """ + Keyword arguments passed in from `oneshot` or `train` will + separate the arguments into the following: + + * ModelArguments in + src/llmcompressor/args/model_args.py + * DatasetArguments in + src/llmcompressor/args/dataset_args.py + * RecipeArguments in + src/llmcompressor/args/recipe_args.py + * TrainingArguments in + src/llmcompressor/args/training_args.py + + ModelArguments, DatasetArguments, and RecipeArguments are used for both + `oneshot` and `train`. TrainingArguments is only used for `train`. + + """ + + # pop output_dir, used as an attr in TrainingArguments, where oneshot is not used + output_dir = kwargs.pop("output_dir", None) + + parser_args = (ModelArguments, DatasetArguments, RecipeArguments) + if include_training_args: + parser_args += (TrainingArguments,) + + parser = HfArgumentParser(parser_args) + parsed_args = parser.parse_dict(kwargs) + + training_args = None + if include_training_args: + model_args, dataset_args, recipe_args, training_args = parsed_args + if output_dir is not None: + training_args.output_dir = output_dir + else: + model_args, dataset_args, recipe_args = parsed_args + + if recipe_args.recipe_args is not None: + if not isinstance(recipe_args.recipe_args, dict): + arg_dict = {} + for recipe_arg in recipe_args.recipe_args: + key, value = recipe_arg.split("=") + arg_dict[key] = value + recipe_args.recipe_args = arg_dict + + # raise depreciation warnings + if dataset_args.remove_columns is not None: + logger.warn( + "`remove_columns` argument is depreciated. When tokenizing datasets, all " + "columns which are invalid inputs the tokenizer will be removed", + DeprecationWarning, + ) + + # silently assign tokenizer to processor + resolve_processor_from_model_args(model_args) + + return model_args, dataset_args, recipe_args, training_args, output_dir diff --git a/src/llmcompressor/datasets/__init__.py b/src/llmcompressor/datasets/__init__.py new file mode 100644 index 000000000..0b81cc724 --- /dev/null +++ b/src/llmcompressor/datasets/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa + +from .utils import ( + format_calibration_data, + get_calibration_dataloader, + get_processed_dataset, + make_dataset_splits, +) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py new file mode 100644 index 000000000..0d36cb3ac --- /dev/null +++ b/src/llmcompressor/datasets/utils.py @@ -0,0 +1,191 @@ +import re +from typing import Any, Callable, Dict, List, Optional + +import torch +from datasets import Dataset +from loguru import logger +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from transformers.data import default_data_collator + +from llmcompressor.args import DatasetArguments +from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.typing import Processor + + +def get_processed_dataset( + dataset_args: DatasetArguments, + processor: Processor, + do_oneshot: bool = False, + do_train: bool = True, +) -> Optional[Dict[str, Dataset]]: + """ + Loads datasets for each flow based on dataset_args, stores a Dataset for each + enabled flow in datasets + :param dataset_args: DatasetArguments that contain dataset loading and + processing params + :param processor: processor or tokenizer to use for dataset tokenization + :param do_oneshot: True for oneshot pathway + :param do_train: True for train pathway + :return: A dataset corresponding to either train or calibration (oneshot) + """ + if dataset_args.dataset is None: + logger.warning( + "Running oneshot without calibration data. This is expected for " + "weight-only and dynamic quantization" + ) + return + + splits = dataset_args.splits + tokenized_datasets = {} + + def _get_split_name(inp_str): + # strip out split name, for ex train[60%:] -> train + match = re.match(r"(\w*)\[.*\]", inp_str) + if match is not None: + return match.group(1) + return inp_str + + if splits is None: + splits = {"all": None} + elif isinstance(splits, str): + splits = {_get_split_name(splits): splits} + elif isinstance(splits, List): + splits = {_get_split_name(s): s for s in splits} + + # default to custom dataset if dataset provided isn't a string + registry_id = ( + dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom" + ) + for split_name, split_str in splits.items(): + dataset = dataset_args.dataset + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + # dataset is already tokenized + tokenized_datasets[split_name] = dataset + else: + # dataset needs to be tokenized + dataset_manager = TextGenerationDataset.load_from_registry( + registry_id, + dataset_args=dataset_args, + split=split_str, + processor=processor, + ) + tokenized_datasets[split_name] = dataset_manager(add_labels=do_train) + + return make_dataset_splits( + tokenized_datasets, + do_oneshot=do_oneshot, + do_train=do_train, + ) + + +def get_calibration_dataloader( + dataset_args: DatasetArguments, + processor: Processor, +) -> torch.utils.data.DataLoader: + """ + Get the dataloader used for oneshot calibration. + :param dataset_args: DatasetArguments that contains the dataset parameters. + :param processor: Processor or the tokenizer of the model. + :return: PyTorch dataloader object that contains the calibration dataset. + """ + if dataset_args.dataset is None: + # weight-only quantization or dynamic quantization + return + + datasets = get_processed_dataset( + dataset_args=dataset_args, + processor=processor, + do_oneshot=True, + do_train=False, + ) + + calibration_dataset = datasets.get("calibration") + + return format_calibration_data( + tokenized_dataset=calibration_dataset, + num_calibration_samples=dataset_args.num_calibration_samples, + do_shuffle=dataset_args.shuffle_calibration_samples, + collate_fn=dataset_args.data_collator, + ) + + +def format_calibration_data( + tokenized_dataset: Dataset, + num_calibration_samples: Optional[int] = None, + do_shuffle: bool = True, + collate_fn: Callable = default_data_collator, +) -> List[torch.Tensor]: + """ + Creates a dataloader out of the calibration dataset split, trimming it to + the desired number of calibration samples + :param tokenized_dataset: dataset to convert to dataloader + :param num_calibration_samples: number of data samples to convert + :param do_shuffle: whether to shuffle the dataset before selecting calibration + samples, true by default + :param collate_fn: optional custom collate function, or use default + :return: list of trimmed calibration data tensors + """ + safe_calibration_samples = len(tokenized_dataset) + if num_calibration_samples is not None: + safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) + if safe_calibration_samples != num_calibration_samples: + logger.warn( + f"Requested {num_calibration_samples} calibration samples but " + f"the provided dataset only has {safe_calibration_samples}. " + ) + + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) + + dataloader_params = { + "batch_size": 1, + "sampler": RandomSampler(tokenized_calibration) + if do_shuffle + else SequentialSampler(tokenized_calibration), + "collate_fn": collate_fn, + "pin_memory": True, + } + + calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) + + return calibration_dataloader + + +def make_dataset_splits( + tokenized_datasets: Dict[str, Any], + do_oneshot: bool = True, + do_train: bool = False, +) -> Dict[str, Dataset]: + """ + Restructures the datasets dictionary based on what tasks will be run + train + :param tokenized_datasets: dictionary of processed datasets + :param do_oneshot: Whether to store the calibration dataset + :return: A dataset corresponding to either train or calibration (oneshot) + """ + + # handles case where all splits are contained in a single dataset + if "all" in tokenized_datasets and len(tokenized_datasets) == 1: + tokenized_datasets = tokenized_datasets.get("all") + if isinstance(tokenized_datasets, Dataset): + tokenized_datasets = {"train": tokenized_datasets} + + train_split = calib_split = None + + if do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_split = tokenized_datasets["train"] + if do_oneshot: + calib_split = tokenized_datasets.get("calibration") + if calib_split is None: + if "train" not in tokenized_datasets: + raise ValueError("--do_oneshot requires a calibration dataset") + calib_split = tokenized_datasets["train"] + + split_datasets = { + "train": train_split, + "calibration": calib_split, + } + return split_datasets diff --git a/src/llmcompressor/entrypoints/__init__.py b/src/llmcompressor/entrypoints/__init__.py index dd1d4aa83..299ab9084 100644 --- a/src/llmcompressor/entrypoints/__init__.py +++ b/src/llmcompressor/entrypoints/__init__.py @@ -1,2 +1,3 @@ # flake8: noqa from .oneshot import Oneshot, oneshot +from .utils import post_process, pre_process diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9c57b423a..21e29057b 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -1,26 +1,14 @@ -from pathlib import PosixPath -from typing import Optional, Tuple +from typing import Optional -from loguru import logger from torch.utils.data import DataLoader -from transformers import HfArgumentParser, PreTrainedModel +from transformers import PreTrainedModel -from llmcompressor.args import DatasetArguments, ModelArguments, RecipeArguments +from llmcompressor.args import parse_args from llmcompressor.core.session_functions import active_session -from llmcompressor.transformers.finetune.data.data_helpers import ( - get_calibration_dataloader, -) -from llmcompressor.transformers.finetune.text_generation import ( - initialize_model_from_path, - initialize_processor_from_path, -) -from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - modify_save_pretrained, - patch_tied_tensors_bug, -) -from llmcompressor.transformers.utils.helpers import resolve_processor_from_model_args - -__all__ = ["Oneshot", "oneshot", "parse_oneshot_args"] +from llmcompressor.datasets import get_calibration_dataloader +from llmcompressor.entrypoints.utils import post_process, pre_process + +__all__ = ["Oneshot", "oneshot"] class Oneshot: @@ -36,7 +24,7 @@ class Oneshot: `kwargs` are parsed into: - `model_args`: Arguments for loading and configuring a pretrained model (e.g., `AutoModelForCausalLM`). - - `data_args`: Arguments for dataset-related configurations, such as + - `dataset_args`: Arguments for dataset-related configurations, such as calibration dataloaders. - `recipe_args`: Arguments for defining and configuring recipes that specify optimization actions. @@ -74,7 +62,7 @@ class Oneshot: Initializes the `Oneshot` object by parsing input arguments, performing preprocessing, and setting instance attributes. - run(**kwargs): + __call__(**kwargs): Performs the one-shot calibration process by preparing a calibration dataloader, applying recipe modifiers to the model, and executing postprocessing steps. @@ -89,17 +77,6 @@ class Oneshot: defined in the recipe. Each action is executed via the global `CompressionSession`. - _pre_process(): - Handles preprocessing steps, including model initialization, - tokenizer/processor setup, and resolving tied embedding issues. - - check_tied_embeddings(): - Logs a warning if `tie_word_embeddings=True`, which may interfere with - saving in the one-shot workflow. - - _post_process(): - Executes postprocessing steps such as saving the model and resetting - lifecycle actions, especially when a custom `output_dir` is specified. """ def __init__( @@ -109,24 +86,23 @@ def __init__( """ Initializes the `Oneshot` class with provided arguments. - Parses the input keyword arguments into `model_args`, `data_args`, and + Parses the input keyword arguments into `model_args`, `dataset_args`, and `recipe_args`. Performs preprocessing to initialize the model and tokenizer/processor. :param model_args: ModelArguments parameters, responsible for controlling model loading and saving logic - :param data_args: DatasetArguments parameters, responsible for controlling + :param dataset_args: DatasetArguments parameters, responsible for controlling dataset loading, preprocessing and dataloader loading :param recipe_args: RecipeArguments parameters, responsible for containing recipe-related parameters :param output_dir: Path to save the output model after carrying out oneshot """ - - model_args, data_args, recipe_args, output_dir = parse_oneshot_args(**kwargs) + model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs) self.model_args = model_args - self.data_args = data_args + self.dataset_args = dataset_args self.recipe_args = recipe_args self.output_dir = output_dir @@ -137,20 +113,25 @@ def __init__( @classmethod def from_args( - cls, model_args, data_args, recipe_args, output_dir, do_preprocess: bool = True + cls, + model_args, + dataset_args, + recipe_args, + output_dir, + do_preprocess: bool = True, ): """ Used only for the stage runner to populate the args. """ instance = super().__new__(cls) instance.model_args = model_args - instance.data_args = data_args + instance.dataset_args = dataset_args instance.recipe_args = recipe_args instance.output_dir = output_dir # only run for the first oneshot call if do_preprocess: - instance._pre_process() + pre_process(model_args) # Set instance attributes instance.model = instance.model_args.model @@ -171,35 +152,18 @@ def __call__(self): """ # TODO: move back once stage runner is removed # Preprocess the model and tokenizer/processor - self._pre_process() + pre_process(self.model_args) self.model = self.model_args.model self.recipe = self.recipe_args.recipe self.processor = self.model_args.processor calibration_dataloader = get_calibration_dataloader( - self.data_args, self.processor + self.dataset_args, self.processor ) self.apply_recipe_modifiers( calibration_dataloader=calibration_dataloader, ) - self._post_process() - - def save(self): - """ - Saves the model and tokenizer/processor to the output directory. - - The model is saved in a compressed format if specified in `model_args`. - The tokenizer or processor, if available, is also saved. - - Raises: - ValueError: If saving fails due to an invalid `output_dir` or other issues. - """ - self.model.save_pretrained( - self.output_dir, - save_compressed=self.model_args.save_compressed, - ) - if self.processor is not None: - self.processor.save_pretrained(self.output_dir) + post_process(model_args=self.model_args, output_dir=self.output_dir) def apply_recipe_modifiers( self, @@ -235,139 +199,9 @@ def apply_recipe_modifiers( session.initialize(**session_kwargs) session.finalize(**session_kwargs) - def _pre_process(self): - """ - Prepares the model and tokenizer/processor for calibration. - - - Initializes the model if it's specified as a path or string. - - Applies patches to fix tied tensor issues and modifies `save_pretrained` - behavior. - - Initializes the processor if specified as a path or `None`. - - Sets the minimum tokens per module if `data_args` are provided. - - Raises: - FileNotFoundError: If the model or processor path is invalid. - """ - self.check_tied_embeddings() - - # Initialize model - if isinstance(self.model_args.model, (str, PosixPath)): - self.model_args.model, _ = initialize_model_from_path(self.model_args) - - patch_tied_tensors_bug(self.model_args.model) - modify_save_pretrained(self.model_args.model) - - # Initialize processor - if isinstance(self.model_args.processor, (str, type(None))): - self.model_args.processor = initialize_processor_from_path( - self.model_args, self.model_args.model - ) - # TODO: move to init once stage runner is removed - self.processor = self.model_args.processor - - # Set minimum tokens per module if data arguments are provided - if self.data_args: - self.min_tokens_per_module = self.data_args.min_tokens_per_module - - def check_tied_embeddings(self): - """ - Logs a warning if the model has tied word embeddings. - - The `tie_word_embeddings` flag may cause issues during saving in the one-shot - calibration workflow due to shared tensor addresses. - """ - if self.model_args.tie_word_embeddings: - logger.debug( - "The tie_word_embeddings flag is by default set to False. " - "This guarantees that the one-shot algorithm saves the final " - "weights without errors. Detected tie_word_embeddings=True. " - "This may cause issues with the one-shot algorithm on save." - ) - - def _post_process(self): - """ - Executes post-calibration steps. - - This method saves the model and resets lifecycle actions if the `output_dir` - is not the default directory. - - Raises: - ValueError: If saving fails due to invalid configurations. - """ - if self.output_dir is not None: - self.save() - return - - logger.warning( - "Optimized model not saved. To save, please provide", - "`output_dir` as input arg.", - "Ex. `oneshot(..., output_dir=...)`", - ) - def oneshot(**kwargs) -> PreTrainedModel: one_shot = Oneshot(**kwargs) one_shot() return one_shot.model - - -def parse_oneshot_args( - **kwargs, -) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, str]: - """ - Parses kwargs by grouping into model, data or training arg groups: - * model_args in - src/llmcompressor/transformers/utils/arg_parser/model_args.py - * data_args in - src/llmcompressor/transformers/utils/arg_parser/data_args.py - * recipe_args in - src/llmcompressor/transformers/utils/arg_parser/recipe_args.py - * training_args in - src/llmcompressor/transformers/utils/arg_parser/training_args.py - """ - output_dir = kwargs.pop("output_dir", None) - - parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) - - if not kwargs: - - def _get_output_dir_from_argv() -> Optional[str]: - import sys - - output_dir = None - if "--output_dir" in sys.argv: - index = sys.argv.index("--output_dir") - sys.argv.pop(index) - if index < len(sys.argv): # Check if value exists afer the flag - output_dir = sys.argv.pop(index) - - return output_dir - - output_dir = _get_output_dir_from_argv() or output_dir - parsed_args = parser.parse_args_into_dataclasses() - else: - parsed_args = parser.parse_dict(kwargs) - - model_args, data_args, recipe_args = parsed_args - - if recipe_args.recipe_args is not None: - if not isinstance(recipe_args.recipe_args, dict): - arg_dict = {} - for recipe_arg in recipe_args.recipe_args: - key, value = recipe_arg.split("=") - arg_dict[key] = value - recipe_args.recipe_args = arg_dict - - # raise depreciation warnings - if data_args.remove_columns is not None: - logger.waning( - "`remove_columns` argument is depreciated. When tokenizing datasets, all " - "columns which are invalid inputs the tokenizer will be removed", - DeprecationWarning, - ) - - # silently assign tokenizer to processor - resolve_processor_from_model_args(model_args) - - return model_args, data_args, recipe_args, output_dir diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py new file mode 100644 index 000000000..c8cc3ba07 --- /dev/null +++ b/src/llmcompressor/entrypoints/utils.py @@ -0,0 +1,272 @@ +import inspect +import os +from pathlib import PosixPath +from typing import Optional, Tuple + +from loguru import logger +from torch.nn import Module +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + PreTrainedModel, + set_seed, +) +from transformers.utils.quantization_config import CompressedTensorsConfig + +from llmcompressor.args import ModelArguments, TrainingArguments +from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + modify_save_pretrained, + patch_tied_tensors_bug, +) +from llmcompressor.transformers.utils.helpers import ( + detect_last_checkpoint, + is_model_ct_quantized_from_path, +) +from llmcompressor.typing import Processor +from llmcompressor.utils.fsdp.helpers import is_fsdp_model + + +def pre_process(model_args: "ModelArguments"): + """ + Prepares the model and tokenizer/processor for calibration. + - Initializes the model if it's specified as a path or string. + - Applies patches to fix tied tensor issues and modifies `save_pretrained` + behavior. + - Initializes the processor if specified as a path or `None`. + - Sets the minimum tokens per module if `dataset_args` are provided. + Raises: + FileNotFoundError: If the model or processor path is invalid. + """ + _warn_tied_embeddings(model_args.tie_word_embeddings) + + # Initialize model + if isinstance(model_args.model, (str, PosixPath)): + model, distill_teacher = initialize_model_from_path(model_args) + if is_fsdp_model(model): + raise NotImplementedError( + "FSDP models are not supported in the current release but will be " + "suported in future releases of LLM Compressor." + ) + model_args.model = model + model_args.distill_teacher = distill_teacher + + # Initialize processor + if isinstance(model_args.processor, (str, type(None))): + model_args.processor = initialize_processor_from_path( + model_args, model_args.model + ) + + # untie tie_word_embeddings weights + patch_tied_tensors_bug(model_args.model) + + # wrap model.save_pretrained + modify_save_pretrained(model_args.model) + + +def post_process( + model_args: "ModelArguments", + output_dir: Optional[str] = None, +): + """ + Saves the model and tokenizer/processor to the output directory. + + If the `output_dir` is not the default directory, the method resets lifecycle + actions. The model is saved in a compressed format if specified in `model_args`. + Additionally, the tokenizer or processor, if available, is also saved. + + Raises: + ValueError: If saving fails due to an invalid `output_dir` or other issues. + """ + if output_dir is not None: + model_args.model.save_pretrained( + output_dir, + save_compressed=model_args.save_compressed, + ) + if model_args.processor: + model_args.processor.save_pretrained(output_dir) + return + + logger.warning( + "Optimized model is not saved. To save, please provide", + "`output_dir` as input arg.", + "Ex. `oneshot(..., output_dir=...)`", + ) + + +def _warn_tied_embeddings(tie_word_embeddings: bool = False): + """ + Logs a warning if the model has tied word embeddings. + The `tie_word_embeddings` flag may cause issues during saving in the one-shot + calibration workflow due to shared tensor addresses. + """ + if tie_word_embeddings: + logger.debug( + "The tie_word_embeddings flag is by default set to False. " + "This guarantees that the one-shot algorithm saves the final " + "weights without errors. Detected tie_word_embeddings=True. " + "This may cause issues with the one-shot algorithm on save." + ) + + +def initialize_model_from_path( + model_args: ModelArguments, + training_args: Optional[TrainingArguments] = None, +) -> Tuple[PreTrainedModel, Optional[PreTrainedModel]]: + # Load pretrained model + # The .from_pretrained methods guarantee that only one local process can + # concurrently download model & vocab. + model_path = model_args.model + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + + last_checkpoint = None + teacher = None + + if training_args is not None: + # Load teacher configuration if applicable + teacher_config = ( + AutoConfig.from_pretrained( + model_args.distill_teacher, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + if model_args.distill_teacher + else None + ) + + # Detect last checkpoint + last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) + + # Set seed before initializing model + set_seed(training_args.seed) + + # Initialize teacher model if teacher path is provided + if model_args.distill_teacher is not None: + teacher_device_map = ( + None + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" + else "auto" + ) + teacher_kwargs = { + "config": teacher_config, + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": teacher_device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + teacher = AutoModelForCausalLM.from_pretrained( + model_args.distill_teacher, + **teacher_kwargs, + ) + if "sequence_length" in teacher_kwargs: + teacher.seqlen = teacher_kwargs["sequence_length"] + + model_path = ( + last_checkpoint or model_args.model + if hasattr(model_args, "model") + else model_args.model_name_or_path + ) + + # Fallback to CPU if GPU requested and not available + model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) + + device_map = model_args.oneshot_device + if training_args is not None and training_args.do_train: + device_map = "auto" + + model_kwargs = { + "config": config, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + # optimized models must be decompressed to carry out oneshot/train/etc + if is_model_ct_quantized_from_path(model_path): + model_kwargs["quantization_config"] = CompressedTensorsConfig( + run_compressed=False + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + **model_kwargs, + ) + if "sequence_length" in model_kwargs: + model.seqlen = model_kwargs["sequence_length"] + + return model, teacher + + +def initialize_processor_from_path( + model_args: ModelArguments, + model: PreTrainedModel, + teacher: Optional[PreTrainedModel] = None, +) -> Processor: + processor_src = model_args.processor or get_processor_name_from_model( + model, teacher + ) + # The use_fast=True option is not currently supported safely in Transformers + # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 + try: + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + except Exception: + logger.debug("Could not load fast processor, loading slow processor instead") + processor = AutoProcessor.from_pretrained( + processor_src, + cache_dir=model_args.cache_dir, + use_fast=False, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + trust_remote_code=model_args.trust_remote_code_model, + ) + + return processor + + +def get_processor_name_from_model(student: Module, teacher: Optional[Module]) -> str: + """ + Get a processor/tokenizer source used for both student and teacher, assuming + that they could be shared + + :param student: the student model + :param teacher: the teacher model + :return: the source for the processor/tokenizer shared between teacher and model + """ + if teacher is not None and teacher not in ("disable", "self"): + student_forward_params = list( + inspect.signature(student.forward).parameters.keys() + ) + teacher_forward_params = list( + inspect.signature(teacher.forward).parameters.keys() + ) + diff = [p for p in student_forward_params if p not in teacher_forward_params] + if diff: + raise RuntimeError( + "Teacher tokenizer cannot be used for student " + f"due to missing args: {diff}" + ) + src_model = teacher + else: + src_model = student + return src_model.config._name_or_path diff --git a/src/llmcompressor/logger.py b/src/llmcompressor/logger.py index 332daeb3d..686da7564 100644 --- a/src/llmcompressor/logger.py +++ b/src/llmcompressor/logger.py @@ -53,9 +53,9 @@ class LoggerConfig: metrics_disabled: bool = False -def configure_logger(config: Optional[LoggerConfig] = None): +def configure_logger(config: Optional[LoggerConfig] = None) -> None: """ - Configure the metrics for LLM Compressor. + Configure the logger for LLM Compressor. This function sets up the console and file logging as per the specified or default parameters. @@ -68,9 +68,9 @@ def configure_logger(config: Optional[LoggerConfig] = None): # env vars get priority if (disabled := os.getenv("LLM_COMPRESSOR_LOG_DISABLED")) is not None: - logger_config.disabled = disabled.lower() + logger_config.disabled = disabled.lower() == "true" if (clear_loggers := os.getenv("LLM_COMPRESSOR_CLEAR_LOGGERS")) is not None: - logger_config.clear_loggers = clear_loggers.lower() + logger_config.clear_loggers = clear_loggers.lower() == "true" if (console_log_level := os.getenv("LLM_COMPRESSOR_LOG_LEVEL")) is not None: logger_config.console_log_level = console_log_level.upper() if (log_file := os.getenv("LLM_COMPRESSOR_LOG_FILE")) is not None: diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index bde819e43..593682cbd 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin): 3. Tokenize dataset using model tokenizer/processor 4. Apply post processing such as grouping text and/or adding labels for finetuning - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin): def __init__( self, - data_args: DatasetArguments, + dataset_args: DatasetArguments, split: str, processor: Processor, ): - self.data_args = data_args + self.dataset_args = dataset_args self.split = split self.processor = processor @@ -58,23 +58,23 @@ def __init__( self.tokenizer.pad_token = self.tokenizer.eos_token # configure sequence length - max_seq_length = data_args.max_seq_length - if data_args.max_seq_length > self.tokenizer.model_max_length: + max_seq_length = dataset_args.max_seq_length + if dataset_args.max_seq_length > self.tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " f"maximum length for model ({self.tokenizer.model_max_length}). " f"Using max_seq_length={self.tokenizer.model_max_length}." ) self.max_seq_length = min( - data_args.max_seq_length, self.tokenizer.model_max_length + dataset_args.max_seq_length, self.tokenizer.model_max_length ) # configure padding self.padding = ( False - if self.data_args.concatenate_data + if self.dataset_args.concatenate_data else "max_length" - if self.data_args.pad_to_max_length + if self.dataset_args.pad_to_max_length else False ) @@ -83,7 +83,7 @@ def __init__( self.padding = False def __call__(self, add_labels: bool = True) -> DatasetType: - dataset = self.data_args.dataset + dataset = self.dataset_args.dataset if isinstance(dataset, str): # load dataset: load from huggingface or disk @@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Preprocessing", ) logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}") @@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # regardless of `batched` argument remove_columns=get_columns(dataset), # assumes that input names # and output names are disjoint - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Tokenizing", ) logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}") - if self.data_args.concatenate_data: + if self.dataset_args.concatenate_data: # postprocess: group text dataset = self.map( dataset, self.group_text, batched=True, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Concatenating data", ) logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}") @@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.add_labels, batched=False, # not compatible with batching, need row lengths - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Adding labels", ) logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}") @@ -165,27 +165,31 @@ def load_dataset(self): :param cache_dir: disk location to search for cached dataset :return: the requested dataset """ - if self.data_args.dataset_path is not None: - if self.data_args.dvc_data_repository is not None: - self.data_args.raw_kwargs["storage_options"] = { - "url": self.data_args.dvc_data_repository + if self.dataset_args.dataset_path is not None: + if self.dataset_args.dvc_data_repository is not None: + self.dataset_args.raw_kwargs["storage_options"] = { + "url": self.dataset_args.dvc_data_repository } - self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path + self.dataset_args.raw_kwargs["data_files"] = ( + self.dataset_args.dataset_path + ) else: - self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path( - self.data_args.dataset_path, - self.data_args.dataset - if hasattr(self.data_args, "dataset") - else self.data_args.dataset_name, + self.dataset_args.raw_kwargs["data_files"] = ( + get_custom_datasets_from_path( + self.dataset_args.dataset_path, + self.dataset_args.dataset + if hasattr(self.dataset_args, "dataset") + else self.dataset_args.dataset_name, + ) ) - logger.debug(f"Loading dataset {self.data_args.dataset}") + logger.debug(f"Loading dataset {self.dataset_args.dataset}") return get_raw_dataset( - self.data_args, + self.dataset_args, None, split=self.split, - streaming=self.data_args.streaming, - **self.data_args.raw_kwargs, + streaming=self.dataset_args.streaming, + **self.dataset_args.raw_kwargs, ) @cached_property @@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]: The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY """ - preprocessing_func = self.data_args.preprocessing_func + preprocessing_func = self.dataset_args.preprocessing_func if callable(preprocessing_func): return preprocessing_func @@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs column_names = get_columns(dataset) - if self.data_args.text_column in column_names and "text" not in column_names: - logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`") - dataset = dataset.rename_column(self.data_args.text_column, "text") + if self.dataset_args.text_column in column_names and "text" not in column_names: + logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`") + dataset = dataset.rename_column(self.dataset_args.text_column, "text") return dataset diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index 988a4adc3..e4fe6431c 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -13,14 +13,16 @@ class C4Dataset(TextGenerationDataset): """ Child text generation class for the C4 dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "allenai/c4" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "allenai/c4" + dataset_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index b005724aa..fcc67482f 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -13,19 +13,21 @@ class CNNDailyMailDataset(TextGenerationDataset): """ Text generation class for the CNN/DailyMail dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "cnn_dailymail" - data_args.dataset_config_name = "3.0.0" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "cnn_dailymail" + dataset_args.dataset_config_name = "3.0.0" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): return { diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index 7cff3c1d9..1239e08be 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -7,7 +7,7 @@ class CustomDataset(TextGenerationDataset): Child text generation class for custom local dataset supporting load for csv and json - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` Can also be set to None to load all the splits :param processor: processor or tokenizer to use on dataset diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index cf9b81f69..ff56cfbb9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,76 +1,20 @@ import logging import os -import re -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, Optional -import torch from datasets import Dataset, load_dataset -from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 __all__ = [ - "format_calibration_data", "get_raw_dataset", - "make_dataset_splits", "get_custom_datasets_from_path", - "get_calibration_dataloader", ] -def format_calibration_data( - tokenized_dataset: Dataset, - num_calibration_samples: Optional[int] = None, - do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, - accelerator: Optional[Any] = None, -) -> List[torch.Tensor]: - """ - Creates a dataloader out of the calibration dataset split, trimming it to - the desired number of calibration samples - - :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert - :param do_shuffle: whether to shuffle the dataset before selecting calibration - samples, true by default - :param collate_fn: optional custom collate function, or use default - :param accelerator: optional accelerator for if preparing in FSDP mode - :return: list of trimmed calibration data tensors - """ - safe_calibration_samples = len(tokenized_dataset) - if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: - LOGGER.warn( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " - ) - - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - - dataloader_params = { - "batch_size": 1, - "sampler": RandomSampler(tokenized_calibration) - if do_shuffle - else SequentialSampler(tokenized_calibration), - "collate_fn": collate_fn, - "pin_memory": True, - } - - calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - if accelerator: - calib_dataloader = accelerator.prepare(calib_dataloader) - - return calib_dataloader - - def get_raw_dataset( - data_args, + dataset_args, cache_dir: Optional[str] = None, streaming: Optional[bool] = False, **kwargs, @@ -84,57 +28,16 @@ def get_raw_dataset( """ raw_datasets = load_dataset( - data_args.dataset, - data_args.dataset_config_name, + dataset_args.dataset, + dataset_args.dataset_config_name, cache_dir=cache_dir, streaming=streaming, - trust_remote_code=data_args.trust_remote_code_data, + trust_remote_code=dataset_args.trust_remote_code_data, **kwargs, ) return raw_datasets -def make_dataset_splits( - tokenized_datasets: Dict[str, Any], - do_train: bool = False, - do_oneshot: bool = False, -) -> Dict[str, Dataset]: - """ - Restructures the datasets dictionary based on what tasks will be run - train - - :param tokenized_datasets: dictionary of processed datasets - :param do_oneshot: Whether to store the calibration dataset - - :return: Datasets to be used by the requested tasks - """ - - # handles case where all splits are contained in a single dataset - if "all" in tokenized_datasets and len(tokenized_datasets) == 1: - tokenized_datasets = tokenized_datasets.get("all") - if isinstance(tokenized_datasets, Dataset): - tokenized_datasets = {"train": tokenized_datasets} - - train_split = calib_split = None - - if do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] - if do_oneshot: - calib_split = tokenized_datasets.get("calibration") - if calib_split is None: - if "train" not in tokenized_datasets: - raise ValueError("--do_oneshot requires a calibration dataset") - calib_split = tokenized_datasets["train"] - - split_datasets = { - "train": train_split, - "calibration": calib_split, - } - return split_datasets - - def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str]: """ Get a dictionary of custom datasets from a directory path. Support HF's load_dataset @@ -232,76 +135,3 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files - - -def get_calibration_dataloader( - data_args, - processor, - add_labels: bool = False, # for oneshot - do_oneshot=True, -) -> torch.utils.data.DataLoader: - """ - Loads datasets for each flow based on data_args, stores a Dataset for each - enabled flow in self.datasets - - :param processor: processor or tokenizer to use for dataset tokenization - :param add_labels: if True, add labels column to dataset splits - """ - if data_args.dataset is None: - logger.info( - "Running oneshot without calibration data. This is expected for " - "weight-only and dynamic quantization" - ) - return - - splits = data_args.splits - tokenized_datasets = {} - - def _get_split_name(inp_str): - # strip out split name, for ex train[60%:] -> train - match = re.match(r"(\w*)\[.*\]", inp_str) - if match is not None: - return match.group(1) - return inp_str - - if splits is None: - splits = {"all": None} - elif isinstance(splits, str): - splits = {_get_split_name(splits): splits} - elif isinstance(splits, List): - splits = {_get_split_name(s): s for s in splits} - - # default to custom dataset if dataset provided isn't a string - registry_id = data_args.dataset if isinstance(data_args.dataset, str) else "custom" - for split_name, split_str in splits.items(): - dataset = data_args.dataset - if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: - # dataset is already tokenized - tokenized_datasets[split_name] = dataset - else: - # dataset needs to be tokenized - from llmcompressor.transformers.finetune.data.base import ( - TextGenerationDataset, - ) - - dataset_manager = TextGenerationDataset.load_from_registry( - registry_id, - data_args=data_args, - split=split_str, - processor=processor, - ) - tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) - - datasets = make_dataset_splits( - tokenized_datasets, - do_oneshot=do_oneshot, - ) - - calibration_dataset = datasets.get("calibration") - - return format_calibration_data( - tokenized_dataset=calibration_dataset, - num_calibration_samples=data_args.num_calibration_samples, - do_shuffle=data_args.shuffle_calibration_samples, - collate_fn=data_args.data_collator, - ) diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 3de833738..8a7892c13 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -13,7 +13,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): """ Child text generation class for the Evol Code Alpaca dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -25,12 +25,14 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "theblackcat102/evol-codealpaca-v1" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "theblackcat102/evol-codealpaca-v1" + dataset_args.text_column = "text" - super().__init__(data_args, split=split, processor=processor) + super().__init__(dataset_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.EVOL_ALPACA_TEMPLATE.format(instruction=sample["instruction"]) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 6e11c3aaf..8ada07a0e 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -13,7 +13,7 @@ @TextGenerationDataset.register(name="flickr", alias="flickr30k") class Flickr30K(TextGenerationDataset): """ - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -31,11 +31,13 @@ class Flickr30K(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "lmms-lab/flickr30k" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "lmms-lab/flickr30k" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) if ( self.tokenizer is not None diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index 4f61d1726..ae1318571 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -13,19 +13,21 @@ class GSM8KDataset(TextGenerationDataset): """ Child text generation class for the Grade School Math 8k dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "gsm8k" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "gsm8k" + dataset_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.GSM_TEMPLATE.format(question=sample["question"]) diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 33c9ddc86..81413e785 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -13,7 +13,7 @@ class OpenPlatypusDataset(TextGenerationDataset): """ Child text generation class for the Open Platypus dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -28,11 +28,13 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "garage-bAInd/Open-Platypus" - data_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "garage-bAInd/Open-Platypus" + dataset_args.text_column = "text" + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): if "input" in sample and sample["input"] != "": diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 7966fe4d0..8f03ad509 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -13,18 +13,20 @@ class PtbDataset(TextGenerationDataset): """ Child text generation class for the PTB dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "ptb_text_only" - data_args.text_column = "sentence" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "ptb_text_only" + dataset_args.text_column = "sentence" super().__init__( - data_args=data_args, + dataset_args=dataset_args, split=split, processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index fad57c076..296eb3db5 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -15,7 +15,7 @@ class UltraChatDataset(TextGenerationDataset): """ Child text generation class for the Ultra Chat 200k dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -33,15 +33,17 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "HuggingFaceH4/ultrachat_200k" - data_args.text_column = "messages" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "HuggingFaceH4/ultrachat_200k" + dataset_args.text_column = "messages" if split in ["train", "test"]: split += "_sft" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) if ( self.tokenizer is not None diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 868c9d951..73142d671 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -13,18 +13,20 @@ class WikiTextDataset(TextGenerationDataset): """ Child text generation class for the Open Platypus dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "Salesforce/wikitext" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "Salesforce/wikitext" + dataset_args.text_column = "text" super().__init__( - data_args=data_args, + dataset_args=dataset_args, split=split, processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 1735a99b8..b45153b4f 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -23,10 +23,6 @@ ) from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, - make_dataset_splits, -) from llmcompressor.typing import Processor @@ -42,19 +38,19 @@ class StageRunner: - train() :param model_args: Arguments pertaining to model/config/processor - :param data_args: Arguments pertaining to what data to use for different flows + :param dataset_args: Arguments pertaining to what data to use for different flows :param training_args: Arguments pertaining to training loop configuration :model: unwrapped model to run flows on """ def __init__( self, - data_args: "DatasetArguments", + dataset_args: "DatasetArguments", model_args: "ModelArguments", training_args: "TrainingArguments", recipe_args: "RecipeArguments", ): - self._data_args = data_args + self._dataset_args = dataset_args self._model_args = model_args self._training_args = training_args self._recipe_args = recipe_args @@ -67,13 +63,13 @@ def __init__( def populate_datasets(self, processor: Processor, add_labels: bool = True): """ - Loads datasets for each flow based on data_args, stores a Dataset for each + Loads datasets for each flow based on dataset_args, stores a Dataset for each enabled flow in self.datasets :param processor: processor or tokenizer to use for dataset tokenization :param add_labels: if True, add labels column to dataset splits """ - if self._data_args.dataset is None: + if self._dataset_args.dataset is None: self.processor = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " @@ -81,7 +77,7 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True): ) return - splits = self._data_args.splits + splits = self._dataset_args.splits tokenized_datasets = {} def _get_split_name(inp_str): @@ -100,12 +96,12 @@ def _get_split_name(inp_str): # default to custom dataset if dataset provided isn't a string registry_id = ( - self._data_args.dataset - if isinstance(self._data_args.dataset, str) + self._dataset_args.dataset + if isinstance(self._dataset_args.dataset, str) else "custom" ) for split_name, split_str in splits.items(): - dataset = self._data_args.dataset + dataset = self._dataset_args.dataset if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: # dataset is already tokenized tokenized_datasets[split_name] = dataset @@ -113,12 +109,14 @@ def _get_split_name(inp_str): # dataset needs to be tokenized dataset_manager = TextGenerationDataset.load_from_registry( registry_id, - data_args=self._data_args, + dataset_args=self._dataset_args, split=split_str, processor=processor, ) tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) + from llmcompressor.datasets import make_dataset_splits + self.datasets = make_dataset_splits( tokenized_datasets, do_train=self._training_args.do_train, @@ -164,6 +162,7 @@ def run_sequential_stages( :param checkpoint: optional checkpoint to pick up a stage from """ + recipe_obj = Recipe.create_instance(self._recipe_args.recipe) with self.trainer.accelerator.main_process_first(): checkpoint_dir = self._model_args.model @@ -197,12 +196,13 @@ def run_sequential_stages( # run stage if run_type is StageRunType.ONESHOT: from llmcompressor import Oneshot + from llmcompressor.datasets import format_calibration_data self._model_args.model = model oneshot = Oneshot.from_args( model_args=self._model_args, - data_args=self._data_args, + dataset_args=self._dataset_args, recipe_args=self._recipe_args, output_dir=self._training_args.output_dir, do_preprocess=do_preprocess, @@ -210,10 +210,9 @@ def run_sequential_stages( calib_data = format_calibration_data( tokenized_dataset=self.get_dataset_split("calibration"), - num_calibration_samples=self._data_args.num_calibration_samples, - do_shuffle=self._data_args.shuffle_calibration_samples, - collate_fn=self._data_args.data_collator, - accelerator=self.trainer.accelerator, + num_calibration_samples=self._dataset_args.num_calibration_samples, + do_shuffle=self._dataset_args.shuffle_calibration_samples, + collate_fn=self._dataset_args.data_collator, ) if do_preprocess: diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 27882d7d6..f64916e69 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -54,14 +54,14 @@ class SessionManagerMixIn: :param recipe: path to recipe file to apply during training :param recipe_args: additional kwargs to use for evaluating recipe - :param data_args: kwargs for configuring dataset loading + :param dataset_args: kwargs for configuring dataset loading :param teacher: optional teacher model to use for distillation """ def __init__( self, recipe: str, - data_args: "DatasetArguments", + dataset_args: "DatasetArguments", model_args: "ModelArguments", teacher: Optional[Union[Module, str]] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, @@ -77,7 +77,7 @@ def __init__( self.metadata = None if training_args is not None: - # trl_sft_trainer pathway. Both training_args and data_args + # trl_sft_trainer pathway. Both training_args and dataset_args # have `max_seq_length` which causes collision error. This is the # only shared parameter, where training arg is `TRLSFTConfig` that # inherits HuggingFace's `TrainingArguments` @@ -87,7 +87,7 @@ def __init__( training_args_dict.pop("max_seq_length") ) logger.warning( - "Detected `max_seq_length` in both data_args ", + "Detected `max_seq_length` in both dataset_args ", "and training_args. This is expected for TRL in distillation. ", "Updating metadata to `training_args_max_seq_length`", ) @@ -95,7 +95,7 @@ def __init__( self.metadata = self._extract_metadata( metadata_args=METADATA_ARGS, training_args_dict=training_args_dict, - data_args_dict=asdict(data_args) if data_args else {}, + dataset_args_dict=asdict(dataset_args) if dataset_args else {}, ) # setup metrics and session @@ -125,8 +125,8 @@ def __init__( if self.is_fsdp_enabled: self._prepare_model_for_fsdp() - if data_args is not None: - self.min_tokens_per_module = data_args.min_tokens_per_module + if dataset_args is not None: + self.min_tokens_per_module = dataset_args.min_tokens_per_module def initialize_session( self, @@ -459,16 +459,16 @@ def _extract_metadata( self, metadata_args: List[str], training_args_dict: Dict[str, Any], - data_args_dict: Dict[str, Any], + dataset_args_dict: Dict[str, Any], ) -> Dict[str, Any]: metadata = {} - if not training_args_dict.keys().isdisjoint(data_args_dict.keys()): + if not training_args_dict.keys().isdisjoint(dataset_args_dict.keys()): raise ValueError( "Found common keys in `training_args` and `data args`. " "This is prohibitive and may lead to undesired behavior." ) - args_dict = {**training_args_dict, **data_args_dict} + args_dict = {**training_args_dict, **dataset_args_dict} for arg in metadata_args: if arg not in args_dict.keys(): diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 9a3623f60..66652b686 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -17,22 +17,12 @@ # Adapted from https://github.com/huggingface/transformers # vllm-project: no copyright -import os import warnings from pathlib import PosixPath -from typing import Optional from compressed_tensors.utils.helpers import deprecated from loguru import logger -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoProcessor, - HfArgumentParser, - PreTrainedModel, - set_seed, -) -from transformers.utils.quantization_config import CompressedTensorsConfig +from transformers import HfArgumentParser from llmcompressor.args import ( DatasetArguments, @@ -41,11 +31,7 @@ TrainingArguments, ) from llmcompressor.core import reset_session -from llmcompressor.pytorch.model_load.helpers import ( - fallback_to_cpu, - parse_dtype, - save_checkpoint, -) +from llmcompressor.pytorch.model_load.helpers import save_checkpoint from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer @@ -53,14 +39,6 @@ modify_save_pretrained, patch_tied_tensors_bug, ) -from llmcompressor.transformers.sparsification.sparse_model import ( - get_processor_name_from_model, -) -from llmcompressor.transformers.utils.helpers import ( - detect_last_checkpoint, - is_model_ct_quantized_from_path, -) -from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model @@ -68,18 +46,9 @@ def train(**kwargs): """ CLI entrypoint for running training """ - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) + model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) training_args.do_train = True - main(model_args, data_args, recipe_args, training_args) - - -def eval(**kwargs): - """ - CLI entrypoint for running evaluation - """ - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) - training_args.do_eval = True - main(model_args, data_args, recipe_args, training_args) + main(model_args, dataset_args, recipe_args, training_args) @deprecated( @@ -98,14 +67,18 @@ def apply(**kwargs): """ CLI entrypoint for any of training, oneshot """ - report_to = kwargs.get("report_to", None) - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) + from llmcompressor.args import parse_args + + model_args, dataset_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.run_stages = True + report_to = kwargs.get("report_to", None) if report_to is None: # user didn't specify any reporters # get rid of the reporters inferred from hugging face training_args.report_to = [] - main(model_args, data_args, recipe_args, training_args) + main(model_args, dataset_args, recipe_args, training_args) def compress(**kwargs): @@ -117,13 +90,12 @@ def parse_args(**kwargs): Parses kwargs by grouping into model, data or training arg groups: * model_args in src/llmcompressor/transformers/utils/arg_parser/model_args.py - * data_args in - src/llmcompressor/transformers/utils/arg_parser/data_args.py + * dataset_args in + src/llmcompressor/transformers/utils/arg_parser/dataset_args.py * recipe_args in src/llmcompressor/transformers/utils/arg_parser/recipe_args.py * training_args in src/llmcompressor/transformers/utils/arg_parser/training_args.py - """ parser = HfArgumentParser( (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) @@ -134,7 +106,7 @@ def parse_args(**kwargs): else: parsed_args = parser.parse_dict(kwargs) - model_args, data_args, recipe_args, training_args = parsed_args + model_args, dataset_args, recipe_args, training_args = parsed_args if recipe_args.recipe_args is not None: if not isinstance(recipe_args.recipe_args, dict): arg_dict = {} @@ -144,7 +116,7 @@ def parse_args(**kwargs): recipe_args.recipe_args = arg_dict # raise depreciation warnings - if data_args.remove_columns is not None: + if dataset_args.remove_columns is not None: warnings.warn( "`remove_columns` argument is depreciated. When tokenizing datasets, all " "columns which are invalid inputs the tokenizer will be removed", @@ -158,153 +130,12 @@ def parse_args(**kwargs): model_args.processor = model_args.tokenizer model_args.tokenizer = None - return model_args, data_args, recipe_args, training_args - - -def initialize_model_from_path( - model_args: ModelArguments, - training_args: Optional[TrainingArguments] = None, -): - # Load pretrained model - # The .from_pretrained methods guarantee that only one local process can - # concurrently download model & vocab. - model_path = model_args.model - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_path, - cache_dir=model_args.cache_dir, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, - ) - - last_checkpoint = None - teacher = None - - if training_args is not None: - # Load teacher configuration if applicable - teacher_config = ( - AutoConfig.from_pretrained( - model_args.distill_teacher, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, - ) - if model_args.distill_teacher - else None - ) - - # Detect last checkpoint - last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) - - # Set seed before initializing model - set_seed(training_args.seed) - - # Initialize teacher model if teacher path is provided - if model_args.distill_teacher is not None: - teacher_device_map = ( - None - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - else "auto" - ) - teacher_kwargs = { - "config": teacher_config, - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": teacher_device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } - - teacher = AutoModelForCausalLM.from_pretrained( - model_args.distill_teacher, - **teacher_kwargs, - ) - if "sequence_length" in teacher_kwargs: - teacher.seqlen = teacher_kwargs["sequence_length"] - - model_path = ( - last_checkpoint or model_args.model - if hasattr(model_args, "model") - else model_args.model_name_or_path - ) - - # Fallback to CPU if GPU requested and not available - model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) - - # Trainer handles device assignment for FSDP and training, don't do mapping here - # if running oneshot outside of FSDP, apply user device settings - - fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - - device_map = model_args.oneshot_device - if not fsdp_enabled and training_args is not None and training_args.do_train: - device_map = "auto" - - model_kwargs = { - "config": config, - "cache_dir": model_args.cache_dir, - "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } - - # this calls from_pretrained under the hood so should be FSDP safe - - # optimized models must be decompressed to carry out oneshot/train/etc - if is_model_ct_quantized_from_path(model_path): - model_kwargs["quantization_config"] = CompressedTensorsConfig( - run_compressed=False - ) - - model = AutoModelForCausalLM.from_pretrained( - model_path, - **model_kwargs, - ) - if "sequence_length" in model_kwargs: - model.seqlen = model_kwargs["sequence_length"] - - return model, teacher - - -def initialize_processor_from_path( - model_args: ModelArguments, - model: PreTrainedModel, - teacher: Optional[PreTrainedModel] = None, -) -> Processor: - processor_src = model_args.processor or get_processor_name_from_model( - model, teacher - ) - # The use_fast=True option is not currently supported safely in Transformers - # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 - try: - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) - except Exception: - logger.debug("Could not load fast processor, loading slow processor instead") - processor = AutoProcessor.from_pretrained( - processor_src, - cache_dir=model_args.cache_dir, - use_fast=False, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, - trust_remote_code=model_args.trust_remote_code_model, - ) - - return processor + return model_args, dataset_args, recipe_args, training_args def main( model_args: ModelArguments, - data_args: DatasetArguments, + dataset_args: DatasetArguments, recipe_args: RecipeArguments, training_args: TrainingArguments, ): @@ -326,10 +157,15 @@ def main( :param model_args: Arguments pertaining to which model/config/tokenizer we are going to fine-tune from - :param data_args: Arguments pertaining to what data we are going to input our model - for training + :param dataset_args: Arguments pertaining to what data we are + going to input our model for training :param training_args: Arguments pertaining to training loop configuration """ + from llmcompressor.args import TrainingArguments + from llmcompressor.entrypoints.utils import ( + initialize_model_from_path, + initialize_processor_from_path, + ) # Temporary warning, to be removed if model_args.tie_word_embeddings is True: @@ -384,7 +220,7 @@ def main( # Load datasets stage_runner = StageRunner( model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, training_args=training_args, recipe_args=recipe_args, ) @@ -400,10 +236,10 @@ def main( recipe_args=recipe_args.recipe_args, args=training_args, model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, train_dataset=train_dataset or calib_dataset, processing_class=processor, - data_collator=data_args.data_collator, + data_collator=dataset_args.data_collator, ) # wrap model.save_pretrained @@ -426,6 +262,7 @@ def main( # exit immediately return + # Training if training_args.do_train: checkpoint = None diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index ccce917a7..2bb399b3c 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -63,11 +63,11 @@ def trace( print("Loaded model") # Prepare sample data - data_args = DatasetArguments(**get_dataset_kwargs(modality)) + dataset_args = DatasetArguments(**get_dataset_kwargs(modality)) dataset = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split=data_args.splits["calibration"], + dataset_args.dataset, + dataset_args=dataset_args, + split=dataset_args.splits["calibration"], processor=processor, )(add_labels=False) sample_input = next(iter(dataset)) @@ -89,7 +89,7 @@ def trace( "\nAttempting trace\n" f" model_id={model_id}\n" f" model_class={model_class.__name__}\n" - f" dataset={data_args.dataset}\n" + f" dataset={dataset_args.dataset}\n" f" split={dataset.split}\n" f" inputs={sample_input.keys()}\n" f" sequential_targets={sequential_targets}\n" diff --git a/tests/llmcompressor/entrypoints/test_oneshot.py b/tests/llmcompressor/entrypoints/test_oneshot.py index 4a7f2a5a7..ba0cb3a3a 100644 --- a/tests/llmcompressor/entrypoints/test_oneshot.py +++ b/tests/llmcompressor/entrypoints/test_oneshot.py @@ -1,7 +1,7 @@ from transformers import AutoModelForCausalLM from llmcompressor import Oneshot -from llmcompressor.entrypoints.oneshot import parse_oneshot_args +from llmcompressor.args import parse_args def test_oneshot_from_args(): @@ -17,7 +17,7 @@ def test_oneshot_from_args(): output_dir = "bar_output_dir" - model_args, data_args, recipe_args, output_dir = parse_oneshot_args( + model_args, dataset_args, recipe_args, _, output_dir = parse_args( model=model, dataset=dataset, recipe=recipe, @@ -26,10 +26,10 @@ def test_oneshot_from_args(): output_dir=output_dir, ) - oneshot = Oneshot.from_args(model_args, data_args, recipe_args, output_dir) + oneshot = Oneshot.from_args(model_args, dataset_args, recipe_args, output_dir) assert oneshot.model == model assert oneshot.model_args is model_args - assert oneshot.data_args is data_args + assert oneshot.dataset_args is dataset_args assert oneshot.recipe_args is recipe_args assert oneshot.model_args is model_args assert oneshot.output_dir is output_dir diff --git a/tests/llmcompressor/transformers/compression/configs/channelwise_15m.yaml b/tests/llmcompressor/transformers/compression/configs/channelwise_15m.yaml index 7cf010f66..628521890 100644 --- a/tests/llmcompressor/transformers/compression/configs/channelwise_15m.yaml +++ b/tests/llmcompressor/transformers/compression/configs/channelwise_15m.yaml @@ -1,5 +1,4 @@ cadence: "commit" test_type: "regression" model_stub: "Xenova/llama2.c-stories15M" -new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_channel.yaml" -ppl_threshold: 30000 \ No newline at end of file +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_channel.yaml" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml b/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml index 9b136327a..6837be14e 100644 --- a/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml +++ b/tests/llmcompressor/transformers/compression/configs/fp8_15m.yaml @@ -1,5 +1,4 @@ cadence: "commit" test_type: "regression" model_stub: "Xenova/llama2.c-stories15M" -new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml" -ppl_threshold: 30000 \ No newline at end of file +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_fp8.yaml" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/configs/inputs_15m.yaml b/tests/llmcompressor/transformers/compression/configs/inputs_15m.yaml index 38981e2ca..ca3c1286b 100644 --- a/tests/llmcompressor/transformers/compression/configs/inputs_15m.yaml +++ b/tests/llmcompressor/transformers/compression/configs/inputs_15m.yaml @@ -1,5 +1,4 @@ cadence: "commit" test_type: "regression" model_stub: "Xenova/llama2.c-stories15M" -new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_full.yaml" -ppl_threshold: 30000 \ No newline at end of file +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_full.yaml" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/configs/weights_only_1.1b.yaml b/tests/llmcompressor/transformers/compression/configs/weights_only_1.1b.yaml index 3c9d18f2c..50ccd0aa3 100644 --- a/tests/llmcompressor/transformers/compression/configs/weights_only_1.1b.yaml +++ b/tests/llmcompressor/transformers/compression/configs/weights_only_1.1b.yaml @@ -1,5 +1,4 @@ cadence: "nightly" test_type: "regression" model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" -new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_weight.yaml" -ppl_threshold: 20 \ No newline at end of file +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_weight.yaml" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/configs/weights_only_15m.yaml b/tests/llmcompressor/transformers/compression/configs/weights_only_15m.yaml index 564c961a0..d7aa73f58 100644 --- a/tests/llmcompressor/transformers/compression/configs/weights_only_15m.yaml +++ b/tests/llmcompressor/transformers/compression/configs/weights_only_15m.yaml @@ -1,5 +1,4 @@ cadence: "commit" test_type: "regression" model_stub: "Xenova/llama2.c-stories15M" -new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_weight.yaml" -ppl_threshold: 30000 \ No newline at end of file +new_recipe: "tests/llmcompressor/transformers/compression/recipes/new_quant_weight.yaml" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/decompression_configs/w4a16.yaml b/tests/llmcompressor/transformers/compression/decompression_configs/w4a16.yaml index 144044f28..330023a80 100644 --- a/tests/llmcompressor/transformers/compression/decompression_configs/w4a16.yaml +++ b/tests/llmcompressor/transformers/compression/decompression_configs/w4a16.yaml @@ -1,4 +1,4 @@ -cadence: "commit" +cadence: "nightly" test_type: "regression" compressed_model_stub: "nm-testing/tinyllama-w4a16-compressed" skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/decompression_configs/w8a16_dense.yaml b/tests/llmcompressor/transformers/compression/decompression_configs/w8a16_dense.yaml index 95e73b148..337e6c19e 100644 --- a/tests/llmcompressor/transformers/compression/decompression_configs/w8a16_dense.yaml +++ b/tests/llmcompressor/transformers/compression/decompression_configs/w8a16_dense.yaml @@ -1,4 +1,4 @@ -cadence: "commit" +cadence: "nightly" test_type: "regression" compressed_model_stub: "nm-testing/tinyllama-w8a16-dense" skeleton_model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/test_decompress.py b/tests/llmcompressor/transformers/compression/test_decompress.py index 616dd0dfe..b18cba80e 100644 --- a/tests/llmcompressor/transformers/compression/test_decompress.py +++ b/tests/llmcompressor/transformers/compression/test_decompress.py @@ -3,6 +3,7 @@ import tempfile import unittest +import torch from compressed_tensors import QUANTIZATION_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import QuantizationStatus @@ -113,16 +114,16 @@ def test_hf_quantizer_decompress_match_manual_decompress(self): ) inputs = inputs.to(self.decompressed_model_manual.device) - decompressed_model_manual_output = self.tokenizer.batch_decode( - self.decompressed_model_manual.generate(**inputs, max_length=50) + decompressed_model_manual_output = self.decompressed_model_manual.generate( + **inputs, max_length=50 ) - decompressed_model_hf_quantizer_out = self.tokenizer.batch_decode( + decompressed_model_hf_quantizer_out = ( self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50) ) - assert ( - decompressed_model_hf_quantizer_out == decompressed_model_manual_output + assert torch.equal( + decompressed_model_hf_quantizer_out, decompressed_model_manual_output ) @classmethod diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index e68d8b42a..8a4f46fb5 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -55,7 +55,7 @@ def tearDownClass(cls): @staticmethod def _run_oneshot(model, recipe, dataset, output_dir): - num_calibration_samples = 512 + num_calibration_samples = 64 max_seq_length = 512 pad_to_max_length = False @@ -68,7 +68,7 @@ def _run_oneshot(model, recipe, dataset, output_dir): recipe=recipe, pad_to_max_length=pad_to_max_length, clear_sparse_session=False, - splits={"calibration": "train_gen[:5%]"}, + splits={"calibration": "train_gen[:1%]"}, save_compressed=False, ) return model @@ -123,10 +123,10 @@ def test_quantization_reload(self): assert o_zp.dtype == n_zp.dtype assert torch.equal(o_zp, n_zp) - def _get_dataloader(self, data_args, tokenizer): + def _get_dataloader(self, dataset_args, tokenizer): dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train_gen[:5%]", processor=tokenizer, ) @@ -142,12 +142,14 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): + if self.ppl_threshold is None: + pytest.skip("Skipping perplexity calculation.") tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) - dataloader = self._get_dataloader(data_args, tokenizer) + dataloader = self._get_dataloader(dataset_args, tokenizer) total_ppl = 0.0 total_non_nan = 0 diff --git a/tests/llmcompressor/transformers/compression/test_run_compressed.py b/tests/llmcompressor/transformers/compression/test_run_compressed.py index 616dd0dfe..b18cba80e 100644 --- a/tests/llmcompressor/transformers/compression/test_run_compressed.py +++ b/tests/llmcompressor/transformers/compression/test_run_compressed.py @@ -3,6 +3,7 @@ import tempfile import unittest +import torch from compressed_tensors import QUANTIZATION_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import QuantizationStatus @@ -113,16 +114,16 @@ def test_hf_quantizer_decompress_match_manual_decompress(self): ) inputs = inputs.to(self.decompressed_model_manual.device) - decompressed_model_manual_output = self.tokenizer.batch_decode( - self.decompressed_model_manual.generate(**inputs, max_length=50) + decompressed_model_manual_output = self.decompressed_model_manual.generate( + **inputs, max_length=50 ) - decompressed_model_hf_quantizer_out = self.tokenizer.batch_decode( + decompressed_model_hf_quantizer_out = ( self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50) ) - assert ( - decompressed_model_hf_quantizer_out == decompressed_model_manual_output + assert torch.equal( + decompressed_model_hf_quantizer_out, decompressed_model_manual_output ) @classmethod diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 7b475fdb5..a7138b186 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,18 +1,16 @@ import pytest from llmcompressor.args import DatasetArguments -from llmcompressor.transformers.finetune.data.data_helpers import ( - get_raw_dataset, - make_dataset_splits, -) +from llmcompressor.datasets import make_dataset_splits +from llmcompressor.transformers.finetune.data.data_helpers import get_raw_dataset @pytest.mark.unit def test_combined_datasets(): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) - raw_wikitext2 = get_raw_dataset(data_args) + raw_wikitext2 = get_raw_dataset(dataset_args) datasets = {"all": raw_wikitext2} split_datasets = make_dataset_splits(datasets, do_train=True) assert split_datasets.get("train") is not None @@ -23,13 +21,13 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): - splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DatasetArguments( + splits = {"train": "train[:5%]", "validation": "train[10%:20%]"} + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} for split_name, split_str in splits.items(): - raw_wikitext2 = get_raw_dataset(data_args, split=split_str) + raw_wikitext2 = get_raw_dataset(dataset_args, split=split_str) datasets[split_name] = raw_wikitext2 split_datasets = make_dataset_splits(datasets, do_train=True) diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index dcc602877..3fc174acb 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -11,17 +11,15 @@ RecipeArguments, TrainingArguments, ) +from llmcompressor.datasets import format_calibration_data from llmcompressor.transformers import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, -) from llmcompressor.transformers.finetune.runner import StageRunner @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -33,8 +31,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_concatenation_tokenization(self): wiki_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[:5%]", processor=self.tiny_llama_tokenizer, ) @@ -54,7 +52,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -65,9 +63,9 @@ def prepare_fixture(self, tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_no_padding_tokenization(self): op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:10%]", + self.dataset_args.dataset, + dataset_args=self.dataset_args, + split="train[5%:7%]", processor=self.tiny_llama_tokenizer, ) dataset = op_manager.load_dataset() # load @@ -75,14 +73,14 @@ def test_no_padding_tokenization(self): dataset, op_manager.preprocess, batched=False, - num_proc=op_manager.data_args.preprocessing_num_workers, + num_proc=op_manager.dataset_args.preprocessing_num_workers, ) dataset = op_manager.rename_columns(dataset) # rename self.assertGreater(len(dataset), 0) ex_item = dataset[0]["text"] self.assertIn("Below is an instruction that describes a task", ex_item) - self.assertEqual(dataset.split, "train[5%:10%]") + self.assertEqual(dataset.split, "train[5%:7%]") tokenized_dataset = op_manager() self.assertIn("input_ids", tokenized_dataset.features) self.assertIn("labels", tokenized_dataset.features) @@ -97,7 +95,9 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) + self.dataset_args = DatasetArguments( + dataset="open_platypus", max_seq_length=4096 + ) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -105,9 +105,9 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_max_seq_len_clipped(self): op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[80%:]", + self.dataset_args.dataset, + dataset_args=self.dataset_args, + split="train[95%:]", processor=self.tiny_llama_tokenizer, ) @@ -119,7 +119,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -134,17 +134,17 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_dataset_kwargs_and_percentages(self): c4_manager_a = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:10%]", + self.dataset_args.dataset, + dataset_args=self.dataset_args, + split="train[5%:6%]", processor=self.tiny_llama_tokenizer, ) raw_dataset_a = c4_manager_a.load_dataset() c4_manager_b = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:15%]", + self.dataset_args.dataset, + dataset_args=self.dataset_args, + split="train[6%:8%]", processor=self.tiny_llama_tokenizer, ) raw_dataset_b = c4_manager_b.load_dataset() @@ -162,19 +162,19 @@ def prepare_fixture(self, tiny_llama_tokenizer): [ ["ptb", "penn_treebank", "train[:5%]", False], ["gsm8k", "main", "train[:5%]", True], - ["ultrachat_200k", "default", "train_sft[:2%]", False], + ["ultrachat_200k", "default", "train_sft[:1%]", False], ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, trust_remote_code_data=True, ) manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=split, processor=self.tiny_llama_tokenizer, ) @@ -205,7 +205,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -213,8 +213,8 @@ def setUp(self): def test_evol(self): evol_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[:2%]", processor=self.tiny_llama_tokenizer, ) @@ -234,7 +234,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -247,8 +247,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_stream_loading(self): manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train", processor=self.tiny_llama_tokenizer, ) @@ -271,11 +271,9 @@ class TestSplitLoading(unittest.TestCase): def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer - @parameterized.expand( - [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] - ) + @parameterized.expand([["train[95%:]"], [{"train": "train[:5%]"}]]) def test_split_loading(self, split_def): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, @@ -285,7 +283,7 @@ def test_split_loading(self, split_def): recipe_args = RecipeArguments() stage_runner = StageRunner( model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, training_args=training_args, recipe_args=recipe_args, ) @@ -302,7 +300,7 @@ class TestTokenizationDataset(unittest.TestCase): def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer dataset = load_dataset("garage-bAInd/Open-Platypus")["train"] - self.num_calib_samples = 256 + self.num_calib_samples = 64 self.max_seq_len = 512 self.dataset = dataset.shuffle(seed=42).select(range(self.num_calib_samples)) @@ -321,7 +319,7 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DatasetArguments( + dataset_args=DatasetArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), @@ -339,7 +337,7 @@ def preprocess(sample): calib_dataloader = format_calibration_data( tokenized_dataset=calib_dataset, num_calibration_samples=self.num_calib_samples, - do_shuffle=stage_runner._data_args.shuffle_calibration_samples, + do_shuffle=stage_runner._dataset_args.shuffle_calibration_samples, ) self.assertEqual(len(calib_dataloader), self.num_calib_samples) dataloader_sample = next(iter(calib_dataloader))["input_ids"] diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 694a9b6d3..29895b4a4 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -11,49 +11,49 @@ @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="c4", concatenate_data=True) + dataset_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(c4_manager, TextGenerationDataset) assert isinstance(c4_manager, C4Dataset) - assert c4_manager.data_args.text_column == "text" + assert c4_manager.dataset_args.text_column == "text" assert not c4_manager.padding - assert c4_manager.max_seq_length == data_args.max_seq_length + assert c4_manager.max_seq_length == dataset_args.max_seq_length @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(wiki_manager, TextGenerationDataset) assert isinstance(wiki_manager, WikiTextDataset) - assert wiki_manager.data_args.text_column == "text" + assert wiki_manager.dataset_args.text_column == "text" assert wiki_manager.padding == "max_length" - assert wiki_manager.max_seq_length == data_args.max_seq_length + assert wiki_manager.max_seq_length == dataset_args.max_seq_length @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) + dataset_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(op_manager, TextGenerationDataset) assert isinstance(op_manager, OpenPlatypusDataset) - assert op_manager.data_args.text_column == "text" + assert op_manager.dataset_args.text_column == "text" assert not op_manager.padding - assert op_manager.max_seq_length == data_args.max_seq_length + assert op_manager.max_seq_length == dataset_args.max_seq_length diff --git a/tests/llmcompressor/transformers/finetune/finetune_oneshot_configs/config.yaml b/tests/llmcompressor/transformers/finetune/finetune_oneshot_configs/config.yaml index 044407c5d..30b4658cb 100644 --- a/tests/llmcompressor/transformers/finetune/finetune_oneshot_configs/config.yaml +++ b/tests/llmcompressor/transformers/finetune/finetune_oneshot_configs/config.yaml @@ -4,5 +4,5 @@ model: "Xenova/llama2.c-stories15M" dataset: wikitext dataset_config_name: "wikitext-2-raw-v1" recipe: "tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml" -num_train_epochs: 1 +num_train_epochs: 0.25 concat_txt: False \ No newline at end of file diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py b/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py index f8f8d9827..37524069c 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_no_recipe_custom_dataset.py @@ -108,7 +108,6 @@ def create_mock_file(self, extension, content, path, filename): def tearDown(self): shutil.rmtree(self.output) - self.monkeypatch.undo() @pytest.mark.integration @@ -121,11 +120,8 @@ class TestOneshotCustomDatasetSmall(TestFinetuneNoRecipeCustomDataset): def setUp(self): import torch - self.monkeypatch = pytest.MonkeyPatch() - if torch.cuda.is_available(): self.device = "cuda:0" - self.monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "0") else: self.device = "cpu" @@ -147,15 +143,12 @@ def setUp(self): import torch from transformers import AutoModelForCausalLM - self.monkeypatch = pytest.MonkeyPatch() self.device = "cuda:0" self.output = "./oneshot_output" - self.monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "0") self.model = AutoModelForCausalLM.from_pretrained( self.model, device_map=self.device, torch_dtype=torch.bfloat16 ) - self.monkeypatch = pytest.MonkeyPatch() def test_oneshot_then_finetune_gpu(self): self._test_finetune_wout_recipe_custom_dataset() diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py index 870503496..d3bc611d0 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py @@ -19,9 +19,9 @@ class TestOneshotAndFinetune(unittest.TestCase): def _test_oneshot_and_finetune(self): from llmcompressor.transformers import apply - splits = {"train": "train[:30%]", "calibration": "train[30%:40%]"} + splits = {"train": "train[:5%]", "calibration": "train[5%:10%]"} if self.dataset == "ultrachat-200k": - splits = {"train": "train_gen[:30%]", "calibration": "train_gen[30%:40%]"} + splits = {"train": "train_gen[:5%]", "calibration": "train_gen[5%:10%]"} apply( model=self.model, @@ -30,6 +30,7 @@ def _test_oneshot_and_finetune(self): output_dir=self.output, recipe=self.recipe, num_train_epochs=self.num_train_epochs, + num_calibration_samples=64, concatenate_data=self.concat_txt, splits=splits, oneshot_device=self.device, diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py index 509464a34..45b25818b 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py @@ -21,7 +21,6 @@ def setUp(self): self.output = "./finetune_output" # finetune workflows in general seem to have trouble with multi-gpus # use just one atm - self.monkeypatch = pytest.MonkeyPatch() def test_oneshot_and_finetune_with_tokenizer(self): from datasets import load_dataset @@ -29,8 +28,6 @@ def test_oneshot_and_finetune_with_tokenizer(self): from llmcompressor.transformers import compress - self.monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "0") - recipe_str = ( "tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml" ) @@ -71,4 +68,3 @@ def test_oneshot_and_finetune_with_tokenizer(self): def tearDown(self): shutil.rmtree(self.output) - self.monkeypatch.undo() diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index d4d65469d..e8e0ae426 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -27,7 +27,7 @@ def test_oneshot_sparsification_then_finetune(self): concatenate_data = False num_calibration_samples = 64 output_dir = self.output / "oneshot_out" - splits = {"calibration": "train[:10%]"} + splits = {"calibration": "train[:5%]"} with create_session(): oneshot( @@ -56,8 +56,7 @@ def test_oneshot_sparsification_then_finetune(self): dataset = "open_platypus" concatenate_data = False output_dir = self.output / "finetune_out" - splits = "train[:50%]" - max_steps = 25 + splits = "train[5%:7%]" with create_session(): train( @@ -65,11 +64,10 @@ def test_oneshot_sparsification_then_finetune(self): distill_teacher=distill_teacher, dataset=dataset, output_dir=output_dir, - num_calibration_samples=num_calibration_samples, + num_train_epochs=0.05, recipe=recipe_str, concatenate_data=concatenate_data, splits=splits, - max_steps=max_steps, ) # test reloading checkpoint and final model @@ -85,11 +83,10 @@ def test_oneshot_sparsification_then_finetune(self): distill_teacher=distill_teacher, dataset=dataset, output_dir=output_dir, - num_calibration_samples=num_calibration_samples, + num_train_epochs=0.05, recipe=recipe_str, concatenate_data=concatenate_data, splits=splits, - max_steps=max_steps, resume_from_checkpoint=True, # use last checkpoint ) @@ -106,7 +103,7 @@ def test_oneshot_quantization_then_finetune(self): concatenate_data = False num_calibration_samples = 64 output_dir = self.output / "oneshot_out" - splits = {"calibration": "train[:10%]"} + splits = {"calibration": "train[:5%]"} with create_session(): oneshot( @@ -130,17 +127,17 @@ def test_oneshot_quantization_then_finetune(self): dataset = "open_platypus" concatenate_data = False output_dir = self.output / "finetune_out" - splits = {"calibration": "train[:10%]", "train": "train[:10%]"} + splits = {"calibration": "train[:5%]", "train": "train[5%:7%]"} with create_session(): train( model=model, dataset=dataset, output_dir=output_dir, - num_calibration_samples=num_calibration_samples, recipe=recipe, concatenate_data=concatenate_data, splits=splits, + num_train_epochs=0.05, ) # test reloading checkpoint and final model @@ -152,10 +149,10 @@ def test_oneshot_quantization_then_finetune(self): model=model, dataset=dataset, output_dir=output_dir, - num_calibration_samples=num_calibration_samples, recipe=recipe, concatenate_data=concatenate_data, splits=splits, + num_train_epochs=0.05, resume_from_checkpoint=True, # use last checkpoint ) diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 4fa981de9..65e5140bf 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -15,7 +15,7 @@ def __init__( recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, model_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional[Union[Dict[str, Any], str]] = None, + dataset_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): @@ -24,7 +24,7 @@ def __init__( recipe=recipe, recipe_args=recipe_args, model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, teacher=teacher, **kwargs, ) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 1016cf422..5528a443e 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -26,14 +26,14 @@ def labeled_dataloader(self, dataset_name, model_name): from llmcompressor.transformers.finetune.data import TextGenerationDataset tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False, ) dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train", processor=tokenizer, ) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py index 4948c6da3..17effeb7a 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -3,10 +3,8 @@ from datasets import Dataset from transformers import AutoModelForCausalLM +from llmcompressor.datasets import format_calibration_data from llmcompressor.modifiers.obcq import SparseGPTModifier -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, -) from llmcompressor.utils.pytorch.module import get_layers diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 0e0ac8925..1796293f7 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -103,3 +103,16 @@ def test_environment_variable_disable_logging(monkeypatch, capsys): captured = capsys.readouterr() assert captured.out == "" assert captured.err == "" + + +def test_environment_variable_enable_logging(monkeypatch, capsys): + # Test environment variable to enable logging + monkeypatch.setenv("LLM_COMPRESSOR_LOG_DISABLED", "false") + + configure_logger(config=LoggerConfig()) + logger.info("Info message") + logger.error("Error message") + + captured = capsys.readouterr() + assert captured.out.count("Info message") == 1 + assert captured.out.count("Error message") == 1