Skip to content

Commit 7100732

Browse files
committed
feat: add stratifying by sample size for image classification datasets
1 parent 3ea62c5 commit 7100732

File tree

3 files changed

+103
-27
lines changed

3 files changed

+103
-27
lines changed

src/pruna/data/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,24 @@
7878
"image_classification_collate",
7979
{"img_size": 32},
8080
),
81-
"TinyCIFAR10": (partial(setup_cifar10_dataset, fraction=0.1), "image_classification_collate", {"img_size": 32}),
82-
"TinyMNIST": (partial(setup_mnist_dataset, fraction=0.1), "image_classification_collate", {"img_size": 28}),
83-
"TinyImageNet": (partial(setup_imagenet_dataset, fraction=0.1), "image_classification_collate", {"img_size": 224}),
81+
# our full CIFAR10 has 50k train and 10k test
82+
"TinyCIFAR10": (
83+
partial(setup_cifar10_dataset, train_sample_size=800, test_sample_size=100),
84+
"image_classification_collate",
85+
{"img_size": 32},
86+
),
87+
# our full MNIST has 60k train and 10k test
88+
"TinyMNIST": (
89+
partial(setup_mnist_dataset, train_sample_size=800, test_sample_size=100),
90+
"image_classification_collate",
91+
{"img_size": 28},
92+
),
93+
# our full ImageNet has 100k train and 10k val
94+
"TinyImageNet": (
95+
partial(setup_imagenet_dataset, train_sample_size=1000, test_sample_size=100),
96+
"image_classification_collate",
97+
{"img_size": 224},
98+
),
8499
"DrawBench": (setup_drawbench_dataset, "prompt_collate", {}),
85100
"PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}),
86101
"GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}),

src/pruna/data/datasets/image.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from typing import Tuple
1618

1719
from datasets import load_dataset
1820
from torch.utils.data import Dataset
1921

20-
from pruna.data.utils import split_train_into_train_val, split_val_into_val_test, stratify_dataset
22+
from pruna.data.utils import (
23+
define_sample_size_for_dataset,
24+
split_train_into_train_val,
25+
split_val_into_val_test,
26+
stratify_dataset,
27+
)
2128

2229

23-
def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
30+
def setup_mnist_dataset(
31+
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
32+
) -> Tuple[Dataset, Dataset, Dataset]:
2433
"""
2534
Setup the MNIST dataset.
2635
@@ -30,9 +39,12 @@ def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Data
3039
----------
3140
seed : int
3241
The seed to use.
33-
3442
fraction : float
3543
The fraction of the dataset to use.
44+
train_sample_size : int | None
45+
The sample size to use for the train dataset.
46+
test_sample_size : int | None
47+
The sample size to use for the test dataset.
3648
3749
Returns
3850
-------
@@ -41,16 +53,21 @@ def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Data
4153
"""
4254
train_ds, test_ds = load_dataset("ylecun/mnist", split=["train", "test"]) # type: ignore[misc]
4355

44-
train_ds = stratify_dataset(train_ds, "label", fraction, seed)
45-
test_ds = stratify_dataset(test_ds, "label", fraction, seed)
56+
train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
57+
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)
58+
59+
train_ds = stratify_dataset(train_ds, train_sample_size, seed)
60+
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
4661

4762
train_ds, val_ds = split_train_into_train_val(train_ds, seed)
4863
val_ds, test_ds = split_val_into_val_test(val_ds, seed)
4964

5065
return train_ds, val_ds, test_ds # type: ignore[return-value]
5166

5267

53-
def setup_imagenet_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
68+
def setup_imagenet_dataset(
69+
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
70+
) -> Tuple[Dataset, Dataset, Dataset]:
5471
"""
5572
Setup the ImageNet dataset.
5673
@@ -60,23 +77,30 @@ def setup_imagenet_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, D
6077
----------
6178
seed : int
6279
The seed to use.
63-
6480
fraction : float
6581
The fraction of the dataset to use.
82+
train_sample_size : int | None
83+
The sample size to use for the train dataset.
84+
test_sample_size : int | None
85+
The sample size to use for the test dataset.
6686
6787
Returns
6888
-------
6989
Tuple[Dataset, Dataset, Dataset]
7090
The ImageNet dataset.
7191
"""
7292
train_ds, val = load_dataset("zh-plus/tiny-imagenet", split=["train", "valid"]) # type: ignore[misc]
73-
train_ds = stratify_dataset(train_ds, "label", fraction, seed)
74-
val = stratify_dataset(val, "label", fraction, seed)
93+
train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
94+
train_ds = stratify_dataset(train_ds, train_sample_size, seed)
7595
val_ds, test_ds = split_val_into_val_test(val, seed)
96+
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)
97+
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
7698
return train_ds, val_ds, test_ds # type: ignore[return-value]
7799

