1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517from typing import Tuple
1618
1719from datasets import load_dataset
1820from torch .utils .data import Dataset
1921
20- from pruna .data .utils import split_train_into_train_val , split_val_into_val_test , stratify_dataset
22+ from pruna .data .utils import (
23+ define_sample_size_for_dataset ,
24+ split_train_into_train_val ,
25+ split_val_into_val_test ,
26+ stratify_dataset ,
27+ )
2128
2229
23- def setup_mnist_dataset (seed : int , fraction : float = 1.0 ) -> Tuple [Dataset , Dataset , Dataset ]:
30+ def setup_mnist_dataset (
31+ seed : int , fraction : float = 1.0 , train_sample_size : int | None = None , test_sample_size : int | None = None
32+ ) -> Tuple [Dataset , Dataset , Dataset ]:
2433 """
2534 Setup the MNIST dataset.
2635
@@ -30,9 +39,12 @@ def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Data
3039 ----------
3140 seed : int
3241 The seed to use.
33-
3442 fraction : float
3543 The fraction of the dataset to use.
44+ train_sample_size : int | None
45+ The sample size to use for the train dataset.
46+ test_sample_size : int | None
47+ The sample size to use for the test dataset.
3648
3749 Returns
3850 -------
@@ -41,16 +53,21 @@ def setup_mnist_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Data
4153 """
4254 train_ds , test_ds = load_dataset ("ylecun/mnist" , split = ["train" , "test" ]) # type: ignore[misc]
4355
44- train_ds = stratify_dataset (train_ds , "label" , fraction , seed )
45- test_ds = stratify_dataset (test_ds , "label" , fraction , seed )
56+ train_sample_size = define_sample_size_for_dataset (train_ds , fraction , train_sample_size )
57+ test_sample_size = define_sample_size_for_dataset (test_ds , fraction , test_sample_size )
58+
59+ train_ds = stratify_dataset (train_ds , train_sample_size , seed )
60+ test_ds = stratify_dataset (test_ds , test_sample_size , seed )
4661
4762 train_ds , val_ds = split_train_into_train_val (train_ds , seed )
4863 val_ds , test_ds = split_val_into_val_test (val_ds , seed )
4964
5065 return train_ds , val_ds , test_ds # type: ignore[return-value]
5166
5267
53- def setup_imagenet_dataset (seed : int , fraction : float = 1.0 ) -> Tuple [Dataset , Dataset , Dataset ]:
68+ def setup_imagenet_dataset (
69+ seed : int , fraction : float = 1.0 , train_sample_size : int | None = None , test_sample_size : int | None = None
70+ ) -> Tuple [Dataset , Dataset , Dataset ]:
5471 """
5572 Setup the ImageNet dataset.
5673
@@ -60,23 +77,30 @@ def setup_imagenet_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, D
6077 ----------
6178 seed : int
6279 The seed to use.
63-
6480 fraction : float
6581 The fraction of the dataset to use.
82+ train_sample_size : int | None
83+ The sample size to use for the train dataset.
84+ test_sample_size : int | None
85+ The sample size to use for the test dataset.
6686
6787 Returns
6888 -------
6989 Tuple[Dataset, Dataset, Dataset]
7090 The ImageNet dataset.
7191 """
7292 train_ds , val = load_dataset ("zh-plus/tiny-imagenet" , split = ["train" , "valid" ]) # type: ignore[misc]
73- train_ds = stratify_dataset (train_ds , "label" , fraction , seed )
74- val = stratify_dataset (val , "label" , fraction , seed )
93+ train_sample_size = define_sample_size_for_dataset (train_ds , fraction , train_sample_size )
94+ train_ds = stratify_dataset (train_ds , train_sample_size , seed )
7595 val_ds , test_ds = split_val_into_val_test (val , seed )
96+ test_sample_size = define_sample_size_for_dataset (test_ds , fraction , test_sample_size )
97+ test_ds = stratify_dataset (test_ds , test_sample_size , seed )
7698 return train_ds , val_ds , test_ds # type: ignore[return-value]
7799
78100
79- def setup_cifar10_dataset (seed : int , fraction : float = 1.0 ) -> Tuple [Dataset , Dataset , Dataset ]:
101+ def setup_cifar10_dataset (
102+ seed : int , fraction : float = 1.0 , train_sample_size : int | None = None , test_sample_size : int | None = None
103+ ) -> Tuple [Dataset , Dataset , Dataset ]:
80104 """
81105 Setup the CIFAR-10 dataset.
82106
@@ -90,9 +114,12 @@ def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Da
90114 ----------
91115 seed : int
92116 The seed to use.
93-
94117 fraction : float
95118 The fraction of the dataset to use.
119+ train_sample_size : int | None
120+ The sample size to use for the train dataset.
121+ test_sample_size : int | None
122+ The sample size to use for the test dataset.
96123
97124 Returns
98125 -------
@@ -106,8 +133,11 @@ def setup_cifar10_dataset(seed: int, fraction: float = 1.0) -> Tuple[Dataset, Da
106133 train_ds = train_ds .rename_column ("img" , "image" )
107134 test_ds = test_ds .rename_column ("img" , "image" )
108135
109- train_ds = stratify_dataset (train_ds , "label" , fraction , seed )
110- test_ds = stratify_dataset (test_ds , "label" , fraction , seed )
136+ train_sample_size = define_sample_size_for_dataset (train_ds , fraction , train_sample_size )
137+ test_sample_size = define_sample_size_for_dataset (test_ds , fraction , test_sample_size )
138+
139+ train_ds = stratify_dataset (train_ds , train_sample_size , seed )
140+ test_ds = stratify_dataset (test_ds , test_sample_size , seed )
111141
112142 train_ds , val_ds = split_train_into_train_val (train_ds , seed )
113143 return train_ds , val_ds , test_ds # type: ignore[return-value]
0 commit comments