Skip to content

sbavon/Deep-Feature-Consistent-Variational-AutoEncoder-in-Tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Deep Feature Consistent Variational Autoencoder in Tensorflow

This repository has an objective to implement Deep Feature Consisten Variational Autoencoder (DFC-VAE) according to Deep Feature Consistent Variational Autoencoder. Tensorflow and Python3 are used for development, and pre-trained VGG16 is adapted from VGG in TensorFlow. The training data is CelebA dataset.

To understand this following note, I would recommend to know the concept of Variational Autoencoder and generative model.

Results

Generated Image Random Image interpolated Image

Figure 3: Interpolated image

Problem Statement

It is known that one major problem of plain Variational Autoencoder (Plain-VAE) is that images generated by the model are blurry. This is because the plain model's loss function is defined by pixel-wise comparison between input images and generated images. As a consequence, optimizing model to achieve a great performance is difficult because slightly shifting or distorting those images can result in a very high loss. In other words, even the images have just slight difference in human eyes, computer treats that a big difference!

distorted image

However, with DFC-VAE, the model leverages perceptual loss used in Neural Style Transfer. With regard to this paper, internal representations of convolutional neural networks could capture a content of the input image. This finding leads to the concept of perceptual loss, which compares the content - hidden representation - between images as oppose to calculate euclidean distant among pixels.

model architecture

Implementation

The solution contains four files

File Name Description
dfc_vae_model.py builds the VAE model, including encoder,decoder, VGG, loss function, and optimizer
train_dfc_vae.py trains the DFC_VAE model, and tests interpolation
vgg16.py builds the pre-trained VGG16 model
util.py contains supporting functions, such as data-preprocessing

Step-by-Step execution

Download and preprocess data

  1. Download pre-trained VGG weights from VGG in TensorFlow
  2. Download CelebA dataset from CelebA dataset
  3. Compress data in Zip
  4. Process images (crop and resize) and convert them to TFRecord format (refer to write_tfrecord() function in util.py)

Train the model

  1. Run train_dfc_vae.py

Dependencies

  • scipy.misc
  • zipfile (used for reading content inside Zip file)
  • imageio (used for generating .gif file)

Tips

  • Beta value is extremely significant. You need to adjust the value to make sure the model produce a great result
  • Save file in .png format for a better quality image

References