78100

79-
def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Dataset, Dataset]:
101+
def setup_cifar10_dataset(
102+
seed: int, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None
103+
) -> Tuple[Dataset, Dataset, Dataset]:
80104
"""
81105
Setup the CIFAR-10 dataset.
82106
@@ -90,9 +114,12 @@ def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Da
90114
----------
91115
seed : int
92116
The seed to use.
93-
94117
fraction : float
95118
The fraction of the dataset to use.
119+
train_sample_size : int | None
120+
The sample size to use for the train dataset.
121+
test_sample_size : int | None
122+
The sample size to use for the test dataset.
96123
97124
Returns
98125
-------
@@ -106,8 +133,11 @@ def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Da
106133
train_ds = train_ds.rename_column("img", "image")
107134
test_ds = test_ds.rename_column("img", "image")
108135

109-
train_ds = stratify_dataset(train_ds, "label", fraction, seed)
110-
test_ds = stratify_dataset(test_ds, "label", fraction, seed)
136+
train_sample_size = define_sample_size_for_dataset(train_ds, fraction, train_sample_size)
137+
test_sample_size = define_sample_size_for_dataset(test_ds, fraction, test_sample_size)
138+
139+
train_ds = stratify_dataset(train_ds, train_sample_size, seed)
140+
test_ds = stratify_dataset(test_ds, test_sample_size, seed)
111141

112142
train_ds, val_ds = split_train_into_train_val(train_ds, seed)
113143
return train_ds, val_ds, test_ds # type: ignore[return-value]

src/pruna/data/utils.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import random
1718
from typing import Any, Tuple, Union
1819

1920
import torch
@@ -183,28 +184,58 @@ def recover_text_from_dataloader(dataloader: DataLoader, tokenizer: Any) -> list
183184
return texts
184185

185186

186-
def stratify_dataset(dataset: Dataset, column: str, fraction: float, seed: int) -> Dataset:
187+
def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset:
187188
"""
188-
Stratify the dataset into a fraction of the dataset.
189+
Stratify the dataset into a specific size.
189190
190191
Parameters
191192
----------
192193
dataset : Dataset
193194
The dataset to stratify.
194-
column : str
195-
The column to stratify by.
196-
fraction : float
197-
The fraction of the dataset to stratify.
195+
sample_size : int
196+
The size to stratify.
198197
seed : int
199-
The seed to use for splitting the dataset.
198+
The seed to use for sampling the dataset.
200199
201200
Returns
202201
-------
203202
Dataset
204203
The stratified dataset.
205204
"""
206-
if fraction < 1.0:
207-
split_result = dataset.train_test_split(test_size=1 - fraction, stratify_by_column="label", seed=seed)
208-
dataset = split_result["train"]
209-
205+
dataset_length = len(dataset)
206+
if dataset_length < sample_size:
207+
pruna_logger.warning(
208+
"Dataset length is less than the size to stratify."
209+
f"Using the entire dataset. ({dataset_length} < {sample_size})"
210+
)
211+
return dataset
212+
213+
indices = list(range(dataset_length))
214+
random.Random(seed).shuffle(indices)
215+
selected_indices = indices[:sample_size]
216+
dataset = dataset.select(selected_indices)
210217
return dataset
218+
219+
220+
def define_sample_size_for_dataset(dataset: Dataset, fraction: float, sample_size: int | None = None) -> int:
221+
"""
222+
Define the sample size for the dataset.
223+
224+
Parameters
225+
----------
226+
dataset: Dataset
227+
The dataset to define the sample size for.
228+
fraction: float
229+
The fraction of the dataset to sample.
230+
sample_size: int | None
231+
The sample size to use.
232+
233+
Returns
234+
-------
235+
int
236+
The sample size for the dataset.
237+
"""
238+
if fraction < 1.0 and (sample_size is not None):
239+
raise ValueError("Fraction and sample sizes cannot be used together.")
240+
sample_size = int(len(dataset) * fraction) if fraction < 1.0 else sample_size or len(dataset)
241+
return sample_size

0 commit comments

Comments
 (0)