Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Jan 30, 2025
1 parent 999d660 commit 29f93d3
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 10 deletions.
11 changes: 6 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union

from transformers import DefaultDataCollator


@dataclass
class DVCDatasetTrainingArguments:
Expand Down Expand Up @@ -60,9 +58,12 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
},
)

data_collator: Callable[[Any], Any] = field(
default_factory=lambda: DefaultDataCollator(),
metadata={"help": "The function to used to form a batch from the dataset"},
data_collator: Optional[Callable[[Any], Any]] = field(
default=None,
metadata={
"help": "The function to used to form a batch from the dataset. Defaults "
"to `DataCollatorWithPadding` with model tokenizer if None is provided"
},
)


Expand Down
32 changes: 27 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator
from transformers.data.data_collator import (
DataCollatorWithPadding,
default_data_collator,
)

from llmcompressor.typing import Processor

LOGGER = logging.getLogger(__name__)
LABELS_MASK_VALUE = -100
Expand All @@ -21,7 +27,9 @@
def format_calibration_data(
tokenized_dataset: Dataset,
num_calibration_samples: Optional[int] = None,
batch_size: int = 1,
do_shuffle: bool = True,
processor: Optional[Processor] = None,
collate_fn: Callable = default_data_collator,
accelerator: Optional[Any] = None,
) -> List[torch.Tensor]:
Expand All @@ -37,6 +45,11 @@ def format_calibration_data(
:param accelerator: optional accelerator for if preparing in FSDP mode
:return: list of trimmed calibration data tensors
"""
# shuffle
if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()

# truncate samples
safe_calibration_samples = len(tokenized_dataset)
if num_calibration_samples is not None:
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
Expand All @@ -45,13 +58,22 @@ def format_calibration_data(
f"Requested {num_calibration_samples} calibration samples but "
f"the provided dataset only has {safe_calibration_samples}. "
)

if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

# collate data
if collate_fn is None:
tokenizer = getattr(processor, "tokenizer", processor)
if tokenizer is None:
warnings.warn(
"Could not find processor, attempting to collate with without padding "
"(may fail for batch_size > 1)"
)
return default_data_collator()

collate_fn = DataCollatorWithPadding(tokenizer)

dataloader_params = {
"batch_size": 1,
"batch_size": batch_size,
"sampler": RandomSampler(tokenized_calibration)
if do_shuffle
else SequentialSampler(tokenized_calibration),
Expand Down
2 changes: 2 additions & 0 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None):
calib_data = format_calibration_data(
tokenized_dataset=self.get_dataset_split("calibration"),
num_calibration_samples=self._data_args.num_calibration_samples,
batch_size=self._training_args.per_device_oneshot_batch_size,
do_shuffle=self._data_args.shuffle_calibration_samples,
collate_fn=self._data_args.data_collator,
processor=self.processor,
accelerator=self.trainer.accelerator,
)

Expand Down
6 changes: 6 additions & 0 deletions src/llmcompressor/transformers/finetune/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs):
)
},
)
per_device_oneshot_batch_size: int = field(
default=1,
metadata={
"help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot"
},
)
save_compressed: Optional[bool] = field(
default=True,
metadata={"help": "Whether to compress sparse models during save"},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# TODO: rename to `test_data_helpers.py`
import pytest
import torch
from datasets import Dataset

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
get_raw_dataset,
make_dataset_splits,
)
Expand Down Expand Up @@ -53,3 +57,18 @@ def test_separate_datasets():
split_datasets = make_dataset_splits(
datasets, do_train=True, do_eval=True, do_predict=True
)


@pytest.mark.unit
def test_format_calibration_data():
tokenized_dataset = Dataset.from_dict(
{"input_ids": torch.randint(0, 512, (8, 2048))}
)

calibration_dataloader = format_calibration_data(
tokenized_dataset, num_calibration_samples=4, batch_size=2
)

batch = next(iter(calibration_dataloader))

assert batch["input_ids"].size(0) == 2

0 comments on commit 29f93d3

Please sign in to comment.