diff --git a/docs/mnist_tutorial.ipynb b/docs/mnist_tutorial.ipynb new file mode 100644 index 0000000000..730117a102 --- /dev/null +++ b/docs/mnist_tutorial.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d445c08f", + "metadata": {}, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/mnist_tutorial.ipynb)\n", + "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/mnist_tutorial.ipynb)\n", + "\n", + "# MNIST Tutorial\n", + "\n", + "Welcome to NNX! This tutorial will guide you through building and training a simple convolutional \n", + "neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library\n", + "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within \n", + "[Flax](https://github.com/google/flax)." + ] + }, + { + "cell_type": "markdown", + "id": "81484083", + "metadata": {}, + "source": [ + "## 1. Install NNX\n", + "\n", + "Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c2a0f1c", + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "# !pip install git+https://github.com/google/flax.git" + ] + }, + { + "cell_type": "markdown", + "id": "0f89f054", + "metadata": {}, + "source": [ + "## 2. Load the MNIST Dataset\n", + "\n", + "First, the MNIST dataset is loaded and prepared for training and testing using \n", + "Tensorflow Datasets. Image values are normalized, the data is shuffled and divided \n", + "into batches, and samples are prefetched to enhance performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b78850e", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds # TFDS for MNIST\n", + "import tensorflow as tf # TensorFlow operations\n", + "\n", + "tf.random.set_seed(0) # set random seed for reproducibility\n", + "\n", + "num_epochs = 10\n", + "batch_size = 32\n", + "\n", + "train_ds: tf.data.Dataset = tfds.load('mnist', split='train')\n", + "test_ds: tf.data.Dataset = tfds.load('mnist', split='test')\n", + "\n", + "train_ds = train_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize train set\n", + "test_ds = test_ds.map(\n", + " lambda sample: {\n", + " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", + " 'label': sample['label'],\n", + " }\n", + ") # normalize test set\n", + "\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "train_ds = train_ds.repeat(num_epochs).shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)\n", + "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "test_ds = test_ds.shuffle(1024)\n", + "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "id": "1c09fbaf", + "metadata": {}, + "source": [ + "## 3. Define the Network with NNX\n", + "\n", + "Create a convolutional neural network with NNX by subclassing `nnx.Module`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84a11af8", + "metadata": {}, + "outputs": [], + "source": [ + "from flax.experimental import nnx # NNX API\n", + "from functools import partial\n", + "\n", + "class CNN(nnx.Module):\n", + " \"\"\"A simple CNN model.\"\"\"\n", + "\n", + " def __init__(self, *, rngs: nnx.Rngs):\n", + " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", + " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", + " self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))\n", + " self.linear1 = nnx.Linear(3136, 256, rngs=rngs)\n", + " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " x = self.avg_pool(nnx.relu(self.conv1(x)))\n", + " x = self.avg_pool(nnx.relu(self.conv2(x)))\n", + " x = x.reshape(x.shape[0], -1) # flatten\n", + " x = nnx.relu(self.linear1(x))\n", + " x = self.linear2(x)\n", + " return x\n", + "\n", + "model = CNN(rngs=nnx.Rngs(0))\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "id": "8db8d961", + "metadata": {}, + "source": [ + "### Run model\n", + "\n", + "Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06dbfb6c", + "metadata": { + "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp # JAX NumPy\n", + "\n", + "y = model(jnp.ones((1, 28, 28, 1)))\n", + "nnx.display(y)" + ] + }, + { + "cell_type": "markdown", + "id": "06e1cdd5", + "metadata": {}, + "source": [ + "## 4. Create Optimizer and Metrics\n", + "\n", + "In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "079e73b5", + "metadata": {}, + "outputs": [], + "source": [ + "import optax\n", + "\n", + "learning_rate = 0.005\n", + "momentum = 0.9\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))\n", + "metrics = nnx.MultiMetric(\n", + " accuracy=nnx.metrics.Accuracy(), \n", + " loss=nnx.metrics.Average('loss'),\n", + ")\n", + "\n", + "nnx.display(optimizer)" + ] + }, + { + "cell_type": "markdown", + "id": "28d3a3bb", + "metadata": {}, + "source": [ + "## 5. Training step\n", + "\n", + "We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0eb3d49f", + "metadata": {}, + "outputs": [], + "source": [ + "def loss_fn(model: CNN, batch):\n", + " logits = model(batch['image'])\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(\n", + " logits=logits, labels=batch['label']\n", + " ).mean()\n", + " return loss, logits" + ] + }, + { + "cell_type": "markdown", + "id": "601cd558", + "metadata": {}, + "source": [ + "Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:\n", + "\n", + "* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.\n", + "* Updates training accuracy using the loss, logits, and batch labels.\n", + "* Updates model parameters via the optimizer by applying the gradient updates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59578de4", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):\n", + " \"\"\"Train for a single step.\"\"\"\n", + " grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)\n", + " (loss, logits), grads = grad_fn(model, batch)\n", + " metrics.update(loss=loss, logits=logits, labels=batch['label'])\n", + " optimizer.update(grads)" + ] + }, + { + "cell_type": "markdown", + "id": "a4ab3d87", + "metadata": {}, + "source": [ + "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", + "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", + "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", + "except it can transforms functions that contain NNX objects as inputs and outputs.\n", + "\n", + "## 6. Evaluation step\n", + "\n", + "Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14d48cff", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):\n", + " loss, logits = loss_fn(model, batch)\n", + " metrics.update(loss=loss, logits=logits, labels=batch['label'])" + ] + }, + { + "cell_type": "markdown", + "id": "3d768c92", + "metadata": {}, + "source": [ + "## 7. Seed randomness\n", + "\n", + "For reproducible dataset shuffling (using `tf.data.Dataset.shuffle`), set the TF random seed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be7e6548", + "metadata": {}, + "outputs": [], + "source": [ + "tf.random.set_seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "7f631112", + "metadata": {}, + "source": [ + "## 8. Train and Evaluate\n", + "\n", + "Now we train a model using batches of data for 10 epochs, evaluate its performance \n", + "on the test set after each epoch, and log the training and testing metrics (loss and\n", + "accuracy) throughout the process. Typically this leads to a model with around 99% accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a30b37c", + "metadata": { + "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" + }, + "outputs": [], + "source": [ + "num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs\n", + "\n", + "metrics_history = {\n", + " 'train_loss': [],\n", + " 'train_accuracy': [],\n", + " 'test_loss': [],\n", + " 'test_accuracy': [],\n", + "}\n", + "\n", + "for step, batch in enumerate(train_ds.as_numpy_iterator()):\n", + " # Run the optimization for one step and make a stateful update to the following:\n", + " # - the train state's model parameters\n", + " # - the optimizer state\n", + " # - the training loss and accuracy batch metrics\n", + " train_step(model, optimizer, metrics, batch)\n", + "\n", + " if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed\n", + " # Log training metrics\n", + " for metric, value in metrics.compute().items(): # compute metrics\n", + " metrics_history[f'train_{metric}'].append(value) # record metrics\n", + " metrics.reset() # reset metrics for test set\n", + "\n", + " # Compute metrics on the test set after each training epoch\n", + " for test_batch in test_ds.as_numpy_iterator():\n", + " eval_step(model, metrics, test_batch)\n", + "\n", + " # Log test metrics\n", + " for metric, value in metrics.compute().items():\n", + " metrics_history[f'test_{metric}'].append(value)\n", + " metrics.reset() # reset metrics for next training epoch\n", + "\n", + " print(\n", + " f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", + " f\"loss: {metrics_history['train_loss'][-1]}, \"\n", + " f\"accuracy: {metrics_history['train_accuracy'][-1] * 100}\"\n", + " )\n", + " print(\n", + " f\"test epoch: {(step+1) // num_steps_per_epoch}, \"\n", + " f\"loss: {metrics_history['test_loss'][-1]}, \"\n", + " f\"accuracy: {metrics_history['test_accuracy'][-1] * 100}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "647f3c63", + "metadata": {}, + "source": [ + "## 9. Visualize Metrics\n", + "\n", + "Use Matplotlib to create plots for loss and accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38e7769c", + "metadata": { + "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt # Visualization\n", + "\n", + "# Plot loss and accuracy in subplots\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", + "ax1.set_title('Loss')\n", + "ax2.set_title('Accuracy')\n", + "for dataset in ('train', 'test'):\n", + " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", + " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", + "ax1.legend()\n", + "ax2.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9d7e11c6", + "metadata": {}, + "source": [ + "## 10. Perform inference on test set\n", + "\n", + "Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58fb4a8c", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def pred_step(model: CNN, batch):\n", + " logits = model(batch['image'])\n", + " return logits.argmax(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "609a063b", + "metadata": { + "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" + }, + "outputs": [], + "source": [ + "test_batch = test_ds.as_numpy_iterator().next()\n", + "pred = pred_step(model, test_batch)\n", + "\n", + "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", + "for i, ax in enumerate(axs.flatten()):\n", + " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", + " ax.set_title(f'label={pred[i]}')\n", + " ax.axis('off')" + ] + }, + { + "cell_type": "markdown", + "id": "feeef7e0", + "metadata": {}, + "source": [ + "Congratulations! You made it to the end of the annotated MNIST example." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/mnist_tutorial.md b/docs/mnist_tutorial.md new file mode 100644 index 0000000000..113c221940 --- /dev/null +++ b/docs/mnist_tutorial.md @@ -0,0 +1,287 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/mnist_tutorial.ipynb) +[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/mnist_tutorial.ipynb) + +# MNIST Tutorial + +Welcome to NNX! This tutorial will guide you through building and training a simple convolutional +neural network (CNN) on the MNIST dataset using the NNX API. NNX is a Python neural network library +built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within +[Flax](https://github.com/google/flax). + ++++ + +## 1. Install NNX + +Since NNX is under active development, we recommend using the latest version from the Flax GitHub repository: + +```{code-cell} +:tags: [skip-execution] + +# !pip install git+https://github.com/google/flax.git +``` + +## 2. Load the MNIST Dataset + +First, the MNIST dataset is loaded and prepared for training and testing using +Tensorflow Datasets. Image values are normalized, the data is shuffled and divided +into batches, and samples are prefetched to enhance performance. + +```{code-cell} +import tensorflow_datasets as tfds # TFDS for MNIST +import tensorflow as tf # TensorFlow operations + +tf.random.set_seed(0) # set random seed for reproducibility + +num_epochs = 10 +batch_size = 32 + +train_ds: tf.data.Dataset = tfds.load('mnist', split='train') +test_ds: tf.data.Dataset = tfds.load('mnist', split='test') + +train_ds = train_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } +) # normalize train set +test_ds = test_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } +) # normalize test set + +# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +train_ds = train_ds.repeat(num_epochs).shuffle(1024) +# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) +# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +test_ds = test_ds.shuffle(1024) +# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) +``` + +## 3. Define the Network with NNX + +Create a convolutional neural network with NNX by subclassing `nnx.Module`. + +```{code-cell} +from flax.experimental import nnx # NNX API +from functools import partial + +class CNN(nnx.Module): + """A simple CNN model.""" + + def __init__(self, *, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) + + def __call__(self, x): + x = self.avg_pool(nnx.relu(self.conv1(x))) + x = self.avg_pool(nnx.relu(self.conv2(x))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.linear1(x)) + x = self.linear2(x) + return x + +model = CNN(rngs=nnx.Rngs(0)) +nnx.display(model) +``` + +### Run model + +Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results. + +```{code-cell} +:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da + +import jax.numpy as jnp # JAX NumPy + +y = model(jnp.ones((1, 28, 28, 1))) +nnx.display(y) +``` + +## 4. Create Optimizer and Metrics + +In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. + +```{code-cell} +import optax + +learning_rate = 0.005 +momentum = 0.9 + +optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) +metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average('loss'), +) + +nnx.display(optimizer) +``` + +## 5. Training step + +We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. + +```{code-cell} +def loss_fn(model: CNN, batch): + logits = model(batch['image']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label'] + ).mean() + return loss, logits +``` + +Next, we create the training step function. This function takes the `model` and a data `batch` and does the following: + +* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`. +* Updates training accuracy using the loss, logits, and batch labels. +* Updates model parameters via the optimizer by applying the gradient updates. + +```{code-cell} +@nnx.jit +def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch): + """Train for a single step.""" + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) + optimizer.update(grads) +``` + +The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with +[XLA](https://www.tensorflow.org/xla), optimizing performance on +hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), +except it can transforms functions that contain NNX objects as inputs and outputs. + +## 6. Evaluation step + +Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier. + +```{code-cell} +@nnx.jit +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): + loss, logits = loss_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) +``` + +## 7. Seed randomness + +For reproducible dataset shuffling (using `tf.data.Dataset.shuffle`), set the TF random seed. + +```{code-cell} +tf.random.set_seed(0) +``` + +## 8. Train and Evaluate + +Now we train a model using batches of data for 10 epochs, evaluate its performance +on the test set after each epoch, and log the training and testing metrics (loss and +accuracy) throughout the process. Typically this leads to a model with around 99% accuracy. + +```{code-cell} +:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 + +num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs + +metrics_history = { + 'train_loss': [], + 'train_accuracy': [], + 'test_loss': [], + 'test_accuracy': [], +} + +for step, batch in enumerate(train_ds.as_numpy_iterator()): + # Run the optimization for one step and make a stateful update to the following: + # - the train state's model parameters + # - the optimizer state + # - the training loss and accuracy batch metrics + train_step(model, optimizer, metrics, batch) + + if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed + # Log training metrics + for metric, value in metrics.compute().items(): # compute metrics + metrics_history[f'train_{metric}'].append(value) # record metrics + metrics.reset() # reset metrics for test set + + # Compute metrics on the test set after each training epoch + for test_batch in test_ds.as_numpy_iterator(): + eval_step(model, metrics, test_batch) + + # Log test metrics + for metric, value in metrics.compute().items(): + metrics_history[f'test_{metric}'].append(value) + metrics.reset() # reset metrics for next training epoch + + print( + f"train epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['train_loss'][-1]}, " + f"accuracy: {metrics_history['train_accuracy'][-1] * 100}" + ) + print( + f"test epoch: {(step+1) // num_steps_per_epoch}, " + f"loss: {metrics_history['test_loss'][-1]}, " + f"accuracy: {metrics_history['test_accuracy'][-1] * 100}" + ) +``` + +## 9. Visualize Metrics + +Use Matplotlib to create plots for loss and accuracy. + +```{code-cell} +:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac + +import matplotlib.pyplot as plt # Visualization + +# Plot loss and accuracy in subplots +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) +ax1.set_title('Loss') +ax2.set_title('Accuracy') +for dataset in ('train', 'test'): + ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') + ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') +ax1.legend() +ax2.legend() +plt.show() +``` + +## 10. Perform inference on test set + +Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. + +```{code-cell} +@nnx.jit +def pred_step(model: CNN, batch): + logits = model(batch['image']) + return logits.argmax(axis=1) +``` + +```{code-cell} +:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e + +test_batch = test_ds.as_numpy_iterator().next() +pred = pred_step(model, test_batch) + +fig, axs = plt.subplots(5, 5, figsize=(12, 12)) +for i, ax in enumerate(axs.flatten()): + ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') + ax.set_title(f'label={pred[i]}') + ax.axis('off') +``` + +Congratulations! You made it to the end of the annotated MNIST example.