Skip to content

Commit

Permalink
smapler error fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Jan 2, 2016
1 parent 00031e9 commit ac59257
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 18 deletions.
File renamed without changes
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ DCGAN in Tensorflow

Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434). The referenced torch code can be found [here](https://github.com/soumith/dcgan.torch).

![alt tag](model.png)
![alt tag](DCGAN.png)


Prerequisites
Expand Down
33 changes: 17 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build_model(self):
self.z = tf.placeholder(tf.float32, [None, self.z_dim])

self.image_ = self.generator(self.z)
self.sampler = self.generator(self.z)
self.sampler = self.sampler(self.z)

self.D = self.discriminator(self.image)
self.D_ = self.discriminator(self.image_, reuse=True)
Expand Down Expand Up @@ -207,36 +207,37 @@ def generator(self, z, y=None):
def sampler(self, z, y=None):
tf.get_variable_scope().reuse_variables()

if y:
if self.y_dim:
yb = tf.reshape(y, [None, 1, 1, self.y_dim])
z = tf.concat(1, [z, y])

h0 = tf.nn.relu(self.bn0(linear(z, self.gfc_dim, 's_h0_lin')))
h0 = tf.nn.relu(self.bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
h0 = tf.concat(1, [h0, y])

h1 = tf.nn.relu(self.g_bn1(linear(z, self.gf_dim*2*7*7, 's_h1_lin')))
h1 = tf.nn.relu(self.g_bn1(linear(z, self.gf_dim*2*7*7, 'g_h1_lin')))
h1 = tf.reshape(h1, [None, 7, 7, self.gf_dim * 2])
h1 = conv_cond_concat(h1, yb)

h2 = tf.nn.relu(self.bn2(deconv2d(h1, self.gf_dim, name='h2')))
h2 = tf.nn.relu(self.bn2(deconv2d(h1, self.gf_dim, name='g_h2')))
h2 = conv_cond_concat(h2, yb)

return tf.nn.sigmoid(deconv2d(h2, self.c_dim, name='h3'))
return tf.nn.sigmoid(deconv2d(h2, self.c_dim, name='g_h3'))
else:
h0 = tf.nn.relu(self.g_bn0(linear(z, self.gf_dim*8*4*4, 's_h0_lin'),
train=False))
h0 = tf.reshape(h1, [None, 4, 4, self.gf_dim * 8])
# project `z` and reshape
h0 = tf.reshape(linear(z, self.gf_dim*8*4*4, 'g_h0_lin'),
[-1, 4, 4, self.gf_dim * 8])
h0 = tf.nn.relu(self.g_bn0(h0, train=False))

h1 = deconv2d(h0, [None, 8, 8, self.gf_dim*4], name='h1')
h1 = tf.relu(self.g_bn1(h1, train=False))
h1 = deconv2d(h0, [self.batch_size, 8, 8, self.gf_dim*4], name='g_h1')
h1 = tf.nn.relu(self.g_bn1(h1, train=False))

h2 = deconv2d(h1, [None, 16, 16, self.gf_dim*2], name='h2')
h2 = tf.relu(self.g_bn2(h2, train=False))
h2 = deconv2d(h1, [self.batch_size, 16, 16, self.gf_dim*2], name='g_h2')
h2 = tf.nn.relu(self.g_bn2(h2, train=False))

h3 = deconv2d(h2, [None, 16, 16, self.gf_dim*1], name='h3')
h3 = tf.relu(self.g_bn3(h3, train=False))
h3 = deconv2d(h2, [self.batch_size, 16, 16, self.gf_dim*1], name='g_h3')
h3 = tf.nn.relu(self.g_bn3(h3, train=False))

h4 = deconv2d(h3, [None, 64, 64, 3], name='h4')
h4 = deconv2d(h3, [None, 64, 64, 3], name='g_h4')

return tf.nn.tanh(h4)

Expand Down
3 changes: 2 additions & 1 deletion ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __call__(self, x, train=True):
return tf.nn.batch_norm_with_global_normalization(x,
mean,
variance,
self.beta, self.gamma,
self.beta,
self.gamma,
self.epsilon, True)

def binary_cross_entropy_with_logits(logits, targets, name=None):
Expand Down
10 changes: 10 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ def get_image(image_path):
def imread(path):
return scipy.misc.imread(path).astype(np.float)

def imsave(images, size, path):
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1], 3))

for idx, image in enumerate(images):
i = idx % size[1]
j = idx / size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return scipy.misc.imsave(path, img)

def center_crop(x, ph, pw=None):
if pw is None:
pw = ph
Expand Down

0 comments on commit ac59257

Please sign in to comment.