Skip to content

Commit a88e335

Browse files
committed
GAN testing
1 parent 133f8c1 commit a88e335

File tree

4 files changed

+284
-0
lines changed

4 files changed

+284
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ __pycache__
44
*.lock
55
xcuserdata
66
*.py.cfg
7+
.python-version

gan/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
data
2+
dataset
3+
__pycache__
4+
runs
5+

gan/gan.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
from torch import nn, optim
3+
from torch.utils.data import DataLoader
4+
from torchvision import transforms, datasets
5+
from utils import Logger
6+
7+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8+
9+
def minst_data():
10+
compose = transforms.Compose(
11+
[transforms.ToTensor(),
12+
transforms.Normalize((.5, .5, .5), (.5, .5, .5))])
13+
out_dir = './dataset'
14+
return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)
15+
16+
def images_to_vectors(images):
17+
return images.view(images.size(0), 784)
18+
19+
def vectors_to_images(vectors):
20+
return vectors.view(vectors.size(0), 1, 28, 28)
21+
22+
GEN_N_FEATURES = 100
23+
24+
def noise(size):
25+
n = torch.randn(size, GEN_N_FEATURES).to(DEVICE)
26+
return n
27+
28+
def ones_target(size):
29+
n = torch.ones(size, 1).to(DEVICE)
30+
return n
31+
32+
def zeros_target(size):
33+
n = torch.zeros(size, 1).to(DEVICE)
34+
return n
35+
36+
class DiscriminatorNet(torch.nn.Module):
37+
def __init__(self):
38+
super(DiscriminatorNet, self).__init__()
39+
n_features = 784
40+
n_out = 1
41+
self.hidden0 = nn.Sequential(
42+
nn.Linear(n_features, 1024),
43+
nn.LeakyReLU(0.2),
44+
nn.Dropout(0.3))
45+
self.hidden1 = nn.Sequential(
46+
nn.Linear(1024, 512),
47+
nn.LeakyReLU(0.2),
48+
nn.Dropout(0.3))
49+
self.hidden2 = nn.Sequential(
50+
nn.Linear(512, 256),
51+
nn.LeakyReLU(0.2),
52+
nn.Dropout(0.3))
53+
self.out = nn.Sequential(
54+
nn.Linear(256, n_out),
55+
nn.Sigmoid())
56+
57+
def forward(self, x):
58+
x = self.hidden0(x)
59+
x = self.hidden1(x)
60+
x = self.hidden2(x)
61+
x = self.out(x)
62+
return x
63+
64+
class GeneratorNet(torch.nn.Module):
65+
def __init__(self):
66+
super(GeneratorNet, self).__init__()
67+
n_out = 784
68+
self.hidden0 = nn.Sequential(
69+
nn.Linear(GEN_N_FEATURES, 256),
70+
nn.LeakyReLU(0.2))
71+
self.hidden1 = nn.Sequential(
72+
nn.Linear(256, 512),
73+
nn.LeakyReLU(0.2))
74+
self.hidden2 = nn.Sequential(
75+
nn.Linear(512, 1024),
76+
nn.LeakyReLU(0.2))
77+
self.out = nn.Sequential(
78+
nn.Linear(1024, n_out),
79+
nn.Tanh())
80+
81+
def forward(self, x):
82+
x = self.hidden0(x)
83+
x = self.hidden1(x)
84+
x = self.hidden2(x)
85+
x = self.out(x)
86+
return x
87+
88+
def train_discriminator(generator, discriminator, loss, optimizer, real_data):
89+
N = real_data.size(0)
90+
optimizer.zero_grad()
91+
# Training with real data
92+
prediction_real = discriminator(real_data)
93+
error_real = loss(prediction_real, ones_target(N))
94+
error_real.backward()
95+
# Train on Fake data
96+
fake_data = generator(noise(N)).detach()
97+
prediction_fake = discriminator(fake_data)
98+
error_fake = loss(prediction_fake, zeros_target(N))
99+
error_fake.backward()
100+
101+
optimizer.step()
102+
return error_real + error_fake, prediction_real, prediction_fake
103+
104+
def train_generator(generator, discriminator, loss, optimizer, N):
105+
optimizer.zero_grad()
106+
fake_data = generator(noise(N))
107+
prediction = discriminator(fake_data)
108+
error = loss(prediction, ones_target(N))
109+
error.backward()
110+
optimizer.step()
111+
return error
112+
113+
if __name__ == '__main__':
114+
data = minst_data()
115+
data_loader = DataLoader(data, batch_size=100, shuffle=True)
116+
num_batches = len(data_loader)
117+
discriminator = DiscriminatorNet().to(DEVICE)
118+
generator = GeneratorNet().to(DEVICE)
119+
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
120+
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
121+
loss = nn.BCELoss()
122+
123+
num_test_samples = 16
124+
test_noise = noise(num_test_samples)
125+
126+
logger = Logger(model_name='VGAN', data_name='MNIST')
127+
num_epochs = 200
128+
129+
for epoch in range(num_epochs):
130+
for n_batch, (real_batch,_) in enumerate(data_loader):
131+
N = real_batch.size(0)
132+
# Train Discriminator
133+
real_data = images_to_vectors(real_batch).to(DEVICE)
134+
d_error, d_pred_real, d_pred_fake = train_discriminator(
135+
generator, discriminator, loss, d_optimizer, real_data)
136+
g_error = train_generator(generator, discriminator, loss, g_optimizer, N)
137+
logger.log(d_error, g_error, epoch, n_batch, num_batches)
138+
if n_batch % 100 == 0:
139+
test_images = vectors_to_images(generator(test_noise)).cpu()
140+
logger.log_images(test_images.data, num_test_samples, epoch, n_batch, num_batches)
141+
logger.display_status(epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake)
142+

