Skip to content

Commit

Permalink
Create Flax MNIST
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstiNik committed May 14, 2024
1 parent 4c457c4 commit 62773b5
Showing 1 changed file with 381 additions and 0 deletions.
381 changes: 381 additions & 0 deletions examples/mnist_flax.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Recording flax training on MNIST\n",
"\n",
"The tutorial is taken from https://flax.readthedocs.io/en/latest/experimental/nnx/mnist_tutorial.html\n",
"and adapted using the `papyrus`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/konstantinnikolaou/Applications/miniconda3/envs/jax/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'neural_state'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01moptax\u001b[39;00m \u001b[38;5;66;03m# Optimizers\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow_datasets\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtfds\u001b[39;00m \u001b[38;5;66;03m# TFDS for MNIST\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpapyrus\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n",
"File \u001b[0;32m~/Repositories/papyrus/papyrus/__init__.py:25\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124;03mpapyrus: a lightweight Python library to record neural learning.\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;124;03mpapyrus measurements api.\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpapyrus\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m measurements, neural_state, utils\n\u001b[1;32m 27\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 28\u001b[0m measurements\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 29\u001b[0m utils\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 30\u001b[0m neural_state\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 31\u001b[0m ]\n",
"File \u001b[0;32m~/Repositories/papyrus/papyrus/neural_state/__init__.py:24\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124;03mpapyrus: a lightweight Python library to record neural learning.\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;124;03m-------\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mneural_state\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mneural_state\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NeuralState\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mneural_state\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mneural_state_creator\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NeuralStateCreator\n\u001b[1;32m 27\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 28\u001b[0m NeuralState\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 29\u001b[0m NeuralStateCreator\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 30\u001b[0m ]\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'neural_state'"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp # JAX NumPy\n",
"\n",
"from flax import linen as nn # The Linen API\n",
"from flax.training import train_state # Useful dataclass to keep train state\n",
"\n",
"import numpy as np # Ordinary NumPy\n",
"import optax # Optimizers\n",
"import tensorflow_datasets as tfds # TFDS for MNIST\n",
"\n",
"\n",
"import papyrus as pp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparations"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" \"\"\"A simple CNN model.\"\"\"\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
" x = nn.relu(x)\n",
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
" x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
" x = nn.relu(x)\n",
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
" x = x.reshape((x.shape[0], -1)) # flatten\n",
" x = nn.Dense(features=256)(x)\n",
" x = nn.relu(x)\n",
" x = nn.Dense(features=10)(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def cross_entropy_loss(*, logits, labels):\n",
" labels_onehot = jax.nn.one_hot(labels, num_classes=10)\n",
" return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(*, logits, labels):\n",
" loss = cross_entropy_loss(logits=logits, labels=labels)\n",
" accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n",
" metrics = {\n",
" 'loss': loss,\n",
" 'accuracy': accuracy,\n",
" }\n",
" return metrics"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def get_datasets():\n",
" \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n",
" ds_builder = tfds.builder('mnist')\n",
" ds_builder.download_and_prepare()\n",
" train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n",
" test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n",
" train_ds['image'] = jnp.float32(train_ds['image']) / 255.\n",
" test_ds['image'] = jnp.float32(test_ds['image']) / 255.\n",
" return train_ds, test_ds"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def create_train_state(rng, learning_rate, momentum):\n",
" \"\"\"Creates initial `TrainState`.\"\"\"\n",
" cnn = CNN()\n",
" params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n",
" tx = optax.sgd(learning_rate, momentum)\n",
" return train_state.TrainState.create(\n",
" apply_fn=cnn.apply, params=params, tx=tx)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def train_step(state, batch):\n",
" \"\"\"Train for a single step.\"\"\"\n",
" def loss_fn(params):\n",
" logits = CNN().apply({'params': params}, batch['image'])\n",
" loss = cross_entropy_loss(logits=logits, labels=batch['label'])\n",
" return loss, logits\n",
" grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
" (_, logits), grads = grad_fn(state.params)\n",
" state = state.apply_gradients(grads=grads)\n",
" metrics = compute_metrics(logits=logits, labels=batch['label'])\n",
" return state, metrics"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def eval_step(params, batch):\n",
" logits = CNN().apply({'params': params}, batch['image'])\n",
" return compute_metrics(logits=logits, labels=batch['label'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Include Recorders"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_recorder = "
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(state, train_ds, batch_size, epoch, rng):\n",
" \"\"\"Train for a single epoch.\"\"\"\n",
" train_ds_size = len(train_ds['image'])\n",
" steps_per_epoch = train_ds_size // batch_size\n",
"\n",
" perms = jax.random.permutation(rng, train_ds_size)\n",
" perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch\n",
" perms = perms.reshape((steps_per_epoch, batch_size))\n",
" batch_metrics = []\n",
" for perm in perms:\n",
" batch = {k: v[perm, ...] for k, v in train_ds.items()}\n",
" state, metrics = train_step(state, batch)\n",
" batch_metrics.append(metrics)\n",
"\n",
" # compute mean of metrics across each batch in epoch.\n",
" batch_metrics_np = jax.device_get(batch_metrics)\n",
" epoch_metrics_np = {\n",
" k: np.mean([metrics[k] for metrics in batch_metrics_np])\n",
" for k in batch_metrics_np[0]}\n",
"\n",
" print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (\n",
" epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))\n",
"\n",
" return state"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def eval_model(params, test_ds):\n",
" metrics = eval_step(params, test_ds)\n",
" metrics = jax.device_get(metrics)\n",
" summary = jax.tree_map(lambda x: x.item(), metrics)\n",
" return summary['loss'], summary['accuracy']"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"train_ds, test_ds = get_datasets()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"rng = jax.random.PRNGKey(0)\n",
"rng, init_rng = jax.random.split(rng)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 0.1\n",
"momentum = 0.9"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"state = create_train_state(init_rng, learning_rate, momentum)\n",
"del init_rng # Must not be used anymore."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"num_epochs = 10\n",
"batch_size = 32"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train epoch: 1, loss: 0.1414, accuracy: 95.80\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/6j/j_pvcb3n4tn448fvkngcwmvr0000gn/T/ipykernel_22597/2774240930.py:4: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
" summary = jax.tree_map(lambda x: x.item(), metrics)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" test epoch: 1, loss: 0.06, accuracy: 98.19\n",
"train epoch: 2, loss: 0.0489, accuracy: 98.54\n",
" test epoch: 2, loss: 0.05, accuracy: 98.45\n",
"train epoch: 3, loss: 0.0349, accuracy: 98.94\n",
" test epoch: 3, loss: 0.03, accuracy: 99.11\n",
"train epoch: 4, loss: 0.0251, accuracy: 99.24\n",
" test epoch: 4, loss: 0.03, accuracy: 99.08\n",
"train epoch: 5, loss: 0.0223, accuracy: 99.31\n",
" test epoch: 5, loss: 0.04, accuracy: 98.98\n",
"train epoch: 6, loss: 0.0179, accuracy: 99.43\n",
" test epoch: 6, loss: 0.03, accuracy: 99.10\n",
"train epoch: 7, loss: 0.0169, accuracy: 99.47\n",
" test epoch: 7, loss: 0.03, accuracy: 99.17\n",
"train epoch: 8, loss: 0.0136, accuracy: 99.58\n",
" test epoch: 8, loss: 0.04, accuracy: 98.93\n",
"train epoch: 9, loss: 0.0103, accuracy: 99.69\n",
" test epoch: 9, loss: 0.04, accuracy: 99.05\n",
"train epoch: 10, loss: 0.0100, accuracy: 99.69\n",
" test epoch: 10, loss: 0.03, accuracy: 99.25\n"
]
}
],
"source": [
"for epoch in range(1, num_epochs + 1):\n",
" # Use a separate PRNG key to permute image data during shuffling\n",
" rng, input_rng = jax.random.split(rng)\n",
" # Run an optimization step over a training batch\n",
" state = train_epoch(state, train_ds, batch_size, epoch, input_rng)\n",
" # Evaluate on the test set after each training epoch \n",
" test_loss, test_accuracy = eval_model(state.params, test_ds)\n",
" print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (\n",
" epoch, test_loss, test_accuracy * 100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax_gpu",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 62773b5

Please sign in to comment.