-
Notifications
You must be signed in to change notification settings - Fork 69
feat: add tiny datasets for lightweight experiments #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
a7ddb66
afdd4c6
cf09564
213dc34
1ebdd79
3ea62c5
7100732
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,10 @@ def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | |
| """ | ||
| Setup the CIFAR-10 dataset. | ||
|
|
||
| The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column, | ||
| but this function renames it to 'image' to ensure compatibility with | ||
| the image_classification_collate function which expects an 'image' column. | ||
|
|
||
| License: unspecified | ||
|
|
||
| Parameters | ||
|
|
@@ -76,8 +80,48 @@ def setup_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | |
| Returns | ||
| ------- | ||
| Tuple[Dataset, Dataset, Dataset] | ||
| The CIFAR-10 dataset. | ||
| The CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int). | ||
| """ | ||
| train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) # type: ignore[misc] | ||
| train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) | ||
|
|
||
| # Rename 'img' column to 'image' to match collate function expectations | ||
| # This ensures compatibility with image_classification_collate function | ||
| train_ds = train_ds.rename_column("img", "image") | ||
| test_ds = test_ds.rename_column("img", "image") | ||
|
|
||
| train_ds, val_ds = split_train_into_train_val(train_ds, seed) | ||
| return train_ds, val_ds, test_ds # type: ignore[return-value] | ||
|
|
||
|
|
||
| def setup_tiny_cifar10_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: | ||
| """ | ||
| Setup the Tiny CIFAR-10 dataset (< 1,000 samples). | ||
|
|
||
| The original CIFAR-10 dataset from uoft-cs/cifar10 has an 'img' column, | ||
| but this function renames it to 'image' to ensure compatibility with | ||
| the image_classification_collate function which expects an 'image' column. | ||
|
|
||
| License: unspecified | ||
|
|
||
| Parameters | ||
| ---------- | ||
| seed : int | ||
| The seed to use. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tuple[Dataset, Dataset, Dataset] | ||
| The Tiny CIFAR-10 dataset with columns: 'image' (PIL Image) and 'label' (int). | ||
| Contains approximately 600 training samples, split validation, and 200 test samples. | ||
| """ | ||
| train_ds, test_ds = load_dataset("uoft-cs/cifar10", split=["train", "test"]) | ||
|
|
||
| # Rename 'img' column to 'image' to match collate function expectations | ||
| # This ensures compatibility with image_classification_collate function | ||
| train_ds = train_ds.rename_column("img", "image") | ||
| test_ds = test_ds.rename_column("img", "image") | ||
|
|
||
| tiny_train = train_ds.select(range(600)) | ||
|
||
| tiny_test = test_ds.select(range(200)) | ||
| train_ds, val_ds = split_train_into_train_val(tiny_train, seed) | ||
| return train_ds, val_ds, tiny_test | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see us re-using somethign like
get_tiny(setup_cifar10_dataset)or something.