Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ possibly-missing-attribute = "ignore" # mypy is more permissive with attribute a
possibly-missing-import = "ignore" # mypy is more permissive with imports
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
missing-argument = "ignore"

[tool.coverage.run]
source = ["src/pruna"]
Expand Down
2 changes: 2 additions & 0 deletions src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
setup_cifar10_dataset,
setup_imagenet_dataset,
setup_mnist_dataset,
setup_tiny_cifar10_dataset,
)
from pruna.data.datasets.prompt import (
setup_drawbench_dataset,
Expand Down Expand Up @@ -77,6 +78,7 @@
"image_classification_collate",
{"img_size": 32},
),
"TinyCIFAR10": (setup_tiny_cifar10_dataset, "image_classification_collate", {"img_size": 32}),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see us re-using somethign like get_tiny(setup_cifar10_dataset) or something.

"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),
Expand Down
48 changes: 46 additions & 2 deletions src/pruna/data/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the CIFAR-10 dataset.

The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column,
but this function renames it to 'image' to ensure compatibility with
the image_classification_collate function which expects an 'image' column.

License: unspecified

Parameters
Expand All @@ -76,8 +80,48 @@ def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
Returns
-------
Tuple[Dataset, Dataset, Dataset]
The CIFAR-10 dataset.
The CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int).
"""
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) # type: ignore[misc]
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"])

# Rename 'img' column to 'image' to match collate function expectations
# This ensures compatibility with image_classification_collate function
train_ds = train_ds.rename_column("img", "image")
test_ds = test_ds.rename_column("img", "image")

train_ds, val_ds = split_train_into_train_val(train_ds, seed)
return train_ds, val_ds, test_ds # type: ignore[return-value]


def setup_tiny_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Tiny CIFAR-10 dataset (< 1,000 samples).

The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column,
but this function renames it to 'image' to ensure compatibility with
the image_classification_collate function which expects an 'image' column.

License: unspecified

Parameters
----------
seed : int
The seed to use.

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The Tiny CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int).
Contains approximately 600 training samples, split validation, and 200 test samples.
"""
train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"])

# Rename 'img' column to 'image' to match collate function expectations
# This ensures compatibility with image_classification_collate function
train_ds = train_ds.rename_column("img", "image")
test_ds = test_ds.rename_column("img", "image")

tiny_train = train_ds.select(range(600))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we just getting a this specific smaller subset? Can't we generalise this approach across all datasets and create general logic for getting tiny versions? perhaps to be tackled in a seperate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this makes a lot of sense actually!

tiny_test = test_ds.select(range(200))
train_ds, val_ds = split_train_into_train_val(tiny_train, seed)
return train_ds, val_ds, tiny_test
1 change: 1 addition & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None:
pytest.param("LibriSpeech", dict(), marks=pytest.mark.slow),
pytest.param("AIPodcast", dict(), marks=pytest.mark.slow),
("ImageNet", dict(img_size=512)),
("TinyCIFAR10", dict(img_size=32)),
pytest.param("MNIST", dict(img_size=512), marks=pytest.mark.slow),
("WikiText", dict(tokenizer=bert_tokenizer)),
pytest.param("TinyWikiText", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow),
Expand Down