diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 7d0bc14ce..6154be4c2 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -1,8 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union -from transformers import DefaultDataCollator - @dataclass class DVCDatasetTrainingArguments: @@ -60,9 +58,12 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + data_collator: Optional[Callable[[Any], Any]] = field( + default=None, + metadata={ + "help": "The function to used to form a batch from the dataset. Defaults " + "to `DataCollatorWithPadding` with model tokenizer if None is provided" + }, ) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..88fce7595 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,11 +1,17 @@ import logging import os +import warnings from typing import Any, Callable, Dict, List, Optional import torch from datasets import Dataset, load_dataset from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from transformers.data.data_collator import ( + DataCollatorWithPadding, + default_data_collator, +) + +from llmcompressor.typing import Processor LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 @@ -21,7 +27,9 @@ def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, + batch_size: int = 1, do_shuffle: bool = True, + processor: Optional[Processor] = None, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, ) -> List[torch.Tensor]: @@ -37,6 +45,11 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ + # shuffle + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + + # truncate samples safe_calibration_samples = len(tokenized_dataset) if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) @@ -45,13 +58,22 @@ def format_calibration_data( f"Requested {num_calibration_samples} calibration samples but " f"the provided dataset only has {safe_calibration_samples}. " ) - - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) + # collate data + if collate_fn is None: + tokenizer = getattr(processor, "tokenizer", processor) + if tokenizer is None: + warnings.warn( + "Could not find processor, attempting to collate with without padding " + "(may fail for batch_size > 1)" + ) + return default_data_collator() + + collate_fn = DataCollatorWithPadding(tokenizer) + dataloader_params = { - "batch_size": 1, + "batch_size": batch_size, "sampler": RandomSampler(tokenized_calibration) if do_shuffle else SequentialSampler(tokenized_calibration), diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 0a07c45eb..ce3b0ae49 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None): calib_data = format_calibration_data( tokenized_dataset=self.get_dataset_split("calibration"), num_calibration_samples=self._data_args.num_calibration_samples, + batch_size=self._training_args.per_device_oneshot_batch_size, do_shuffle=self._data_args.shuffle_calibration_samples, collate_fn=self._data_args.data_collator, + processor=self.processor, accelerator=self.trainer.accelerator, ) diff --git a/src/llmcompressor/transformers/finetune/training_args.py b/src/llmcompressor/transformers/finetune/training_args.py index c04fa2807..56d4eb6a4 100644 --- a/src/llmcompressor/transformers/finetune/training_args.py +++ b/src/llmcompressor/transformers/finetune/training_args.py @@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs): ) }, ) + per_device_oneshot_batch_size: int = field( + default=1, + metadata={ + "help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot" + }, + ) save_compressed: Optional[bool] = field( default=True, metadata={"help": "Whether to compress sparse models during save"}, diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..2c936a363 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,7 +1,11 @@ +# TODO: rename to `test_data_helpers.py` import pytest +import torch +from datasets import Dataset from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( + format_calibration_data, get_raw_dataset, make_dataset_splits, ) @@ -53,3 +57,18 @@ def test_separate_datasets(): split_datasets = make_dataset_splits( datasets, do_train=True, do_eval=True, do_predict=True ) + + +@pytest.mark.unit +def test_format_calibration_data(): + tokenized_dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, 512, (8, 2048))} + ) + + calibration_dataloader = format_calibration_data( + tokenized_dataset, num_calibration_samples=4, batch_size=2 + ) + + batch = next(iter(calibration_dataloader)) + + assert batch["input_ids"].size(0) == 2