-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Recording flax training on MNIST\n", | ||
"\n", | ||
"The tutorial is taken from https://flax.readthedocs.io/en/latest/experimental/nnx/mnist_tutorial.html\n", | ||
"and adapted using the `papyrus`. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/konstantinnikolaou/Applications/miniconda3/envs/jax/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
}, | ||
{ | ||
"ename": "ModuleNotFoundError", | ||
"evalue": "No module named 'neural_state'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[1], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01moptax\u001b[39;00m \u001b[38;5;66;03m# Optimizers\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow_datasets\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtfds\u001b[39;00m \u001b[38;5;66;03m# TFDS for MNIST\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpapyrus\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpp\u001b[39;00m\n", | ||
"File \u001b[0;32m~/Repositories/papyrus/papyrus/__init__.py:25\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124;03mpapyrus: a lightweight Python library to record neural learning.\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;124;03mpapyrus measurements api.\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpapyrus\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m measurements, neural_state, utils\n\u001b[1;32m 27\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 28\u001b[0m measurements\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 29\u001b[0m utils\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 30\u001b[0m neural_state\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 31\u001b[0m ]\n", | ||
"File \u001b[0;32m~/Repositories/papyrus/papyrus/neural_state/__init__.py:24\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124;03mpapyrus: a lightweight Python library to record neural learning.\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;124;03m-------\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mneural_state\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mneural_state\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NeuralState\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mneural_state\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mneural_state_creator\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m NeuralStateCreator\n\u001b[1;32m 27\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 28\u001b[0m NeuralState\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 29\u001b[0m NeuralStateCreator\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m,\n\u001b[1;32m 30\u001b[0m ]\n", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'neural_state'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp # JAX NumPy\n", | ||
"\n", | ||
"from flax import linen as nn # The Linen API\n", | ||
"from flax.training import train_state # Useful dataclass to keep train state\n", | ||
"\n", | ||
"import numpy as np # Ordinary NumPy\n", | ||
"import optax # Optimizers\n", | ||
"import tensorflow_datasets as tfds # TFDS for MNIST\n", | ||
"\n", | ||
"\n", | ||
"import papyrus as pp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Preparations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class CNN(nn.Module):\n", | ||
" \"\"\"A simple CNN model.\"\"\"\n", | ||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, x):\n", | ||
" x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", | ||
" x = nn.relu(x)\n", | ||
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", | ||
" x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", | ||
" x = nn.relu(x)\n", | ||
" x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", | ||
" x = x.reshape((x.shape[0], -1)) # flatten\n", | ||
" x = nn.Dense(features=256)(x)\n", | ||
" x = nn.relu(x)\n", | ||
" x = nn.Dense(features=10)(x)\n", | ||
" return x" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def cross_entropy_loss(*, logits, labels):\n", | ||
" labels_onehot = jax.nn.one_hot(labels, num_classes=10)\n", | ||
" return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute_metrics(*, logits, labels):\n", | ||
" loss = cross_entropy_loss(logits=logits, labels=labels)\n", | ||
" accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n", | ||
" metrics = {\n", | ||
" 'loss': loss,\n", | ||
" 'accuracy': accuracy,\n", | ||
" }\n", | ||
" return metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_datasets():\n", | ||
" \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", | ||
" ds_builder = tfds.builder('mnist')\n", | ||
" ds_builder.download_and_prepare()\n", | ||
" train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n", | ||
" test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n", | ||
" train_ds['image'] = jnp.float32(train_ds['image']) / 255.\n", | ||
" test_ds['image'] = jnp.float32(test_ds['image']) / 255.\n", | ||
" return train_ds, test_ds" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def create_train_state(rng, learning_rate, momentum):\n", | ||
" \"\"\"Creates initial `TrainState`.\"\"\"\n", | ||
" cnn = CNN()\n", | ||
" params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n", | ||
" tx = optax.sgd(learning_rate, momentum)\n", | ||
" return train_state.TrainState.create(\n", | ||
" apply_fn=cnn.apply, params=params, tx=tx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@jax.jit\n", | ||
"def train_step(state, batch):\n", | ||
" \"\"\"Train for a single step.\"\"\"\n", | ||
" def loss_fn(params):\n", | ||
" logits = CNN().apply({'params': params}, batch['image'])\n", | ||
" loss = cross_entropy_loss(logits=logits, labels=batch['label'])\n", | ||
" return loss, logits\n", | ||
" grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n", | ||
" (_, logits), grads = grad_fn(state.params)\n", | ||
" state = state.apply_gradients(grads=grads)\n", | ||
" metrics = compute_metrics(logits=logits, labels=batch['label'])\n", | ||
" return state, metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@jax.jit\n", | ||
"def eval_step(params, batch):\n", | ||
" logits = CNN().apply({'params': params}, batch['image'])\n", | ||
" return compute_metrics(logits=logits, labels=batch['label'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Include Recorders" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_recorder = " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def train_epoch(state, train_ds, batch_size, epoch, rng):\n", | ||
" \"\"\"Train for a single epoch.\"\"\"\n", | ||
" train_ds_size = len(train_ds['image'])\n", | ||
" steps_per_epoch = train_ds_size // batch_size\n", | ||
"\n", | ||
" perms = jax.random.permutation(rng, train_ds_size)\n", | ||
" perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch\n", | ||
" perms = perms.reshape((steps_per_epoch, batch_size))\n", | ||
" batch_metrics = []\n", | ||
" for perm in perms:\n", | ||
" batch = {k: v[perm, ...] for k, v in train_ds.items()}\n", | ||
" state, metrics = train_step(state, batch)\n", | ||
" batch_metrics.append(metrics)\n", | ||
"\n", | ||
" # compute mean of metrics across each batch in epoch.\n", | ||
" batch_metrics_np = jax.device_get(batch_metrics)\n", | ||
" epoch_metrics_np = {\n", | ||
" k: np.mean([metrics[k] for metrics in batch_metrics_np])\n", | ||
" for k in batch_metrics_np[0]}\n", | ||
"\n", | ||
" print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (\n", | ||
" epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))\n", | ||
"\n", | ||
" return state" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def eval_model(params, test_ds):\n", | ||
" metrics = eval_step(params, test_ds)\n", | ||
" metrics = jax.device_get(metrics)\n", | ||
" summary = jax.tree_map(lambda x: x.item(), metrics)\n", | ||
" return summary['loss'], summary['accuracy']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_ds, test_ds = get_datasets()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rng = jax.random.PRNGKey(0)\n", | ||
"rng, init_rng = jax.random.split(rng)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"learning_rate = 0.1\n", | ||
"momentum = 0.9" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 25, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"state = create_train_state(init_rng, learning_rate, momentum)\n", | ||
"del init_rng # Must not be used anymore." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_epochs = 10\n", | ||
"batch_size = 32" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"train epoch: 1, loss: 0.1414, accuracy: 95.80\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/var/folders/6j/j_pvcb3n4tn448fvkngcwmvr0000gn/T/ipykernel_22597/2774240930.py:4: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n", | ||
" summary = jax.tree_map(lambda x: x.item(), metrics)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
" test epoch: 1, loss: 0.06, accuracy: 98.19\n", | ||
"train epoch: 2, loss: 0.0489, accuracy: 98.54\n", | ||
" test epoch: 2, loss: 0.05, accuracy: 98.45\n", | ||
"train epoch: 3, loss: 0.0349, accuracy: 98.94\n", | ||
" test epoch: 3, loss: 0.03, accuracy: 99.11\n", | ||
"train epoch: 4, loss: 0.0251, accuracy: 99.24\n", | ||
" test epoch: 4, loss: 0.03, accuracy: 99.08\n", | ||
"train epoch: 5, loss: 0.0223, accuracy: 99.31\n", | ||
" test epoch: 5, loss: 0.04, accuracy: 98.98\n", | ||
"train epoch: 6, loss: 0.0179, accuracy: 99.43\n", | ||
" test epoch: 6, loss: 0.03, accuracy: 99.10\n", | ||
"train epoch: 7, loss: 0.0169, accuracy: 99.47\n", | ||
" test epoch: 7, loss: 0.03, accuracy: 99.17\n", | ||
"train epoch: 8, loss: 0.0136, accuracy: 99.58\n", | ||
" test epoch: 8, loss: 0.04, accuracy: 98.93\n", | ||
"train epoch: 9, loss: 0.0103, accuracy: 99.69\n", | ||
" test epoch: 9, loss: 0.04, accuracy: 99.05\n", | ||
"train epoch: 10, loss: 0.0100, accuracy: 99.69\n", | ||
" test epoch: 10, loss: 0.03, accuracy: 99.25\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for epoch in range(1, num_epochs + 1):\n", | ||
" # Use a separate PRNG key to permute image data during shuffling\n", | ||
" rng, input_rng = jax.random.split(rng)\n", | ||
" # Run an optimization step over a training batch\n", | ||
" state = train_epoch(state, train_ds, batch_size, epoch, input_rng)\n", | ||
" # Evaluate on the test set after each training epoch \n", | ||
" test_loss, test_accuracy = eval_model(state.params, test_ds)\n", | ||
" print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (\n", | ||
" epoch, test_loss, test_accuracy * 100))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "jax_gpu", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |