Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ possibly-missing-attribute = "ignore" # mypy is more permissive with attribute a
possibly-missing-import = "ignore" # mypy is more permissive with imports
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
missing-argument = "ignore"

[tool.coverage.run]
source = ["src/pruna"]
Expand Down
19 changes: 19 additions & 0 deletions src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Any, Callable, Tuple

from pruna.data.datasets.audio import (
Expand Down Expand Up @@ -77,6 +78,24 @@
"image_classification_collate",
{"img_size": 32},
),
# our full CIFAR10 has 50k train and 10k test
"TinyCIFAR10": (
partial(setup_cifar10_dataset, train_sample_size=800, test_sample_size=100),
"image_classification_collate",
{"img_size": 32},
),
# our full MNIST has 60k train and 10k test
"TinyMNIST": (
partial(setup_mnist_dataset, train_sample_size=800, test_sample_size=100),
"image_classification_collate",
{"img_size": 28},
),
# our full ImageNet has 100k train and 10k val
"TinyImageNet": (
partial(setup_imagenet_dataset, train_sample_size=1000, test_sample_size=100),
"image_classification_collate",
{"img_size": 224},
),
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
Expand Down
72 changes: 66 additions & 6 deletions src/pruna/data/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Tuple

from datasets import load_dataset
from torch.utils.data import Dataset

from pruna.data.utils import split_train_into_train_val, split_val_into_val_test
from pruna.data.utils import (
define_sample_size_for_dataset,
split_train_into_train_val,
split_val_into_val_test,
stratify_dataset,
)


def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_mnist_dataset(
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the MNIST dataset.

Expand All @@ -30,18 +39,35 @@ def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
fraction : float
The fraction of the dataset to use.
train_sample_size : int | None
The sample size to use for the train dataset.
test_sample_size : int | None
The sample size to use for the test dataset.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The MNIST dataset.
"""
train_ds, test_ds = load_dataset("ylecun/mnist", split=["train", "test"]) # type: ignore[misc]

train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)

train_ds = stratify_dataset(train_ds, train_sample_size, seed)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
val_ds, test_ds = split_val_into_val_test(val_ds, seed)

return train_ds, val_ds, test_ds # type: ignore[return-value]


def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_imagenet_dataset(
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the ImageNet dataset.

Expand All @@ -51,33 +77,67 @@ def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
----------
seed : int
The seed to use.
fraction : float
The fraction of the dataset to use.
train_sample_size : int | None
The sample size to use for the train dataset.
test_sample_size : int | None
The sample size to use for the test dataset.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The ImageNet dataset.
"""
train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc]
train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
train_ds = stratify_dataset(train_ds, train_sample_size, seed)
val_ds, test_ds = split_val_into_val_test(val, seed)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]


def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_cifar10_dataset(
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the CIFAR-10 dataset.

The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column,
but this function renames it to 'image' to ensure compatibility with
the image_classification_collate function which expects an 'image' column.

License: unspecified

Parameters
----------
seed : int
The seed to use.
fraction : float
The fraction of the dataset to use.
train_sample_size : int | None
The sample size to use for the train dataset.
test_sample_size : int | None
The sample size to use for the test dataset.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The CIFAR-10 dataset.
The CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int).
"""
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) # type: ignore[misc]
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"])

# Rename 'img' column to 'image' to match collate function expectations
# This ensures compatibility with image_classification_collate function
train_ds = train_ds.rename_column("img", "image")
test_ds = test_ds.rename_column("img", "image")

train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)

train_ds = stratify_dataset(train_ds, train_sample_size, seed)
test_ds = stratify_dataset(test_ds, test_sample_size, seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]
58 changes: 58 additions & 0 deletions src/pruna/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import random
from typing import Any, Tuple, Union

import torch
Expand Down Expand Up @@ -181,3 +182,60 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list
raise ValueError()
texts.extend(out)
return texts


def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset:
"""
Stratify the dataset into a specific size.

Parameters
----------
dataset : Dataset
The dataset to stratify.
sample_size : int
The size to stratify.
seed : int
The seed to use for sampling the dataset.

Returns
-------
Dataset
The stratified dataset.
"""
dataset_length = len(dataset)
if dataset_length < sample_size:
pruna_logger.warning(
"Dataset length is less than the size to stratify."
f"Using the entire dataset. ({dataset_length} < {sample_size})"
)
return dataset

indices = list(range(dataset_length))
random.Random(seed).shuffle(indices)
selected_indices = indices[:sample_size]
dataset = dataset.select(selected_indices)
return dataset


def define_sample_size_for_dataset(dataset: Dataset, fraction: float, sample_size: int | None = None) -> int:
"""
Define the sample size for the dataset.

Parameters
----------
dataset: Dataset
The dataset to define the sample size for.
fraction: float
The fraction of the dataset to sample.
sample_size: int | None
The sample size to use.

Returns
-------
int
The sample size for the dataset.
"""
if fraction < 1.0 and (sample_size is not None):
raise ValueError("Fraction and sample sizes cannot be used together.")
sample_size = int(len(dataset) * fraction) if fraction < 1.0 else sample_size or len(dataset)
return sample_size
3 changes: 3 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("LibriSpeech", dict(), marks=pytest.mark.slow),
pytest.param("AIPodcast", dict(), marks=pytest.mark.slow),
("ImageNet", dict(img_size=512)),
("TinyCIFAR10", dict(img_size=32)),
("TinyImageNet", dict(img_size=224)),
("TinyMNIST", dict(img_size=28)),
pytest.param("MNIST", dict(img_size=512), marks=pytest.mark.slow),
("WikiText", dict(tokenizer=bert_tokenizer)),
pytest.param("TinyWikiText", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
Expand Down