-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathdataset.py
106 lines (81 loc) · 3.12 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""dataset.py"""
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
def is_power_of_2(num):
return ((num & (num - 1)) == 0) and num != 0
class CustomImageFolder(ImageFolder):
def __init__(self, root, transform=None):
super(CustomImageFolder, self).__init__(root, transform)
def __getitem__(self, index):
path = self.imgs[index][0]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img
class CustomTensorDataset(Dataset):
def __init__(self, data_tensor):
self.data_tensor = data_tensor
def __getitem__(self, index):
return self.data_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
def return_data(args):
name = args.dataset
dset_dir = args.dset_dir
batch_size = args.batch_size
num_workers = args.num_workers
image_size = args.image_size
assert image_size == 64, 'currently only image size of 64 is supported'
if name.lower() == '3dchairs':
root = os.path.join(dset_dir, '3DChairs')
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),])
train_kwargs = {'root':root, 'transform':transform}
dset = CustomImageFolder
elif name.lower() == 'celeba':
root = os.path.join(dset_dir, 'CelebA')
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),])
train_kwargs = {'root':root, 'transform':transform}
dset = CustomImageFolder
elif name.lower() == 'dsprites':
root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
if not os.path.exists(root):
import subprocess
print('Now download dsprites-dataset')
subprocess.call(['./download_dsprites.sh'])
print('Finished')
data = np.load(root, encoding='bytes')
data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
train_kwargs = {'data_tensor':data}
dset = CustomTensorDataset
else:
raise NotImplementedError
train_data = dset(**train_kwargs)
train_loader = DataLoader(train_data,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
data_loader = train_loader
return data_loader
if __name__ == '__main__':
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),])
dset = CustomImageFolder('data/CelebA', transform)
loader = DataLoader(dset,
batch_size=32,
shuffle=True,
num_workers=1,
pin_memory=False,
drop_last=True)
images1 = iter(loader).next()
import ipdb; ipdb.set_trace()