Skip to content

Commit

Permalink
commit for RRC
Browse files Browse the repository at this point in the history
  • Loading branch information
Exception4U committed Jul 7, 2017
1 parent 0cb4e44 commit f7f9325
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
import numpy as np
from six.moves import xrange
import cv2

from ops import *
from utils import *
Expand Down Expand Up @@ -91,7 +92,7 @@ def build_model(self):

self.sampler = self.sampler(self.z)
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)


self.d_sum = histogram_summary("d", self.D)
self.d__sum = histogram_summary("d_", self.D_)
Expand All @@ -103,7 +104,7 @@ def build_model(self):

self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)

self.d_loss = self.d_loss_real + self.d_loss_fake

self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
Expand All @@ -119,7 +120,9 @@ def build_model(self):
def train(self, config):
"""Train DCGAN"""
if config.dataset == 'mnist':
data_X, data_y = self.load_mnist()
# data_X, data_y = self.load_mnist()
#import pdb;pdb.set_trace()
data_X, data_y = self.load_mnist_handcrafted()
else:
data = glob(os.path.join("./data", config.dataset, "*.jpg"))
#np.random.shuffle(data)
Expand All @@ -137,7 +140,7 @@ def train(self, config):
self.writer = SummaryWriter("./logs", self.sess.graph)

sample_z = np.random.uniform(-1, 1, size=(self.sample_size , self.z_dim))

if config.dataset == 'mnist':
sample_images = data_X[0:self.sample_size]
sample_labels = data_y[0:self.sample_size]
Expand All @@ -146,9 +149,10 @@ def train(self, config):
sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop, resize_w=self.output_size, is_grayscale = self.is_grayscale) for sample_file in sample_files]
if (self.is_grayscale):
sample_images = np.array(sample).astype(np.float32)[:, :, :, None]

else:
sample_images = np.array(sample).astype(np.float32)

counter = 1
start_time = time.time()

Expand All @@ -160,7 +164,7 @@ def train(self, config):
for epoch in xrange(config.epoch):
if config.dataset == 'mnist':
batch_idxs = min(len(data_X), config.train_size) // config.batch_size
else:
else:
data = glob(os.path.join("./data", config.dataset, "*.jpg"))
batch_idxs = min(len(data), config.train_size) // config.batch_size

Expand Down Expand Up @@ -194,7 +198,7 @@ def train(self, config):
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z, self.y:batch_labels })
self.writer.add_summary(summary_str, counter)

errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.y:batch_labels})
errD_real = self.d_loss_real.eval({self.images: batch_images, self.y:batch_labels})
errG = self.g_loss.eval({self.z: batch_z, self.y:batch_labels})
Expand All @@ -213,7 +217,7 @@ def train(self, config):
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter)

errD_fake = self.d_loss_fake.eval({self.z: batch_z})
errD_real = self.d_loss_real.eval({self.images: batch_images})
errG = self.g_loss.eval({self.z: batch_z})
Expand Down Expand Up @@ -261,14 +265,14 @@ def discriminator(self, image, y=None, reuse=False):
h0 = conv_cond_concat(h0, yb)

h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
h1 = tf.reshape(h1, [self.batch_size, -1])
h1 = tf.reshape(h1, [self.batch_size, -1])
h1 = tf.concat(1, [h1, y])

h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
h2 = tf.concat(1, [h2, y])

h3 = linear(h2, 1, 'd_h3_lin')

return tf.nn.sigmoid(h3), h3

def generator(self, z, y=None):
Expand All @@ -282,7 +286,7 @@ def generator(self, z, y=None):
self.h0 = tf.reshape(self.z_, [-1, s16, s16, self.gf_dim * 8])
h0 = tf.nn.relu(self.g_bn0(self.h0))

self.h1, self.h1_w, self.h1_b = deconv2d(h0,
self.h1, self.h1_w, self.h1_b = deconv2d(h0,
[self.batch_size, s8, s8, self.gf_dim*4], name='g_h1', with_w=True)
h1 = tf.nn.relu(self.g_bn1(self.h1))

Expand All @@ -300,7 +304,7 @@ def generator(self, z, y=None):
return tf.nn.tanh(h4)
else:
s = self.output_size
s2, s4 = int(s/2), int(s/4)
s2, s4 = int(s/2), int(s/4)

# yb = tf.expand_dims(tf.expand_dims(y, 1),2)
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
Expand All @@ -323,7 +327,7 @@ def sampler(self, z, y=None):
tf.get_variable_scope().reuse_variables()

if not self.y_dim:

s = self.output_size
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

Expand Down Expand Up @@ -366,10 +370,11 @@ def sampler(self, z, y=None):

def load_mnist(self):
data_dir = os.path.join("./data", self.dataset_name)

fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)
print trX.shape

fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
Expand All @@ -385,22 +390,41 @@ def load_mnist(self):

trY = np.asarray(trY)
teY = np.asarray(teY)

X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0)

seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)

y_vec = np.zeros((len(y), self.y_dim), dtype=np.float)
for i, label in enumerate(y):
y_vec[i,y[i]] = 1.0

return X/255.,y_vec


def load_mnist_handcrafted(self):
X=[]
for j in range(1,5500+1):
for i in range(1,7):
image = cv2.imread("data/mnist_handcrafted/fiveBin"+str(i)+".png",0)
image=image/255
X.append(image)
X=np.array(X)
X= X.reshape((X.shape[0],28,28,1)).astype(np.float)
print X.shape
Y = [np.array([0,0,0,0,1,0,0,0,0,0]) for _ in range(1,X.shape[0]+1)]
print len(Y)
Y=np.array(Y)
print Y.shape
#import pdb;pdb.set_trace()
Y= Y.reshape(X.shape[0],10).astype(np.float)

return X,Y

def save(self, checkpoint_dir, step):
model_name = "DCGAN.model"
model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
Expand Down

0 comments on commit f7f9325

Please sign in to comment.