Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ possibly-missing-attribute = "ignore" # mypy is more permissive with attribute a
possibly-missing-import = "ignore" # mypy is more permissive with imports
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
missing-argument = "ignore"

[tool.coverage.run]
source = ["src/pruna"]
Expand Down
4 changes: 4 additions & 0 deletions src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Any, Callable, Tuple

from pruna.data.datasets.audio import (
Expand Down Expand Up @@ -77,6 +78,9 @@
"image_classification_collate",
{"img_size": 32},
),
"TinyCIFAR10": (partial(setup_cifar10_dataset, fraction=0.1), "image_classification_collate", {"img_size": 32}),
"TinyMNIST": (partial(setup_mnist_dataset, fraction=0.1), "image_classification_collate", {"img_size": 28}),
"TinyImageNet": (partial(setup_imagenet_dataset, fraction=0.1), "image_classification_collate", {"img_size": 224}),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Just not 100% sure if a fraction for each of these datasets is small enough, and it is clear how many samples we get now? We could also allow a range/number or something. Not sure if that would be better, but otherwise we can keep it like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely see your point here, I only did fractions since we have limit_datasets in PrunaDataModule that allows us to give a number to limit the dataset. If you still think also having a number rather than a fraction here makes more sense I am happy to change it, what do you think?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think setting it to fixed numbers is nicer as we have more control and awareness surrounding the number.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I have also added this feature 🧡🧡

"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
Expand Down
42 changes: 36 additions & 6 deletions src/pruna/data/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from datasets import load_dataset
from torch.utils.data import Dataset

from pruna.data.utils import split_train_into_train_val, split_val_into_val_test
from pruna.data.utils import split_train_into_train_val, split_val_into_val_test, stratify_dataset


def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the MNIST dataset.

Expand All @@ -31,17 +31,26 @@ def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
seed : int
The seed to use.

fraction : float
The fraction of the dataset to use.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The MNIST dataset.
"""
train_ds, test_ds = load_dataset("ylecun/mnist", split=["train", "test"]) # type: ignore[misc]

train_ds = stratify_dataset(train_ds, "label", fraction, seed)
test_ds = stratify_dataset(test_ds, "label", fraction, seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
val_ds, test_ds = split_val_into_val_test(val_ds, seed)

return train_ds, val_ds, test_ds # type: ignore[return-value]


def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_imagenet_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the ImageNet dataset.

Expand All @@ -52,32 +61,53 @@ def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
seed : int
The seed to use.

fraction : float
The fraction of the dataset to use.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The ImageNet dataset.
"""
train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc]
train_ds = stratify_dataset(train_ds, "label", fraction, seed)
val = stratify_dataset(val, "label", fraction, seed)
val_ds, test_ds = split_val_into_val_test(val, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]


def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the CIFAR-10 dataset.

The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column,
but this function renames it to 'image' to ensure compatibility with
the image_classification_collate function which expects an 'image' column.

License: unspecified

Parameters
----------
seed : int
The seed to use.

fraction : float
The fraction of the dataset to use.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The CIFAR-10 dataset.
The CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int).
"""
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) # type: ignore[misc]
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"])

# Rename 'img' column to 'image' to match collate function expectations
# This ensures compatibility with image_classification_collate function
train_ds = train_ds.rename_column("img", "image")
test_ds = test_ds.rename_column("img", "image")

train_ds = stratify_dataset(train_ds, "label", fraction, seed)
test_ds = stratify_dataset(test_ds, "label", fraction, seed)

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]
27 changes: 27 additions & 0 deletions src/pruna/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,30 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list
raise ValueError()
texts.extend(out)
return texts


def stratify_dataset(dataset: Dataset, column: str, fraction: float, seed: int) -> Dataset:
"""
Stratify the dataset into a fraction of the dataset.

Parameters
----------
dataset : Dataset
The dataset to stratify.
column : str
The column to stratify by.
fraction : float
The fraction of the dataset to stratify.
seed : int
The seed to use for splitting the dataset.

Returns
-------
Dataset
The stratified dataset.
"""
if fraction < 1.0:
split_result = dataset.train_test_split(test_size=1 - fraction, stratify_by_column="label", seed=seed)
dataset = split_result["train"]

return dataset
3 changes: 3 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("LibriSpeech", dict(), marks=pytest.mark.slow),
pytest.param("AIPodcast", dict(), marks=pytest.mark.slow),
("ImageNet", dict(img_size=512)),
("TinyCIFAR10", dict(img_size=32)),
("TinyImageNet", dict(img_size=224)),
("TinyMNIST", dict(img_size=28)),
pytest.param("MNIST", dict(img_size=512), marks=pytest.mark.slow),
("WikiText", dict(tokenizer=bert_tokenizer)),
pytest.param("TinyWikiText", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
Expand Down