diff --git a/src/llmcompressor/datasets/__init__.py b/src/llmcompressor/datasets/__init__.py new file mode 100644 index 000000000..0b81cc724 --- /dev/null +++ b/src/llmcompressor/datasets/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa + +from .utils import ( + format_calibration_data, + get_calibration_dataloader, + get_processed_dataset, + make_dataset_splits, +) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py new file mode 100644 index 000000000..0d36cb3ac --- /dev/null +++ b/src/llmcompressor/datasets/utils.py @@ -0,0 +1,191 @@ +import re +from typing import Any, Callable, Dict, List, Optional + +import torch +from datasets import Dataset +from loguru import logger +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from transformers.data import default_data_collator + +from llmcompressor.args import DatasetArguments +from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.typing import Processor + + +def get_processed_dataset( + dataset_args: DatasetArguments, + processor: Processor, + do_oneshot: bool = False, + do_train: bool = True, +) -> Optional[Dict[str, Dataset]]: + """ + Loads datasets for each flow based on dataset_args, stores a Dataset for each + enabled flow in datasets + :param dataset_args: DatasetArguments that contain dataset loading and + processing params + :param processor: processor or tokenizer to use for dataset tokenization + :param do_oneshot: True for oneshot pathway + :param do_train: True for train pathway + :return: A dataset corresponding to either train or calibration (oneshot) + """ + if dataset_args.dataset is None: + logger.warning( + "Running oneshot without calibration data. This is expected for " + "weight-only and dynamic quantization" + ) + return + + splits = dataset_args.splits + tokenized_datasets = {} + + def _get_split_name(inp_str): + # strip out split name, for ex train[60%:] -> train + match = re.match(r"(\w*)\[.*\]", inp_str) + if match is not None: + return match.group(1) + return inp_str + + if splits is None: + splits = {"all": None} + elif isinstance(splits, str): + splits = {_get_split_name(splits): splits} + elif isinstance(splits, List): + splits = {_get_split_name(s): s for s in splits} + + # default to custom dataset if dataset provided isn't a string + registry_id = ( + dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom" + ) + for split_name, split_str in splits.items(): + dataset = dataset_args.dataset + if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: + # dataset is already tokenized + tokenized_datasets[split_name] = dataset + else: + # dataset needs to be tokenized + dataset_manager = TextGenerationDataset.load_from_registry( + registry_id, + dataset_args=dataset_args, + split=split_str, + processor=processor, + ) + tokenized_datasets[split_name] = dataset_manager(add_labels=do_train) + + return make_dataset_splits( + tokenized_datasets, + do_oneshot=do_oneshot, + do_train=do_train, + ) + + +def get_calibration_dataloader( + dataset_args: DatasetArguments, + processor: Processor, +) -> torch.utils.data.DataLoader: + """ + Get the dataloader used for oneshot calibration. + :param dataset_args: DatasetArguments that contains the dataset parameters. + :param processor: Processor or the tokenizer of the model. + :return: PyTorch dataloader object that contains the calibration dataset. + """ + if dataset_args.dataset is None: + # weight-only quantization or dynamic quantization + return + + datasets = get_processed_dataset( + dataset_args=dataset_args, + processor=processor, + do_oneshot=True, + do_train=False, + ) + + calibration_dataset = datasets.get("calibration") + + return format_calibration_data( + tokenized_dataset=calibration_dataset, + num_calibration_samples=dataset_args.num_calibration_samples, + do_shuffle=dataset_args.shuffle_calibration_samples, + collate_fn=dataset_args.data_collator, + ) + + +def format_calibration_data( + tokenized_dataset: Dataset, + num_calibration_samples: Optional[int] = None, + do_shuffle: bool = True, + collate_fn: Callable = default_data_collator, +) -> List[torch.Tensor]: + """ + Creates a dataloader out of the calibration dataset split, trimming it to + the desired number of calibration samples + :param tokenized_dataset: dataset to convert to dataloader + :param num_calibration_samples: number of data samples to convert + :param do_shuffle: whether to shuffle the dataset before selecting calibration + samples, true by default + :param collate_fn: optional custom collate function, or use default + :return: list of trimmed calibration data tensors + """ + safe_calibration_samples = len(tokenized_dataset) + if num_calibration_samples is not None: + safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) + if safe_calibration_samples != num_calibration_samples: + logger.warn( + f"Requested {num_calibration_samples} calibration samples but " + f"the provided dataset only has {safe_calibration_samples}. " + ) + + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) + + dataloader_params = { + "batch_size": 1, + "sampler": RandomSampler(tokenized_calibration) + if do_shuffle + else SequentialSampler(tokenized_calibration), + "collate_fn": collate_fn, + "pin_memory": True, + } + + calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) + + return calibration_dataloader + + +def make_dataset_splits( + tokenized_datasets: Dict[str, Any], + do_oneshot: bool = True, + do_train: bool = False, +) -> Dict[str, Dataset]: + """ + Restructures the datasets dictionary based on what tasks will be run + train + :param tokenized_datasets: dictionary of processed datasets + :param do_oneshot: Whether to store the calibration dataset + :return: A dataset corresponding to either train or calibration (oneshot) + """ + + # handles case where all splits are contained in a single dataset + if "all" in tokenized_datasets and len(tokenized_datasets) == 1: + tokenized_datasets = tokenized_datasets.get("all") + if isinstance(tokenized_datasets, Dataset): + tokenized_datasets = {"train": tokenized_datasets} + + train_split = calib_split = None + + if do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_split = tokenized_datasets["train"] + if do_oneshot: + calib_split = tokenized_datasets.get("calibration") + if calib_split is None: + if "train" not in tokenized_datasets: + raise ValueError("--do_oneshot requires a calibration dataset") + calib_split = tokenized_datasets["train"] + + split_datasets = { + "train": train_split, + "calibration": calib_split, + } + return split_datasets diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index ea6481043..cfaf83f92 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -7,9 +7,7 @@ from llmcompressor.args import parse_args from llmcompressor.core.session_functions import active_session -from llmcompressor.transformers.finetune.data.data_helpers import ( - get_calibration_dataloader, -) +from llmcompressor.datasets import get_calibration_dataloader from llmcompressor.transformers.finetune.text_generation import ( initialize_model_from_path, initialize_processor_from_path, diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index bd28de314..ff56cfbb9 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,74 +1,18 @@ import logging import os -import re -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, Optional -import torch from datasets import Dataset, load_dataset -from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 __all__ = [ - "format_calibration_data", "get_raw_dataset", - "make_dataset_splits", "get_custom_datasets_from_path", - "get_calibration_dataloader", ] -def format_calibration_data( - tokenized_dataset: Dataset, - num_calibration_samples: Optional[int] = None, - do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, - accelerator: Optional[Any] = None, -) -> List[torch.Tensor]: - """ - Creates a dataloader out of the calibration dataset split, trimming it to - the desired number of calibration samples - - :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert - :param do_shuffle: whether to shuffle the dataset before selecting calibration - samples, true by default - :param collate_fn: optional custom collate function, or use default - :param accelerator: optional accelerator for if preparing in FSDP mode - :return: list of trimmed calibration data tensors - """ - safe_calibration_samples = len(tokenized_dataset) - if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: - LOGGER.warn( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " - ) - - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - - dataloader_params = { - "batch_size": 1, - "sampler": RandomSampler(tokenized_calibration) - if do_shuffle - else SequentialSampler(tokenized_calibration), - "collate_fn": collate_fn, - "pin_memory": True, - } - - calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - if accelerator: - calib_dataloader = accelerator.prepare(calib_dataloader) - - return calib_dataloader - - def get_raw_dataset( dataset_args, cache_dir: Optional[str] = None, @@ -94,47 +38,6 @@ def get_raw_dataset( return raw_datasets -def make_dataset_splits( - tokenized_datasets: Dict[str, Any], - do_train: bool = False, - do_oneshot: bool = False, -) -> Dict[str, Dataset]: - """ - Restructures the datasets dictionary based on what tasks will be run - train - - :param tokenized_datasets: dictionary of processed datasets - :param do_oneshot: Whether to store the calibration dataset - - :return: Datasets to be used by the requested tasks - """ - - # handles case where all splits are contained in a single dataset - if "all" in tokenized_datasets and len(tokenized_datasets) == 1: - tokenized_datasets = tokenized_datasets.get("all") - if isinstance(tokenized_datasets, Dataset): - tokenized_datasets = {"train": tokenized_datasets} - - train_split = calib_split = None - - if do_train: - if "train" not in tokenized_datasets: - raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] - if do_oneshot: - calib_split = tokenized_datasets.get("calibration") - if calib_split is None: - if "train" not in tokenized_datasets: - raise ValueError("--do_oneshot requires a calibration dataset") - calib_split = tokenized_datasets["train"] - - split_datasets = { - "train": train_split, - "calibration": calib_split, - } - return split_datasets - - def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str]: """ Get a dictionary of custom datasets from a directory path. Support HF's load_dataset @@ -232,78 +135,3 @@ def do_transform(candidate: str) -> bool: transform_dataset_key(dataset_key) return data_files - - -def get_calibration_dataloader( - dataset_args, - processor, - add_labels: bool = False, # for oneshot - do_oneshot=True, -) -> torch.utils.data.DataLoader: - """ - Loads datasets for each flow based on dataset_args, stores a Dataset for each - enabled flow in self.datasets - - :param processor: processor or tokenizer to use for dataset tokenization - :param add_labels: if True, add labels column to dataset splits - """ - if dataset_args.dataset is None: - logger.info( - "Running oneshot without calibration data. This is expected for " - "weight-only and dynamic quantization" - ) - return - - splits = dataset_args.splits - tokenized_datasets = {} - - def _get_split_name(inp_str): - # strip out split name, for ex train[60%:] -> train - match = re.match(r"(\w*)\[.*\]", inp_str) - if match is not None: - return match.group(1) - return inp_str - - if splits is None: - splits = {"all": None} - elif isinstance(splits, str): - splits = {_get_split_name(splits): splits} - elif isinstance(splits, List): - splits = {_get_split_name(s): s for s in splits} - - # default to custom dataset if dataset provided isn't a string - registry_id = ( - dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom" - ) - for split_name, split_str in splits.items(): - dataset = dataset_args.dataset - if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: - # dataset is already tokenized - tokenized_datasets[split_name] = dataset - else: - # dataset needs to be tokenized - from llmcompressor.transformers.finetune.data.base import ( - TextGenerationDataset, - ) - - dataset_manager = TextGenerationDataset.load_from_registry( - registry_id, - dataset_args=dataset_args, - split=split_str, - processor=processor, - ) - tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) - - datasets = make_dataset_splits( - tokenized_datasets, - do_oneshot=do_oneshot, - ) - - calibration_dataset = datasets.get("calibration") - - return format_calibration_data( - tokenized_dataset=calibration_dataset, - num_calibration_samples=dataset_args.num_calibration_samples, - do_shuffle=dataset_args.shuffle_calibration_samples, - collate_fn=dataset_args.data_collator, - ) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 75d963aa5..b45153b4f 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -23,10 +23,6 @@ ) from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, - make_dataset_splits, -) from llmcompressor.typing import Processor @@ -119,6 +115,8 @@ def _get_split_name(inp_str): ) tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels) + from llmcompressor.datasets import make_dataset_splits + self.datasets = make_dataset_splits( tokenized_datasets, do_train=self._training_args.do_train, @@ -164,6 +162,7 @@ def run_sequential_stages( :param checkpoint: optional checkpoint to pick up a stage from """ + recipe_obj = Recipe.create_instance(self._recipe_args.recipe) with self.trainer.accelerator.main_process_first(): checkpoint_dir = self._model_args.model @@ -197,6 +196,7 @@ def run_sequential_stages( # run stage if run_type is StageRunType.ONESHOT: from llmcompressor import Oneshot + from llmcompressor.datasets import format_calibration_data self._model_args.model = model @@ -213,7 +213,6 @@ def run_sequential_stages( num_calibration_samples=self._dataset_args.num_calibration_samples, do_shuffle=self._dataset_args.shuffle_calibration_samples, collate_fn=self._dataset_args.data_collator, - accelerator=self.trainer.accelerator, ) if do_preprocess: diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 39165ffe6..a7138b186 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,10 +1,8 @@ import pytest from llmcompressor.args import DatasetArguments -from llmcompressor.transformers.finetune.data.data_helpers import ( - get_raw_dataset, - make_dataset_splits, -) +from llmcompressor.datasets import make_dataset_splits +from llmcompressor.transformers.finetune.data.data_helpers import get_raw_dataset @pytest.mark.unit diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 7198e0da3..3fc174acb 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -11,10 +11,8 @@ RecipeArguments, TrainingArguments, ) +from llmcompressor.datasets import format_calibration_data from llmcompressor.transformers import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, -) from llmcompressor.transformers.finetune.runner import StageRunner diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py index 4948c6da3..17effeb7a 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -3,10 +3,8 @@ from datasets import Dataset from transformers import AutoModelForCausalLM +from llmcompressor.datasets import format_calibration_data from llmcompressor.modifiers.obcq import SparseGPTModifier -from llmcompressor.transformers.finetune.data.data_helpers import ( - format_calibration_data, -) from llmcompressor.utils.pytorch.module import get_layers