-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
23 lines (15 loc) · 867 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torchvision.datasets as datasets
import torchvision.transforms as transforms
def get_dataset(data_name, data_root, image_size, train):
transform = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
if data_name == "mnist":
dataset = datasets.MNIST(root=data_root, train=train, transform=transform, download=True)
elif data_name == "fashion-mnist":
dataset = datasets.FashionMNIST(root=data_root, train=train, transform=transform, download=True)
elif data_name == "kmnist":
dataset = datasets.KMNIST(root=data_root, train=train, transform=transform, download=True)
elif data_name == "emnist":
dataset = datasets.EMNIST(root=data_root, split="balanced", train=train, transform=transform, download=True)
else:
dataset = None
return dataset