Skip to content

Commit

Permalink
Refactor MNIST example to use Flax
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax authored and rlouf committed Sep 30, 2022
1 parent 58d0076 commit e7b882d
Showing 1 changed file with 28 additions and 54 deletions.
82 changes: 28 additions & 54 deletions examples/SGMCMC.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.0
jupytext_version: 1.14.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -21,10 +21,12 @@ This example is inspired form [this notebook](https://github.com/jeremiecoullon/

```{code-cell} ipython3
import jax
import jax.nn as nn
import jax.numpy as jnp
import jax.scipy.stats as stats
import flax.linen as nn
import distrax
import numpy as np
from functools import partial
```

## Data Preparation
Expand Down Expand Up @@ -91,82 +93,54 @@ We will use a very simple (bayesian) neural network in this example: A MLP with
```

```{code-cell} ipython3
class NN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=100)(x)
x = nn.softmax(x)
x = nn.Dense(features=10)(x)
return nn.log_softmax(x)
model = NN()
@jax.jit
def predict_fn(parameters, X):
def predict_fn(params, X):
"""Returns the probability for the image represented by X
to be in each category given the MLP's weights vakues.
"""
activations = X
for W, b in parameters[:-1]:
outputs = jnp.dot(W, activations) + b
activations = nn.softmax(outputs)
return model.apply(params, X)
final_W, final_b = parameters[-1]
logits = jnp.dot(final_W, activations) + final_b
return nn.log_softmax(logits)
def logprior_fn(parameters):
def logprior_fn(params):
"""Compute the value of the log-prior density function."""
logprob = 0.0
for W, b in parameters:
logprob += jnp.sum(stats.norm.logpdf(W))
logprob += jnp.sum(stats.norm.logpdf(b))
return logprob
leaves, _ = jax.tree_util.tree_flatten(params)
flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])
return jnp.sum(distrax.Normal(0.0, 1.0).log_prob(flat_params))
def loglikelihood_fn(parameters, data):
def loglikelihood_fn(params, data):
"""Categorical log-likelihood"""
X, y = data
return jnp.sum(y * predict_fn(parameters, X))
return jnp.sum(y * predict_fn(params, X))
@jax.jit
def compute_accuracy(parameters, X, y):
def compute_accuracy(params, X, y):
"""Compute the accuracy of the model.
To make predictions we take the number that corresponds to the highest probability value.
"""
target_class = jnp.argmax(y, axis=1)
predicted_class = jnp.argmax(
jax.vmap(predict_fn, in_axes=(None, 0))(parameters, X), axis=1
)
predicted_class = jnp.argmax(predict_fn(params, X), axis=1)
return jnp.mean(predicted_class == target_class)
```

## Sample From the Posterior Distribution of the Perceptron's Weights

Now we need to get initial values for the parameters, and we simply sample from their prior distribution:

```{code-cell} ipython3
def init_parameters(rng_key, sizes):
"""
Parameter
----------
rng_key
PRNGKey used by JAX to generate pseudo-random numbers
sizes
List of size for the subsequent layers. The first size must correspond
to the size of the input data and the last one to the number of
categories.
"""
num_layers = len(sizes)
keys = jax.random.split(rng_key, num_layers)
return [
init_layer(rng_key, m, n) for rng_key, m, n in zip(keys, sizes[:-1], sizes[1:])
]
def init_layer(rng_key, m, n, scale=1e-2):
"""Initialize the weights for a single layer."""
key_W, key_b = jax.random.split(rng_key)
return (scale * jax.random.normal(key_W, (n, m))), scale * jax.random.normal(
key_b, (n,)
)
```
+++

We now sample from the model's posteriors. We discard the first 1000 samples until the sampler has reached the typical set, and then take 2000 samples. We record the model's accuracy with the current values every 100 steps.

Expand Down Expand Up @@ -196,7 +170,7 @@ grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn, lambda _: step_size)
# Set the initial state
init_positions = init_parameters(rng_key, layer_sizes)
init_positions = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1]))
state = sgld.init(init_positions, next(batches))
# Sample from the posterior
Expand Down Expand Up @@ -231,7 +205,7 @@ ax.set_xlim([0, num_warmup + num_samples])
ax.set_ylim([0, 1])
ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9])
plt.title("Sample from 3-layer MLP posterior (MNIST dataset) with SgLD")
plt.plot()
plt.plot();
```

```{code-cell} ipython3
Expand Down

0 comments on commit e7b882d

Please sign in to comment.