From e7b882d0682944524f053c52d3bc7691ef9d1c2b Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Thu, 29 Sep 2022 20:06:43 +0200 Subject: [PATCH] Refactor MNIST example to use Flax --- examples/SGMCMC.md | 82 ++++++++++++++++------------------------------ 1 file changed, 28 insertions(+), 54 deletions(-) diff --git a/examples/SGMCMC.md b/examples/SGMCMC.md index 94c29224e..4ed1fa1b2 100644 --- a/examples/SGMCMC.md +++ b/examples/SGMCMC.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.0 + jupytext_version: 1.14.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -21,10 +21,12 @@ This example is inspired form [this notebook](https://github.com/jeremiecoullon/ ```{code-cell} ipython3 import jax -import jax.nn as nn import jax.numpy as jnp import jax.scipy.stats as stats +import flax.linen as nn +import distrax import numpy as np +from functools import partial ``` ## Data Preparation @@ -91,47 +93,46 @@ We will use a very simple (bayesian) neural network in this example: A MLP with ``` ```{code-cell} ipython3 +class NN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(features=100)(x) + x = nn.softmax(x) + x = nn.Dense(features=10)(x) + return nn.log_softmax(x) + +model = NN() + + @jax.jit -def predict_fn(parameters, X): +def predict_fn(params, X): """Returns the probability for the image represented by X to be in each category given the MLP's weights vakues. - """ - activations = X - for W, b in parameters[:-1]: - outputs = jnp.dot(W, activations) + b - activations = nn.softmax(outputs) + return model.apply(params, X) - final_W, final_b = parameters[-1] - logits = jnp.dot(final_W, activations) + final_b - return nn.log_softmax(logits) - -def logprior_fn(parameters): +def logprior_fn(params): """Compute the value of the log-prior density function.""" - logprob = 0.0 - for W, b in parameters: - logprob += jnp.sum(stats.norm.logpdf(W)) - logprob += jnp.sum(stats.norm.logpdf(b)) - return logprob + leaves, _ = jax.tree_util.tree_flatten(params) + flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves]) + return jnp.sum(distrax.Normal(0.0, 1.0).log_prob(flat_params)) -def loglikelihood_fn(parameters, data): +def loglikelihood_fn(params, data): """Categorical log-likelihood""" X, y = data - return jnp.sum(y * predict_fn(parameters, X)) + return jnp.sum(y * predict_fn(params, X)) @jax.jit -def compute_accuracy(parameters, X, y): +def compute_accuracy(params, X, y): """Compute the accuracy of the model. To make predictions we take the number that corresponds to the highest probability value. """ target_class = jnp.argmax(y, axis=1) - predicted_class = jnp.argmax( - jax.vmap(predict_fn, in_axes=(None, 0))(parameters, X), axis=1 - ) + predicted_class = jnp.argmax(predict_fn(params, X), axis=1) return jnp.mean(predicted_class == target_class) ``` @@ -139,34 +140,7 @@ def compute_accuracy(parameters, X, y): Now we need to get initial values for the parameters, and we simply sample from their prior distribution: -```{code-cell} ipython3 -def init_parameters(rng_key, sizes): - """ - - Parameter - ---------- - rng_key - PRNGKey used by JAX to generate pseudo-random numbers - sizes - List of size for the subsequent layers. The first size must correspond - to the size of the input data and the last one to the number of - categories. - - """ - num_layers = len(sizes) - keys = jax.random.split(rng_key, num_layers) - return [ - init_layer(rng_key, m, n) for rng_key, m, n in zip(keys, sizes[:-1], sizes[1:]) - ] - - -def init_layer(rng_key, m, n, scale=1e-2): - """Initialize the weights for a single layer.""" - key_W, key_b = jax.random.split(rng_key) - return (scale * jax.random.normal(key_W, (n, m))), scale * jax.random.normal( - key_b, (n,) - ) -``` ++++ We now sample from the model's posteriors. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps. @@ -196,7 +170,7 @@ grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size) sgld = blackjax.sgld(grad_fn, lambda _: step_size) # Set the initial state -init_positions = init_parameters(rng_key, layer_sizes) +init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1])) state = sgld.init(init_positions, next(batches)) # Sample from the posterior @@ -231,7 +205,7 @@ ax.set_xlim([0, num_warmup + num_samples]) ax.set_ylim([0, 1]) ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9]) plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SgLD") -plt.plot() +plt.plot(); ``` ```{code-cell} ipython3