Skip to content

Commit

Permalink
Added support for grayscale images
Browse files Browse the repository at this point in the history
  • Loading branch information
larry-he committed Aug 20, 2016
1 parent da53549 commit cb0761e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("image_size", 108, "The size of image to use (will be center cropped) [108]")
flags.DEFINE_integer("output_size", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("c_dim", 3, "Dimension of image color. [3]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
Expand All @@ -36,7 +37,7 @@ def main(_):
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28, c_dim=1,
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
else:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, output_size=FLAGS.output_size,
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, output_size=FLAGS.output_size, c_dim=FLAGS.c_dim
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)

if FLAGS.is_train:
Expand Down
22 changes: 13 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ def __init__(self, sess, image_size=108, is_crop=True,
z_dim: (optional) Dimension of dim for Z. [100]
gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
gfc_dim: (optional) Dimension of gen untis for for fully connected layer. [1024]
gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024]
dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024]
c_dim: (optional) Dimension of image color. [3]
Note that changing c_dim from its default value may require changing the
parameters of scipy.misc.imread() in the function imread() in utils.py and
reshaping batch_images in the train() function in model.py, among other things.
c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
"""
self.sess = sess
self.is_crop = is_crop
self.is_grayscale = (c_dim == 1)
self.batch_size = batch_size
self.image_size = image_size
self.sample_size = sample_size
Expand Down Expand Up @@ -145,8 +143,11 @@ def train(self, config):
sample_labels = data_y[0:self.sample_size]
else:
sample_files = data[0:self.sample_size]
sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop, resize_w=self.output_size) for sample_file in sample_files]
sample_images = np.array(sample).astype(np.float32)
sample = [get_image(sample_file, self.image_size, is_crop=self.is_crop, resize_w=self.output_size, is_grayscale = self.is_grayscale) for sample_file in sample_files]
if (self.is_grayscale):
sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
else:
sample_images = np.array(sample).astype(np.float32)

counter = 1
start_time = time.time()
Expand All @@ -169,8 +170,11 @@ def train(self, config):
batch_labels = data_y[idx*config.batch_size:(idx+1)*config.batch_size]
else:
batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size]
batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop, resize_w=self.output_size) for batch_file in batch_files]
batch_images = np.array(batch).astype(np.float32)
batch = [get_image(batch_file, self.image_size, is_crop=self.is_crop, resize_w=self.output_size, is_grayscale = self.is_grayscale) for batch_file in batch_files]
if (self.is_grayscale):
batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
else:
batch_images = np.array(batch).astype(np.float32)

batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
.astype(np.float32)
Expand Down
11 changes: 7 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])

def get_image(image_path, image_size, is_crop=True, resize_w=64):
return transform(imread(image_path), image_size, is_crop, resize_w)
def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)

def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)

def imread(path):
return scipy.misc.imread(path).astype(np.float)
def imread(path, is_grayscale = False):
if (is_grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)

def merge_images(images, size):
return inverse_transform(images)
Expand Down

0 comments on commit cb0761e

Please sign in to comment.