diff --git a/pyproject.toml b/pyproject.toml index b7fb88e30..03baa6b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ possibly-missing-attribute = "ignore" # mypy is more permissive with attribute a possibly-missing-import = "ignore" # mypy is more permissive with imports no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references +missing-argument = "ignore" [tool.coverage.run] source = ["src/pruna"] diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index 195f8f459..3a811868d 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Any, Callable, Tuple from pruna.data.datasets.audio import ( @@ -77,6 +78,24 @@ "image_classification_collate", {"img_size": 32}, ), + # our full CIFAR10 has 50k train and 10k test + "TinyCIFAR10": ( + partial(setup_cifar10_dataset, train_sample_size=800, test_sample_size=100), + "image_classification_collate", + {"img_size": 32}, + ), + # our full MNIST has 60k train and 10k test + "TinyMNIST": ( + partial(setup_mnist_dataset, train_sample_size=800, test_sample_size=100), + "image_classification_collate", + {"img_size": 28}, + ), + # our full ImageNet has 100k train and 10k val + "TinyImageNet": ( + partial(setup_imagenet_dataset, train_sample_size=1000, test_sample_size=100), + "image_classification_collate", + {"img_size": 224}, + ), "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), diff --git a/src/pruna/data/datasets/image.py b/src/pruna/data/datasets/image.py index 32b9d957f..cdc02e86e 100644 --- a/src/pruna/data/datasets/image.py +++ b/src/pruna/data/datasets/image.py @@ -12,15 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Tuple from datasets import load_dataset from torch.utils.data import Dataset -from pruna.data.utils import split_train_into_train_val, split_val_into_val_test +from pruna.data.utils import ( + define_sample_size_for_dataset, + split_train_into_train_val, + split_val_into_val_test, + stratify_dataset, +) -def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_mnist_dataset( + seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the MNIST dataset. @@ -30,6 +39,12 @@ def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + The sample size to use for the train dataset. + test_sample_size : int | None + The sample size to use for the test dataset. Returns ------- @@ -37,11 +52,22 @@ def setup_mnist_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: The MNIST dataset. """ train_ds, test_ds = load_dataset("ylecun/mnist", split=["train", "test"]) # type: ignore[misc] + + train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) + test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) + + train_ds = stratify_dataset(train_ds, train_sample_size, seed) + test_ds = stratify_dataset(test_ds, test_sample_size, seed) + train_ds, val_ds = split_train_into_train_val(train_ds, seed) + val_ds, test_ds = split_val_into_val_test(val_ds, seed) + return train_ds, val_ds, test_ds # type: ignore[return-value] -def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_imagenet_dataset( + seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the ImageNet dataset. @@ -51,6 +77,12 @@ def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: ---------- seed : int The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + The sample size to use for the train dataset. + test_sample_size : int | None + The sample size to use for the test dataset. Returns ------- @@ -58,26 +90,54 @@ def setup_imagenet_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: The ImageNet dataset. """ train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc] + train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) + train_ds = stratify_dataset(train_ds, train_sample_size, seed) val_ds, test_ds = split_val_into_val_test(val, seed) + test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) + test_ds = stratify_dataset(test_ds, test_sample_size, seed) return train_ds, val_ds, test_ds # type: ignore[return-value] -def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: +def setup_cifar10_dataset( + seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None +) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the CIFAR-10 dataset. + The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column, + but this function renames it to 'image' to ensure compatibility with + the image_classification_collate function which expects an 'image' column. + License: unspecified Parameters ---------- seed : int The seed to use. + fraction : float + The fraction of the dataset to use. + train_sample_size : int | None + The sample size to use for the train dataset. + test_sample_size : int | None + The sample size to use for the test dataset. Returns ------- Tuple[Dataset, Dataset, Dataset] - The CIFAR-10 dataset. + The CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int). """ - train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) # type: ignore[misc] + train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) + + # Rename 'img' column to 'image' to match collate function expectations + # This ensures compatibility with image_classification_collate function + train_ds = train_ds.rename_column("img", "image") + test_ds = test_ds.rename_column("img", "image") + + train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size) + test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size) + + train_ds = stratify_dataset(train_ds, train_sample_size, seed) + test_ds = stratify_dataset(test_ds, test_sample_size, seed) + train_ds, val_ds = split_train_into_train_val(train_ds, seed) return train_ds, val_ds, test_ds # type: ignore[return-value] diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index bd5027c41..b4c04e200 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -14,6 +14,7 @@ from __future__ import annotations +import random from typing import Any, Tuple, Union import torch @@ -181,3 +182,60 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list raise ValueError() texts.extend(out) return texts + + +def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset: + """ + Stratify the dataset into a specific size. + + Parameters + ---------- + dataset : Dataset + The dataset to stratify. + sample_size : int + The size to stratify. + seed : int + The seed to use for sampling the dataset. + + Returns + ------- + Dataset + The stratified dataset. + """ + dataset_length = len(dataset) + if dataset_length < sample_size: + pruna_logger.warning( + "Dataset length is less than the size to stratify." + f"Using the entire dataset. ({dataset_length} < {sample_size})" + ) + return dataset + + indices = list(range(dataset_length)) + random.Random(seed).shuffle(indices) + selected_indices = indices[:sample_size] + dataset = dataset.select(selected_indices) + return dataset + + +def define_sample_size_for_dataset(dataset: Dataset, fraction: float, sample_size: int | None = None) -> int: + """ + Define the sample size for the dataset. + + Parameters + ---------- + dataset: Dataset + The dataset to define the sample size for. + fraction: float + The fraction of the dataset to sample. + sample_size: int | None + The sample size to use. + + Returns + ------- + int + The sample size for the dataset. + """ + if fraction < 1.0 and (sample_size is not None): + raise ValueError("Fraction and sample sizes cannot be used together.") + sample_size = int(len(dataset) * fraction) if fraction < 1.0 else sample_size or len(dataset) + return sample_size diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 282bd2539..ba19d0823 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -28,6 +28,9 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("LibriSpeech", dict(), marks=pytest.mark.slow), pytest.param("AIPodcast", dict(), marks=pytest.mark.slow), ("ImageNet", dict(img_size=512)), + ("TinyCIFAR10", dict(img_size=32)), + ("TinyImageNet", dict(img_size=224)), + ("TinyMNIST", dict(img_size=28)), pytest.param("MNIST", dict(img_size=512), marks=pytest.mark.slow), ("WikiText", dict(tokenizer=bert_tokenizer)), pytest.param("TinyWikiText", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),