An implementation of InfoGAN using JAX, a highly performant NumPy replacement from Google Research with added capabilities for automatic differentiation and just-in-time compilation of Python code. Flax is used for machiene learning on top of JAX and is also provided by Google Research.
InfoGAN is an extension on top of a normal GAN that adjusts the loss function to maximize the mutual information between a small subset of inputs to the Generateor and a new output added to the Discriminator. The idea is that if the Discriminator can predict an aspect of the Generator's input then that part of the Generator's input will correspond to some semantically meaningful part of the image.
In this implementation, only one of the Generator's input dimensions is analyzed, which ends up correlating to which number 0-9 is generated.
Example output is provided in the notebook at MNIST_GAN_Jax.ipynb. Running all cells up until the "Run Model Below" section will create a directory ./saved_models/mnist_gan
. Copying checkpoint_130 into that directory and then running the first cell in "Run Model Below" will result in the pretrained model being run. To train the model more, increase num_epochs
above 130
The first row is data from the MNIST dataset while each of the other rows has a different value 0-9 for the categorical input to the Generator. It's definitely not perfect, but there is clearly a strong relationship between that input and which number is displayed. It's also important to note how the 10 different possible categorical inputs end up including all 10 possible numbers from MNIST.