Skip to content

jacksontromero/MNIST-GAN-JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 

Repository files navigation

MNIST-GAN-JAX

An implementation of InfoGAN using JAX, a highly performant NumPy replacement from Google Research with added capabilities for automatic differentiation and just-in-time compilation of Python code. Flax is used for machiene learning on top of JAX and is also provided by Google Research.

InfoGAN

InfoGAN is an extension on top of a normal GAN that adjusts the loss function to maximize the mutual information between a small subset of inputs to the Generateor and a new output added to the Discriminator. The idea is that if the Discriminator can predict an aspect of the Generator's input then that part of the Generator's input will correspond to some semantically meaningful part of the image.

In this implementation, only one of the Generator's input dimensions is analyzed, which ends up correlating to which number 0-9 is generated.

Loading Pre-trained Model

Example output is provided in the notebook at MNIST_GAN_Jax.ipynb. Running all cells up until the "Run Model Below" section will create a directory ./saved_models/mnist_gan. Copying checkpoint_130 into that directory and then running the first cell in "Run Model Below" will result in the pretrained model being run. To train the model more, increase num_epochs above 130

Example Output

The first row is data from the MNIST dataset while each of the other rows has a different value 0-9 for the categorical input to the Generator. It's definitely not perfect, but there is clearly a strong relationship between that input and which number is displayed. It's also important to note how the 10 different possible categorical inputs end up including all 10 possible numbers from MNIST.

index

About

An implementation of InfoGAN using Jax

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published