gan/utils.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
import numpy as np
3+
import errno
4+
import torchvision.utils as vutils
5+
from tensorboardX import SummaryWriter
6+
from IPython import display
7+
from matplotlib import pyplot as plt
8+
import torch
9+
10+
'''
11+
TensorBoard Data will be stored in './runs' path
12+
'''
13+
14+
class Logger:
15+
def __init__(self, model_name, data_name):
16+
self.model_name = model_name
17+
self.data_name = data_name
18+
19+
self.comment = '{}_{}'.format(model_name, data_name)
20+
self.data_subdir = '{}/{}'.format(model_name, data_name)
21+
22+
# TensorBoard
23+
self.writer = SummaryWriter(comment=self.comment)
24+
25+
def log(self, d_error, g_error, epoch, n_batch, num_batches):
26+
27+
# var_class = torch.autograd.variable.Variable
28+
if isinstance(d_error, torch.autograd.Variable):
29+
d_error = d_error.data.cpu().numpy()
30+
if isinstance(g_error, torch.autograd.Variable):
31+
g_error = g_error.data.cpu().numpy()
32+
33+
step = Logger._step(epoch, n_batch, num_batches)
34+
self.writer.add_scalar(
35+
'{}/D_error'.format(self.comment), d_error, step)
36+
self.writer.add_scalar(
37+
'{}/G_error'.format(self.comment), g_error, step)
38+
39+
def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
40+
'''
41+
input images are expected in format (NCHW)
42+
'''
43+
if type(images) == np.ndarray:
44+
images = torch.from_numpy(images)
45+
46+
if format=='NHWC':
47+
images = images.transpose(1,3)
48+
49+
50+
step = Logger._step(epoch, n_batch, num_batches)
51+
img_name = '{}/images{}'.format(self.comment, '')
52+
53+
# Make horizontal grid from image tensor
54+
horizontal_grid = vutils.make_grid(
55+
images, normalize=normalize, scale_each=True)
56+
# Make vertical grid from image tensor
57+
nrows = int(np.sqrt(num_images))
58+
grid = vutils.make_grid(
59+
images, nrow=nrows, normalize=True, scale_each=True)
60+
61+
# Add horizontal images to tensorboard
62+
self.writer.add_image(img_name, horizontal_grid, step)
63+
64+
# Save plots
65+
self.save_torch_images(horizontal_grid, grid, epoch, n_batch)
66+
67+
def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
68+
out_dir = './data/images/{}'.format(self.data_subdir)
69+
Logger._make_dir(out_dir)
70+
71+
# Plot and save horizontal
72+
fig = plt.figure(figsize=(16, 16))
73+
plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
74+
plt.axis('off')
75+
if plot_horizontal:
76+
display.display(plt.gcf())
77+
self._save_images(fig, epoch, n_batch, 'hori')
78+
plt.close()
79+
80+
# Save squared
81+
fig = plt.figure()
82+
plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
83+
plt.axis('off')
84+
self._save_images(fig, epoch, n_batch)
85+
plt.close()
86+
87+
def _save_images(self, fig, epoch, n_batch, comment=''):
88+
out_dir = './data/images/{}'.format(self.data_subdir)
89+
Logger._make_dir(out_dir)
90+
fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(
91+
out_dir, comment, epoch, n_batch))
92+
93+
def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
94+
95+
# var_class = torch.autograd.variable.Variable
96+
if isinstance(d_error, torch.autograd.Variable):
97+
d_error = d_error.data.cpu().numpy()
98+
if isinstance(g_error, torch.autograd.Variable):
99+
g_error = g_error.data.cpu().numpy()
100+
if isinstance(d_pred_real, torch.autograd.Variable):
101+
d_pred_real = d_pred_real.data
102+
if isinstance(d_pred_fake, torch.autograd.Variable):
103+
d_pred_fake = d_pred_fake.data
104+
105+
106+
print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
107+
epoch,num_epochs, n_batch, num_batches))
108+
print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
109+
print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))
110+
111+
def save_models(self, generator, discriminator, epoch):
112+
out_dir = './data/models/{}'.format(self.data_subdir)
113+
Logger._make_dir(out_dir)
114+
torch.save(
115+
generator.state_dict(),
116+
'{}/G_epoch_{}'.format(out_dir, epoch))
117+
torch.save(
118+
discriminator.state_dict(),
119+
'{}/D_epoch_{}'.format(out_dir, epoch))
120+
121+
def close(self):
122+
self.writer.close()
123+
124+
# Private Functionality
125+
126+
@staticmethod
127+
def _step(epoch, n_batch, num_batches):
128+
return epoch * num_batches + n_batch
129+
130+
@staticmethod
131+
def _make_dir(directory):
132+
try:
133+
os.makedirs(directory)
134+
except OSError as e:
135+
if e.errno != errno.EEXIST:
136+
raise

0 commit comments

Comments
 (0)