diff --git a/docs/guides/flax_fundamentals/flax_basics.ipynb b/docs/guides/flax_fundamentals/flax_basics.ipynb index 2a2198121..b07cc4129 100644 --- a/docs/guides/flax_fundamentals/flax_basics.ipynb +++ b/docs/guides/flax_fundamentals/flax_basics.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "yf-nWLh0naJi" - }, + "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", @@ -22,9 +20,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KyANAaZtbs86" - }, + "metadata": {}, "source": [ "## Setting up our environment\n", "\n", @@ -35,7 +31,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "qdrEVv9tinJn", "outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee", "tags": [ "skip-execution" @@ -61,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "kN6bZDaReZO2" - }, + "metadata": {}, "outputs": [], "source": [ "import jax\n", @@ -75,9 +68,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "pCCwAbOLiscA" - }, + "metadata": {}, "source": [ "## Linear regression with Flax\n", "\n", @@ -91,9 +82,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "id": "zWX2zEtphT4Y" - }, + "metadata": {}, "outputs": [], "source": [ "# We create one dense layer instance (taking 'features' parameter as input)\n", @@ -102,9 +91,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UmzP1QoQYAAN" - }, + "metadata": {}, "source": [ "Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n", "\n", @@ -117,7 +104,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "K529lhzeYtl8", "outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5" }, "outputs": [ @@ -155,9 +141,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "NH7Y9xMEewmO" - }, + "metadata": {}, "source": [ "*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n", "\n", @@ -171,9 +155,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "M1qo9M3_naJo" - }, + "metadata": {}, "source": [ "To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:" ] @@ -182,7 +164,6 @@ "cell_type": "code", "execution_count": 6, "metadata": { - "id": "J8ietJecWiuK", "outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae" }, "outputs": [ @@ -205,9 +186,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "lVsjgYzuSBGL" - }, + "metadata": {}, "source": [ "### Gradient descent\n", "\n", @@ -222,7 +201,6 @@ "cell_type": "code", "execution_count": 7, "metadata": { - "id": "bFIiMnL4dl-e", "outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52" }, "outputs": [ @@ -257,9 +235,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "ZHkioicCiUbx" - }, + "metadata": {}, "source": [ "We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))." ] @@ -267,9 +243,7 @@ { "cell_type": "code", "execution_count": 8, - "metadata": { - "id": "JqJaVc7BeNyT" - }, + "metadata": {}, "outputs": [], "source": [ "# Same as JAX version but using model.apply().\n", @@ -285,9 +259,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "wGKru__mi15v" - }, + "metadata": {}, "source": [ "And finally perform the gradient descent." ] @@ -296,7 +268,6 @@ "cell_type": "code", "execution_count": 9, "metadata": { - "id": "ePEl1ndse0Jq", "outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4" }, "outputs": [ @@ -340,9 +311,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "zqEnJ9Poyb6q" - }, + "metadata": {}, "source": [ "### Optimizing with Optax\n", "\n", @@ -372,9 +341,7 @@ { "cell_type": "code", "execution_count": 10, - "metadata": { - "id": "Ce77uDJx1bUF" - }, + "metadata": {}, "outputs": [], "source": [ "import optax\n", @@ -387,7 +354,6 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "id": "PTSv0vx13xPO", "outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec" }, "outputs": [ @@ -420,9 +386,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "0eAPPwtpXYu7" - }, + "metadata": {}, "source": [ "### Serializing the result\n", "\n", @@ -433,7 +397,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "BiUPRU93XnAZ", "outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c" }, "outputs": [ @@ -479,9 +442,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "eielPo2KZByd" - }, + "metadata": {}, "source": [ "To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n", "\n", @@ -492,7 +453,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "MOhoBDCOYYJ5", "outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc" }, "outputs": [ @@ -531,9 +491,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "8mNu8nuOhDC5" - }, + "metadata": {}, "source": [ "## Defining your own models\n", "\n", @@ -544,9 +502,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "1sllHAdRlpmQ" - }, + "metadata": {}, "source": [ "### Module basics\n", "\n", @@ -557,7 +513,6 @@ "cell_type": "code", "execution_count": 17, "metadata": { - "id": "vbfrfbkxgPhg", "outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3" }, "outputs": [ @@ -610,9 +565,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "DDITIjXitEZl" - }, + "metadata": {}, "source": [ "As we can see, a `nn.Module` subclass is made of:\n", "\n", @@ -630,7 +583,6 @@ "cell_type": "code", "execution_count": 19, "metadata": { - "id": "DEYrVA6dnaJu", "outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058" }, "outputs": [ @@ -651,9 +603,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "I__UrmShnaJu" - }, + "metadata": {}, "source": [ "Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:" ] @@ -662,7 +612,6 @@ "cell_type": "code", "execution_count": 20, "metadata": { - "id": "ZTCbdpQ4suSK", "outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d" }, "outputs": [ @@ -712,9 +661,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "es7YHjgexT-L" - }, + "metadata": {}, "source": [ "There are, however, a few differences you should be aware of between the two declaration modes:\n", "\n", @@ -725,9 +672,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "-ykceROJyp7W" - }, + "metadata": {}, "source": [ "### Module parameters\n", "\n", @@ -738,7 +683,6 @@ "cell_type": "code", "execution_count": 21, "metadata": { - "id": "wK371Pt_vVfR", "outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5" }, "outputs": [ @@ -793,9 +737,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "MKyhfzVpzC94" - }, + "metadata": {}, "source": [ "Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` :\n", "\n", @@ -808,9 +750,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "QmSpxyqLDr58" - }, + "metadata": {}, "source": [ "### Variables and collections of variables\n", "\n", @@ -828,7 +768,6 @@ "cell_type": "code", "execution_count": 22, "metadata": { - "id": "J6_tR-nPzB1i", "outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a" }, "outputs": [ @@ -883,9 +822,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "5OHBbMJng3ic" - }, + "metadata": {}, "source": [ "Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:" ] @@ -894,7 +831,6 @@ "cell_type": "code", "execution_count": 23, "metadata": { - "id": "IbTsCAvZcdBy", "outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b" }, "outputs": [ @@ -934,9 +870,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GuUSOSKegKIM" - }, + "metadata": {}, "source": [ "From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n", "\n", @@ -947,7 +881,6 @@ "cell_type": "code", "execution_count": 29, "metadata": { - "id": "TUgAbUPpnaJw", "outputId": "0906fbab-b866-4956-d231-b1374415d448" }, "outputs": [ @@ -1004,9 +937,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "eWUmx5EjtWge" - }, + "metadata": {}, "source": [ "Note that the above function has a quite verbose signature and it would not actually\n", "work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n", @@ -1016,9 +947,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "_GL0PsCwnaJw" - }, + "metadata": {}, "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", diff --git a/docs/guides/flax_fundamentals/flax_basics.md b/docs/guides/flax_fundamentals/flax_basics.md index d349efc45..437d03c63 100644 --- a/docs/guides/flax_fundamentals/flax_basics.md +++ b/docs/guides/flax_fundamentals/flax_basics.md @@ -8,8 +8,6 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "yf-nWLh0naJi"} - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) @@ -23,14 +21,13 @@ This notebook will walk you through the following workflow: * Serialization of parameters and other objects. * Creating your own models and managing state. -+++ {"id": "KyANAaZtbs86"} ++++ ## Setting up our environment Here we provide the code needed to set up the environment for our notebook. ```{code-cell} -:id: qdrEVv9tinJn :outputId: e30aa464-fa52-4f35-df96-716c68a4b3ee :tags: [skip-execution] @@ -41,8 +38,6 @@ Here we provide the code needed to set up the environment for our notebook. ``` ```{code-cell} -:id: kN6bZDaReZO2 - import jax from typing import Any, Callable, Sequence from jax import random, numpy as jnp @@ -50,8 +45,6 @@ import flax from flax import linen as nn ``` -+++ {"id": "pCCwAbOLiscA"} - ## Linear regression with Flax In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done. @@ -61,14 +54,10 @@ A dense layer is a layer that has a kernel parameter $W\in\mathcal{M}_{m,n}(\mat This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`). ```{code-cell} -:id: zWX2zEtphT4Y - # We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) ``` -+++ {"id": "UmzP1QoQYAAN"} - Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class. ### Model parameters & initialization @@ -76,7 +65,6 @@ Layers (and models in general, we'll use that word from now on) are subclasses o Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data. ```{code-cell} -:id: K529lhzeYtl8 :outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5 key1, key2 = random.split(random.key(0)) @@ -85,8 +73,6 @@ params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes ``` -+++ {"id": "NH7Y9xMEewmO"} - *Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.* The result is what we expect: bias and kernel parameters of the correct size. Under the hood: @@ -96,19 +82,16 @@ The result is what we expect: bias and kernel parameters of the correct size. Un * Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`. * The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`. -+++ {"id": "M1qo9M3_naJo"} ++++ To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input: ```{code-cell} -:id: J8ietJecWiuK :outputId: 7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae model.apply(params, x) ``` -+++ {"id": "lVsjgYzuSBGL"} - ### Gradient descent If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error: @@ -118,7 +101,6 @@ $$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}( Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example. ```{code-cell} -:id: bFIiMnL4dl-e :outputId: 6eae59dc-0632-4f53-eac8-c22a7c646a52 # Set problem dimensions. @@ -141,13 +123,9 @@ y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) ``` -+++ {"id": "ZHkioicCiUbx"} - We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees)). ```{code-cell} -:id: JqJaVc7BeNyT - # Same as JAX version but using model.apply(). @jax.jit def mse(params, x_batched, y_batched): @@ -159,12 +137,9 @@ def mse(params, x_batched, y_batched): return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0) ``` -+++ {"id": "wGKru__mi15v"} - And finally perform the gradient descent. ```{code-cell} -:id: ePEl1ndse0Jq :outputId: 50d975b3-4706-4d8a-c4b8-2629ab8e3ac4 learning_rate = 0.3 # Gradient step size. @@ -185,8 +160,6 @@ for i in range(101): print(f'Loss step {i}: ', loss_val) ``` -+++ {"id": "zqEnJ9Poyb6q"} - ### Optimizing with Optax Flax used to use its own `flax.optim` package for optimization, but with @@ -212,8 +185,6 @@ to the [official documentation](https://optax.readthedocs.io/en/latest/). ```{code-cell} -:id: Ce77uDJx1bUF - import optax tx = optax.adam(learning_rate=learning_rate) opt_state = tx.init(params) @@ -221,7 +192,6 @@ loss_grad_fn = jax.value_and_grad(mse) ``` ```{code-cell} -:id: PTSv0vx13xPO :outputId: eec0c096-1d9e-4b3c-f8e5-942ee63828ec for i in range(101): @@ -232,14 +202,11 @@ for i in range(101): print('Loss step {}: '.format(i), loss_val) ``` -+++ {"id": "0eAPPwtpXYu7"} - ### Serializing the result Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that. ```{code-cell} -:id: BiUPRU93XnAZ :outputId: b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c from flax import serialization @@ -251,35 +218,29 @@ print('Bytes output') print(bytes_output) ``` -+++ {"id": "eielPo2KZByd"} - To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place. *The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.* ```{code-cell} -:id: MOhoBDCOYYJ5 :outputId: 13acc4e1-8757-4554-e2c8-d594ba6e67dc serialization.from_bytes(params, bytes_output) ``` -+++ {"id": "8mNu8nuOhDC5"} - ## Defining your own models Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class. *Keep in mind that we imported* `linen as nn` *and this only works with the new linen API* -+++ {"id": "1sllHAdRlpmQ"} ++++ ### Module basics The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function. ```{code-cell} -:id: vbfrfbkxgPhg :outputId: b59c679c-d164-4fd6-92db-b50f0d310ec3 class ExplicitMLP(nn.Module): @@ -310,8 +271,6 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax. print('output:\n', y) ``` -+++ {"id": "DDITIjXitEZl"} - As we can see, a `nn.Module` subclass is made of: * A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`. @@ -324,7 +283,6 @@ As we can see, a `nn.Module` subclass is made of: Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input: ```{code-cell} -:id: DEYrVA6dnaJu :outputId: 4af16ec5-b52a-43b0-fc47-1f8ab25e7058 try: @@ -333,12 +291,9 @@ except AttributeError as e: print(e) ``` -+++ {"id": "I__UrmShnaJu"} - Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so: ```{code-cell} -:id: ZTCbdpQ4suSK :outputId: 183a74ef-f54e-4848-99bf-fee4c174ba6d class SimpleMLP(nn.Module): @@ -366,22 +321,19 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax. print('output:\n', y) ``` -+++ {"id": "es7YHjgexT-L"} - There are, however, a few differences you should be aware of between the two declaration modes: * In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders). * If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated. * The last initialization will be handled differently. See these notes for more details (TODO: add notes link). -+++ {"id": "-ykceROJyp7W"} ++++ ### Module parameters In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules: ```{code-cell} -:id: wK371Pt_vVfR :outputId: 83b5fea4-071e-4ea0-8fa8-610e69fb5fd5 class SimpleDense(nn.Module): @@ -410,8 +362,6 @@ print('initialized parameters:\n', params) print('output:\n', y) ``` -+++ {"id": "MKyhfzVpzC94"} - Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` : * `name` is simply the name of the parameter that will end up in the parameter structure. @@ -420,7 +370,7 @@ Here, we see how to both declare and assign a parameter to the model using the ` Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site. -+++ {"id": "QmSpxyqLDr58"} ++++ ### Variables and collections of variables @@ -434,7 +384,6 @@ However this is not enough to cover everything that we would need for machine le For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py). ```{code-cell} -:id: J6_tR-nPzB1i :outputId: 75465fd6-cdc8-497c-a3ec-7f709b5dde7a class BiasAdderWithRunningMean(nn.Module): @@ -463,12 +412,9 @@ y, updated_state = model.apply(variables, x, mutable=['batch_stats']) print('updated state:\n', updated_state) ``` -+++ {"id": "5OHBbMJng3ic"} - Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern: ```{code-cell} -:id: IbTsCAvZcdBy :outputId: 09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b for val in [1.0, 2.0, 3.0]: @@ -479,14 +425,11 @@ for val in [1.0, 2.0, 3.0]: print('updated state:\n', updated_state) # Shows only the mutable part ``` -+++ {"id": "GuUSOSKegKIM"} - From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables. *This example isn't doing anything and is only for demonstration purposes.* ```{code-cell} -:id: TUgAbUPpnaJw :outputId: 0906fbab-b866-4956-d231-b1374415d448 from functools import partial @@ -517,14 +460,12 @@ for _ in range(3): print('Updated state: ', state) ``` -+++ {"id": "eWUmx5EjtWge"} - Note that the above function has a quite verbose signature and it would not actually work with `jax.jit()` because the function arguments are not "valid JAX types". Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more. -+++ {"id": "_GL0PsCwnaJw"} ++++ ### Exporting to Tensorflow's SavedModel with jax2tf diff --git a/docs/guides/parallel_training/flax_on_pjit.ipynb b/docs/guides/parallel_training/flax_on_pjit.ipynb index b93547b02..59adbaf09 100644 --- a/docs/guides/parallel_training/flax_on_pjit.ipynb +++ b/docs/guides/parallel_training/flax_on_pjit.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "2a9f78765c0c" - }, + "metadata": {}, "source": [ "# Scale up Flax Modules on multiple devices\n", "\n", @@ -14,9 +12,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "b1e0e5fc8bc1" - }, + "metadata": {}, "source": [ "## Flax and `jax.jit` scaled up\n", "\n", @@ -34,9 +30,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "a9601432b448" - }, + "metadata": {}, "source": [ "## Setup\n", "\n", @@ -49,7 +43,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "867203db3bef", "tags": [ "skip-execution" ] @@ -63,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "f8f42d1174e5" - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -75,9 +66,7 @@ { "cell_type": "code", "execution_count": 42, - "metadata": { - "id": "b8da40732f0b" - }, + "metadata": {}, "outputs": [], "source": [ "import functools\n", @@ -98,9 +87,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "id": "bcc30de1d6eb" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -117,9 +104,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "c0d280def897" - }, + "metadata": {}, "source": [ "The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide:\n", "\n", @@ -135,9 +120,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "684fe9fe13a0" - }, + "metadata": {}, "outputs": [], "source": [ "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", @@ -148,9 +131,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "4589d7a6d4bb" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -178,9 +159,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "307d39db6d94" - }, + "metadata": {}, "source": [ "## Define a layer\n", "\n", @@ -198,9 +177,7 @@ { "cell_type": "code", "execution_count": 43, - "metadata": { - "id": "b74c049968dc" - }, + "metadata": {}, "outputs": [], "source": [ "class DotReluDot(nn.Module):\n", @@ -234,9 +211,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "cbac5321c08e" - }, + "metadata": {}, "source": [ "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n", "\n", @@ -256,9 +231,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "b8389c11af79" - }, + "metadata": {}, "source": [ "## Define a model with `flax.linen.scan` lifted transformation\n", "\n", @@ -277,9 +250,7 @@ { "cell_type": "code", "execution_count": 44, - "metadata": { - "id": "a0ea0dcccbc3" - }, + "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", @@ -303,9 +274,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "44395b62561d" - }, + "metadata": {}, "source": [ "Now, create a `model` instance, and a sample input `x`." ] @@ -313,9 +282,7 @@ { "cell_type": "code", "execution_count": 45, - "metadata": { - "id": "5686299b4839" - }, + "metadata": {}, "outputs": [], "source": [ "# MLP hyperparameters.\n", @@ -334,9 +301,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "5b3abfef359d" - }, + "metadata": {}, "source": [ "## Specify sharding\n", "\n", @@ -350,9 +315,7 @@ { "cell_type": "code", "execution_count": 46, - "metadata": { - "id": "8b913a2e57d3" - }, + "metadata": {}, "outputs": [ { "data": { @@ -397,9 +360,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "06d134795ae1" - }, + "metadata": {}, "source": [ "### The output's sharding\n", "\n", @@ -416,9 +377,7 @@ { "cell_type": "code", "execution_count": 47, - "metadata": { - "id": "19094ec63385" - }, + "metadata": {}, "outputs": [], "source": [ "def init_fn(k, x, model, optimizer):\n", @@ -433,9 +392,7 @@ { "cell_type": "code", "execution_count": 48, - "metadata": { - "id": "e49264a3c78e" - }, + "metadata": {}, "outputs": [ { "data": { @@ -542,9 +499,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2ec24614050b" - }, + "metadata": {}, "source": [ "## Compile the code\n", "\n", @@ -556,9 +511,7 @@ { "cell_type": "code", "execution_count": 49, - "metadata": { - "id": "5b6e699df733" - }, + "metadata": {}, "outputs": [ { "data": { @@ -638,9 +591,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "8f74b009f11f" - }, + "metadata": {}, "source": [ "## Inspect the Module output\n", "\n", @@ -652,9 +603,7 @@ { "cell_type": "code", "execution_count": 14, - "metadata": { - "id": "19243982c892" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -676,9 +625,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2beee7d27bdb" - }, + "metadata": {}, "source": [ "You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices." ] @@ -686,9 +633,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "id": "2067c419a826" - }, + "metadata": {}, "outputs": [ { "data": { @@ -708,9 +653,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "id": "d7cf0baa334b" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -737,9 +680,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "273547d3ab89" - }, + "metadata": {}, "source": [ "You can use [`jax.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays." ] @@ -747,9 +688,7 @@ { "cell_type": "code", "execution_count": 17, - "metadata": { - "id": "29b3dae156a2" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -779,9 +718,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "f7e1ccb14c6b" - }, + "metadata": {}, "source": [ "## Compile the train step and inference \n", "\n", @@ -791,9 +728,7 @@ { "cell_type": "code", "execution_count": 18, - "metadata": { - "id": "4e3cc300cfee" - }, + "metadata": {}, "outputs": [], "source": [ "@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), \n", @@ -815,9 +750,7 @@ { "cell_type": "code", "execution_count": 19, - "metadata": { - "id": "91c6c2662c12" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -906,9 +839,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "2bae79e2e71b" - }, + "metadata": {}, "source": [ "Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`." ] @@ -916,9 +847,7 @@ { "cell_type": "code", "execution_count": 20, - "metadata": { - "id": "c9264a48b9ee" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -979,9 +908,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "7daa9e6e6eb4" - }, + "metadata": {}, "source": [ "## Profiling\n", "\n", @@ -991,9 +918,7 @@ { "cell_type": "code", "execution_count": 21, - "metadata": { - "id": "a68d7cb2eb89" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1017,9 +942,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "51420b514d53" - }, + "metadata": {}, "source": [ "## Logical axis annotation\n", "\n", @@ -1035,9 +958,7 @@ { "cell_type": "code", "execution_count": 50, - "metadata": { - "id": "a26f85a9e772" - }, + "metadata": {}, "outputs": [], "source": [ "class LogicalDotReluDot(nn.Module):\n", @@ -1085,9 +1006,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "0de93ec6cbd6" - }, + "metadata": {}, "source": [ "Now, initiate a model and try to figure out what sharding its `state` should have.\n", "\n", @@ -1099,9 +1018,7 @@ { "cell_type": "code", "execution_count": 51, - "metadata": { - "id": "14db7a1e30fd" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1133,9 +1050,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "58475fffb2de" - }, + "metadata": {}, "source": [ "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above." ] @@ -1143,9 +1058,7 @@ { "cell_type": "code", "execution_count": 52, - "metadata": { - "id": "589ff774bb4c" - }, + "metadata": {}, "outputs": [ { "data": { @@ -1165,9 +1078,7 @@ { "cell_type": "code", "execution_count": 53, - "metadata": { - "id": "77e07a0ab309" - }, + "metadata": {}, "outputs": [], "source": [ "logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),\n", @@ -1180,9 +1091,7 @@ { "cell_type": "code", "execution_count": 54, - "metadata": { - "id": "fb53bc20e0f9" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1271,9 +1180,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "ae1754a3031d" - }, + "metadata": {}, "source": [ "## When to use device axis / logical axis\n", "\n", @@ -1289,9 +1196,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "576bdd5cd782" - }, + "metadata": {}, "source": [ "## Save the data\n", "\n", diff --git a/docs/guides/parallel_training/flax_on_pjit.md b/docs/guides/parallel_training/flax_on_pjit.md index 486871642..fe2e75c51 100644 --- a/docs/guides/parallel_training/flax_on_pjit.md +++ b/docs/guides/parallel_training/flax_on_pjit.md @@ -8,13 +8,11 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "2a9f78765c0c"} - # Scale up Flax Modules on multiple devices This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html). -+++ {"id": "b1e0e5fc8bc1"} ++++ ## Flax and `jax.jit` scaled up @@ -28,7 +26,7 @@ Flax provides several functionalities that can help you use auto-SPMD on [Flax M You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site. -+++ {"id": "a9601432b448"} ++++ ## Setup @@ -37,7 +35,6 @@ Import some necessary dependencies. **Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment. ```{code-cell} -:id: 867203db3bef :tags: [skip-execution] # Once Flax v0.6.10 is released, there is no need to do this. @@ -45,15 +42,11 @@ Import some necessary dependencies. ``` ```{code-cell} -:id: f8f42d1174e5 - import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` ```{code-cell} -:id: b8da40732f0b - import functools from typing import Optional, Callable @@ -70,13 +63,9 @@ import optax # Optax for common losses and optimizers. ``` ```{code-cell} -:id: bcc30de1d6eb - print(f'We have 8 fake JAX devices now: {jax.devices()}') ``` -+++ {"id": "c0d280def897"} - The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide: 1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). @@ -88,16 +77,12 @@ The code below shows how to import and set up the JAX-level device API, followin 3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout. ```{code-cell} -:id: 684fe9fe13a0 - from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.lax import with_sharding_constraint from jax.experimental import mesh_utils ``` ```{code-cell} -:id: 4589d7a6d4bb - # Create a mesh and annotate each axis with a name. device_mesh = mesh_utils.create_device_mesh((2, 4)) print(device_mesh) @@ -109,8 +94,6 @@ def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec) ``` -+++ {"id": "307d39db6d94"} - ## Define a layer Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between. @@ -124,8 +107,6 @@ To shard the parameters efficiently, apply the following APIs to annotate the pa * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. ```{code-cell} -:id: b74c049968dc - class DotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @@ -154,8 +135,6 @@ class DotReluDot(nn.Module): return z, None ``` -+++ {"id": "cbac5321c08e"} - Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all. For example: @@ -170,7 +149,7 @@ For example: * The first dimension — the batch dimension — will be sharded over the `'data'` axis. This means half of the batch will be processed on devices `0-3` (first four devices), and another half on devices `4-7` (the remaining four devices). * The second dimension — the data depth dimension — will be replicated across all devices. -+++ {"id": "b8389c11af79"} ++++ ## Define a model with `flax.linen.scan` lifted transformation @@ -186,8 +165,6 @@ The code below shows how to apply both methods, and default with the for-loop, s The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations). ```{code-cell} -:id: a0ea0dcccbc3 - class MLP(nn.Module): num_layers: int depth: int @@ -206,13 +183,9 @@ class MLP(nn.Module): return x ``` -+++ {"id": "44395b62561d"} - Now, create a `model` instance, and a sample input `x`. ```{code-cell} -:id: 5686299b4839 - # MLP hyperparameters. BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. @@ -226,8 +199,6 @@ optimizer = optax.adam(learning_rate=0.001) model = MLP(LAYERS, DEPTH, USE_SCAN) ``` -+++ {"id": "5b3abfef359d"} - ## Specify sharding Next, you need to tell `jax.jit` how to shard our data across devices. @@ -237,15 +208,11 @@ Next, you need to tell `jax.jit` how to shard our data across devices. For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s. ```{code-cell} -:id: 8b913a2e57d3 - x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length) x = jax.device_put(x, x_sharding) jax.debug.visualize_array_sharding(x) ``` -+++ {"id": "06d134795ae1"} - ### The output's sharding You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. @@ -258,8 +225,6 @@ To achieve this, luckily, you don't have to hardcode the output's sharding by ha * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) annotations in the earlier definition to generate the correct sharding for the parameters. ```{code-cell} -:id: 19094ec63385 - def init_fn(k, x, model, optimizer): variables = model.init(k, x) # Initialize the model. state = train_state.TrainState.create( # Create a `TrainState`. @@ -270,8 +235,6 @@ def init_fn(k, x, model, optimizer): ``` ```{code-cell} -:id: e49264a3c78e - # Create an abstract closure to wrap the function before feeding it in # because `jax.eval_shape` only takes pytrees as arguments. abstract_variables = jax.eval_shape( @@ -283,8 +246,6 @@ state_sharding = nn.get_sharding(abstract_variables, mesh) state_sharding ``` -+++ {"id": "2ec24614050b"} - ## Compile the code Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) to your `init_fn`, but with two extra arguments: `in_shardings` and `out_shardings`. @@ -292,8 +253,6 @@ Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-ji Run it to get the `initialized_state`, in which parameters are sharded exactly as instructed: ```{code-cell} -:id: 5b6e699df733 - jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(None), x_sharding), # PRNG key and x out_shardings=state_sharding) @@ -306,8 +265,6 @@ jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Den jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` -+++ {"id": "8f74b009f11f"} - ## Inspect the Module output Note that in the output of `initialized_state`, the `params` `W1` and `W2` are of type [`flax.linen.Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Partitioned.html). This is a wrapper around the actual `jax.Array` that allows Flax to record the axis names associated with it. @@ -315,38 +272,26 @@ Note that in the output of `initialized_state`, the `params` `W1` and `W2` are o You can access the raw `jax.Array` by adding `.value` when outside `jit`, or by `.unbox()` when inside. ```{code-cell} -:id: 19243982c892 - print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) ``` -+++ {"id": "2beee7d27bdb"} - You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices. ```{code-cell} -:id: 2067c419a826 - initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding ``` ```{code-cell} -:id: d7cf0baa334b - print(initialized_state.step) initialized_state.step.sharding ``` -+++ {"id": "273547d3ab89"} - You can use [`jax.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays. ```{code-cell} -:id: 29b3dae156a2 - diff = jax.tree_map( lambda a, b: a - b, initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0']) @@ -356,15 +301,11 @@ print(type(diff_array)) print(diff_array.shape) ``` -+++ {"id": "f7e1ccb14c6b"} - ## Compile the train step and inference Create a `jit`ted training step as follows: ```{code-cell} -:id: 4e3cc300cfee - @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=state_sharding) def train_step(state, x): @@ -382,21 +323,15 @@ with mesh: ``` ```{code-cell} -:id: 91c6c2662c12 - print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` -+++ {"id": "2bae79e2e71b"} - Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`. ```{code-cell} -:id: c9264a48b9ee - @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=x_sharding) def apply_fn(state, x): @@ -410,15 +345,11 @@ print(y.shape) jax.debug.visualize_array_sharding(y) ``` -+++ {"id": "7daa9e6e6eb4"} - ## Profiling If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance: ```{code-cell} -:id: a68d7cb2eb89 - %%timeit def block_all(xs): @@ -429,8 +360,6 @@ with mesh: new_state = block_all(train_step(initialized_state, x)) ``` -+++ {"id": "51420b514d53"} - ## Logical axis annotation JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). @@ -442,8 +371,6 @@ The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to 2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_partitioning.html) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_constraint.html#flax-linen-with-logical-constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names. ```{code-cell} -:id: a26f85a9e772 - class LogicalDotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @@ -486,8 +413,6 @@ class LogicalMLP(nn.Module): return x ``` -+++ {"id": "0de93ec6cbd6"} - Now, initiate a model and try to figure out what sharding its `state` should have. To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.logical_to_mesh_sharding.html#flax-linen-logical-to-mesh-sharding) will convert them to the kind of sharding that the device mesh can understand. @@ -495,8 +420,6 @@ To allow the device mesh to take your model correctly, you need to decide which This allows you to change the rules and try out new partition layouts without modifying the model definition. ```{code-cell} -:id: 14db7a1e30fd - # Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`. rules = (('batch', 'data'), ('hidden', 'model')) @@ -514,19 +437,13 @@ print('sharding annotations are mesh-specific: ', logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec) ``` -+++ {"id": "58475fffb2de"} - You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above. ```{code-cell} -:id: 589ff774bb4c - state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0'] ``` ```{code-cell} -:id: 77e07a0ab309 - logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(None), x_sharding), # PRNG key and x out_shardings=logical_state_sharding) @@ -535,16 +452,12 @@ logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) ``` ```{code-cell} -:id: fb53bc20e0f9 - print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value) ``` -+++ {"id": "ae1754a3031d"} - ## When to use device axis / logical axis Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model: @@ -555,7 +468,7 @@ Choosing when to use a device or logical axis depends on how much you want to co * **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. -+++ {"id": "576bdd5cd782"} ++++ ## Save the data diff --git a/docs/guides/training_techniques/use_checkpointing.ipynb b/docs/guides/training_techniques/use_checkpointing.ipynb index e9728f429..ba9054feb 100644 --- a/docs/guides/training_techniques/use_checkpointing.ipynb +++ b/docs/guides/training_techniques/use_checkpointing.ipynb @@ -4,9 +4,7 @@ "attachments": {}, "cell_type": "markdown", "id": "6e9134fa", - "metadata": { - "id": "6e9134fa" - }, + "metadata": {}, "source": [ "# Save and load checkpoints\n", "\n", @@ -46,9 +44,7 @@ { "cell_type": "markdown", "id": "5a2f6aae", - "metadata": { - "id": "5a2f6aae" - }, + "metadata": {}, "source": [ "## Setup\n", "\n", @@ -59,9 +55,7 @@ "attachments": {}, "cell_type": "markdown", "id": "-icO30rwmKYj", - "metadata": { - "id": "-icO30rwmKYj" - }, + "metadata": {}, "source": [ "Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell." ] @@ -70,9 +64,7 @@ "cell_type": "code", "execution_count": 1, "id": "ArKLnsyGRxGv", - "metadata": { - "id": "ArKLnsyGRxGv" - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -83,9 +75,7 @@ "cell_type": "code", "execution_count": 2, "id": "SJT9DTxTytjn", - "metadata": { - "id": "SJT9DTxTytjn" - }, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -128,9 +118,7 @@ { "cell_type": "markdown", "id": "40d434cd", - "metadata": { - "id": "40d434cd" - }, + "metadata": {}, "source": [ "## Save checkpoints\n", "\n", @@ -144,7 +132,6 @@ "execution_count": 4, "id": "56dec3f6", "metadata": { - "id": "56dec3f6", "outputId": "f1856d96-1961-48ed-bb7c-cb63fbaa7567" }, "outputs": [ @@ -220,9 +207,7 @@ { "cell_type": "markdown", "id": "6fc59dfa", - "metadata": { - "id": "6fc59dfa" - }, + "metadata": {}, "source": [ "Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory.\n", "\n", @@ -233,9 +218,7 @@ "cell_type": "code", "execution_count": 5, "id": "61b12da2", - "metadata": { - "id": "0pp4QtEqW9k7" - }, + "metadata": {}, "outputs": [], "source": [ "from flax.training import orbax_utils\n", @@ -262,7 +245,6 @@ "execution_count": 6, "id": "d3686ea5", "metadata": { - "id": "T6T8V4UBXB1R", "outputId": "b7132933-566d-440d-c34e-c5468d87cbdc" }, "outputs": [ @@ -293,9 +275,7 @@ { "cell_type": "markdown", "id": "8ecbc4cc", - "metadata": { - "id": "OQkUOkHVW_4e" - }, + "metadata": {}, "source": [ "### With the legacy API\n", "\n", @@ -307,7 +287,6 @@ "execution_count": 7, "id": "4cdb35ef", "metadata": { - "id": "4cdb35ef", "outputId": "6d849273-15ce-4480-8864-726d1838ac1f" }, "outputs": [ @@ -336,9 +315,7 @@ { "cell_type": "markdown", "id": "6b658bd1", - "metadata": { - "id": "6b658bd1" - }, + "metadata": {}, "source": [ "## Restore checkpoints\n", "\n", @@ -352,7 +329,6 @@ "execution_count": 8, "id": "a807a9c1", "metadata": { - "id": "WgRJj3wjXIaN", "outputId": "b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" }, "outputs": [ @@ -425,9 +401,7 @@ { "cell_type": "markdown", "id": "c7fe3bc8", - "metadata": { - "id": "VKJrfSyLXGrc" - }, + "metadata": {}, "source": [ "### With the legacy API\n", "\n", @@ -441,7 +415,6 @@ "execution_count": 10, "id": "150b20a0", "metadata": { - "id": "150b20a0", "outputId": "85ffceca-f38d-46b8-e567-d9d38b7885f9" }, "outputs": [ @@ -474,9 +447,7 @@ { "cell_type": "markdown", "id": "987b981f", - "metadata": { - "id": "987b981f" - }, + "metadata": {}, "source": [ "## Restore with custom dataclasses\n", "\n", @@ -496,7 +467,6 @@ "execution_count": 11, "id": "58f42513", "metadata": { - "id": "58f42513", "outputId": "110c6b6e-fe42-4179-e5d8-6b92d355e11b" }, "outputs": [ @@ -647,9 +617,7 @@ "attachments": {}, "cell_type": "markdown", "id": "136a300a", - "metadata": { - "id": "136a300a" - }, + "metadata": {}, "source": [ "## Restore when checkpoint structures differ\n", "\n", @@ -667,7 +635,6 @@ "execution_count": 14, "id": "be65d4af", "metadata": { - "id": "be65d4af", "outputId": "4fe776f0-65f8-4fc4-d64a-990520b36dce" }, "outputs": [ @@ -705,9 +672,7 @@ { "cell_type": "markdown", "id": "379c2255", - "metadata": { - "id": "379c2255" - }, + "metadata": {}, "source": [ "It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given.\n", "\n", @@ -867,7 +832,6 @@ "execution_count": 17, "id": "29fd1e33", "metadata": { - "id": "29fd1e33", "outputId": "cdbb9247-d1eb-4458-aa83-8db0332af7cb" }, "outputs": [ @@ -986,9 +950,7 @@ { "cell_type": "markdown", "id": "a6b39501", - "metadata": { - "id": "a6b39501" - }, + "metadata": {}, "source": [ "## Asynchronized checkpointing\n", "\n", @@ -1006,7 +968,6 @@ "execution_count": 19, "id": "85be68a6", "metadata": { - "id": "85be68a6", "outputId": "aefce94c-8bae-4355-c142-05f2b61c39e2" }, "outputs": [ @@ -1062,9 +1023,7 @@ { "cell_type": "markdown", "id": "13e93db6", - "metadata": { - "id": "QpuTCeMVXOBn" - }, + "metadata": {}, "source": [ "If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead." ] @@ -1084,9 +1043,7 @@ { "cell_type": "markdown", "id": "bb0e03cd", - "metadata": { - "id": "13e93db6" - }, + "metadata": {}, "source": [ "## Multi-host/multi-process checkpointing\n", "\n", @@ -1101,9 +1058,7 @@ "cell_type": "code", "execution_count": 21, "id": "ubdUvyMrhD-1", - "metadata": { - "id": "ubdUvyMrhD-1" - }, + "metadata": {}, "outputs": [], "source": [ "from jax.sharding import PartitionSpec, NamedSharding\n", @@ -1184,9 +1139,7 @@ { "cell_type": "markdown", "id": "edc355ce", - "metadata": { - "id": "edc355ce" - }, + "metadata": {}, "source": [ "### With the legacy Flax: use `save_checkpoint_multiprocess`\n", "\n", @@ -1200,7 +1153,6 @@ "execution_count": 24, "id": "5d10039b", "metadata": { - "id": "5d10039b", "outputId": "901bb097-0899-479d-b9ae-61dae79e7057" }, "outputs": [ @@ -1230,7 +1182,6 @@ "execution_count": 25, "id": "a9f9724c", "metadata": { - "id": "a9f9724c", "outputId": "393c4a0e-8a8c-4ca6-c609-93c8bab38e75" }, "outputs": [ diff --git a/docs/guides/training_techniques/use_checkpointing.md b/docs/guides/training_techniques/use_checkpointing.md index 10b9d8fca..f6c6b58f0 100644 --- a/docs/guides/training_techniques/use_checkpointing.md +++ b/docs/guides/training_techniques/use_checkpointing.md @@ -10,7 +10,6 @@ jupyter: jupytext_version: 1.13.8 --- - # Save and load checkpoints This guide demonstrates how to save and load Flax checkpoints with [Orbax](https://github.com/google/orbax). @@ -45,24 +44,21 @@ For backward-compatibility, this guide shows the Orbax-equivalent calls in the F If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](https://github.com/google/orbax/blob/main/docs/checkpoint.md). - - + ## Setup Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation). - - + Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell. - -```python id="ArKLnsyGRxGv" +```python import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` -```python id="SJT9DTxTytjn" +```python from typing import Optional, Any import shutil @@ -86,15 +82,13 @@ if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run. ``` - ## Save checkpoints In Orbax and Flax, you can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html). This includes not only typical Python and NumPy containers, but also customized classes extended from [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass). That means you can store almost any data generated — not only your model parameters, but any arrays/dictionaries, metadata/configs, and so on. First, create a pytree with many data structures and containers, and play with it: - -```python id="56dec3f6" outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" +```python outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" # A simple model with one linear layer. key1, key2 = random.split(random.key(0)) x1 = random.normal(key1, (5,)) # A simple JAX array. @@ -121,13 +115,12 @@ ckpt ### With Orbax - + Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory. Note: An optional `save_args` is provided. This is recommended for performance speedups, as it bundles smaller arrays in your pytree to a single large file instead of multiple smaller files. - -```python id="0pp4QtEqW9k7" +```python from flax.training import orbax_utils orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() @@ -141,7 +134,7 @@ In addition, provide `orbax.checkpoint.CheckpointManagerOptions` that customizes `orbax.checkpoint.CheckpointManager` should be placed at the top-level outside your training steps to manage your saves. -```python id="T6T8V4UBXB1R" outputId="b7132933-566d-440d-c34e-c5468d87cbdc" +```python outputId="b7132933-566d-440d-c34e-c5468d87cbdc" options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True) checkpoint_manager = orbax.checkpoint.CheckpointManager( '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options) @@ -154,13 +147,11 @@ for step in range(5): os.listdir('/tmp/flax_ckpt/orbax/managed') # Because max_to_keep=2, only step 3 and 4 are retained ``` - ### With the legacy API And here's how to save with the legacy Flax checkpointing utilities (note that this provides less management features compared with `orbax.checkpoint.CheckpointManagerOptions`): - -```python id="4cdb35ef" outputId="6d849273-15ce-4480-8864-726d1838ac1f" +```python outputId="6d849273-15ce-4480-8864-726d1838ac1f" # Import Flax Checkpoints. from flax.training import checkpoints @@ -171,15 +162,13 @@ checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', keep=2) ``` - ## Restore checkpoints ### With Orbax In Orbax, call `.restore()` for either `orbax.checkpoint.PyTreeCheckpointer` or `orbax.checkpoint.CheckpointManager` to restore your checkpoint in the raw pytree format. - -```python id="WgRJj3wjXIaN" outputId="b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" +```python outputId="b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save') raw_restored ``` @@ -191,20 +180,17 @@ step = checkpoint_manager.latest_step() # step = 4 checkpoint_manager.restore(step) ``` - ### With the legacy API Note that with the migration to Orbax in progress, `flax.training.checkpointing.restore_checkpoint` can automatically identify whether a checkpoint is saved in the legacy Flax format or with an Orbax backend, and restore the pytree correctly. Therefore, adding `flax.config.update('flax_use_orbax_checkpointing', True)` won't hurt your ability to restore old checkpoints. Here's how to restore checkpoints using the legacy API: - -```python id="150b20a0" outputId="85ffceca-f38d-46b8-e567-d9d38b7885f9" +```python outputId="85ffceca-f38d-46b8-e567-d9d38b7885f9" raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None) raw_restored ``` - ## Restore with custom dataclasses ### With Orbax @@ -216,9 +202,8 @@ raw_restored This section demonstrates how to set up any custom Flax dataclass explicitly, and have the same structure as a saved checkpoint. Note: Data that was a JAX NumPy array (`jnp.array`) format will be restored as a NumPy array (`numpy.array`). This would not affect your work because JAX will [automatically convert](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) NumPy arrays to JAX arrays once the computation starts. - -```python id="58f42513" outputId="110c6b6e-fe42-4179-e5d8-6b92d355e11b" +```python outputId="110c6b6e-fe42-4179-e5d8-6b92d355e11b" empty_state = train_state.TrainState.create( apply_fn=model.apply, params=jax.tree_map(np.zeros_like, variables['params']), # values of the tree leaf doesn't matter @@ -244,7 +229,7 @@ checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', tar It's often recommended to refactor out the process of initializing a checkpoint's structure (for example, a [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state)), so that saving/loading is easier and less error-prone. This is because functions and complex objects like `apply_fn` and `tx` (optimizer) cannot be serialized into the checkpoint file and must be initialized by code. - + ## Restore when checkpoint structures differ During your development, your checkpoint structure will change when changing the model, adding/removing fields during tweaking, and so on. @@ -254,9 +239,8 @@ This section explains how to load old data to your new code. Below is a simple example — a `CustomTrainState` extended from `flax.training.train_state.TrainState` that contains an extra field called `batch_stats`. When working on a real-world model, you may need this when applying [batch normalization](https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html). Here, you store the new `CustomTrainState` as step 5, while step 4 contains the old/previous `TrainState`. - -```python id="be65d4af" outputId="4fe776f0-65f8-4fc4-d64a-990520b36dce" +```python outputId="4fe776f0-65f8-4fc4-d64a-990520b36dce" class CustomTrainState(train_state.TrainState): batch_stats: Any = None @@ -276,11 +260,10 @@ custom_save_args = orbax_utils.save_args_from_target(custom_ckpt) checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args}) ``` - It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given. Below are examples of a few common scenarios. - + ### Scenario 1: When a reference object is partial @@ -326,7 +309,7 @@ restored If you have already saved your checkpoints with the Orbax backend, you can use `orbax_transforms` to access this `transforms` argument in the Flax API. -```python id="29fd1e33" outputId="cdbb9247-d1eb-4458-aa83-8db0332af7cb" +```python outputId="cdbb9247-d1eb-4458-aa83-8db0332af7cb" # Save in the "Flax-with-Orbax" backend. flax.config.update('flax_use_orbax_checkpointing', True) checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', @@ -362,7 +345,6 @@ raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10)) flax.serialization.from_state_dict(custom_target, raw_state_dict) ``` - ## Asynchronized checkpointing Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training. @@ -372,9 +354,8 @@ You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://gi Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps. Whenever you want to explicitly wait until an async save is done, you can call `async_checkpointer.wait_until_finished()`. - -```python id="85be68a6" outputId="aefce94c-8bae-4355-c142-05f2b61c39e2" +```python outputId="aefce94c-8bae-4355-c142-05f2b61c39e2" # `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was # originally designed for multi-process large model checkpointing. # For Python notebooks or other single-process settings, just set up with `num_processes=1`. @@ -394,9 +375,7 @@ async_checkpointer.wait_until_finished() # Blocks until the checkpoint saving i async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target) ``` - If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead. - ```python async_checkpoint_manager = orbax.checkpoint.CheckpointManager( @@ -404,7 +383,6 @@ async_checkpoint_manager = orbax.checkpoint.CheckpointManager( async_checkpoint_manager.wait_until_finished() ``` - ## Multi-host/multi-process checkpointing JAX provides a few ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). To get started on JAX in multi-process settings, check out [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and the [distributed array guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). @@ -412,9 +390,8 @@ JAX provides a few ways to scale up your code on multiple hosts at the same time In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket. Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context. - -```python id="ubdUvyMrhD-1" +```python from jax.sharding import PartitionSpec, NamedSharding # Create an array sharded across multiple devices. @@ -454,15 +431,13 @@ async_checkpoint_manager.restore( 0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args}) ``` - ### With the legacy Flax: use `save_checkpoint_multiprocess` In legacy Flax, to save multi-process arrays, use [`flax.training.checkpoints.save_checkpoint_multiprocess()`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) in place of `save_checkpoint()` and with the same arguments. If your checkpoint is too large, you can specify `timeout_secs` in the manager and give it more time to finish writing. - -```python id="5d10039b" outputId="901bb097-0899-479d-b9ae-61dae79e7057" +```python outputId="901bb097-0899-479d-b9ae-61dae79e7057" async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50) checkpoints.save_checkpoint_multiprocess(ckpt_dir, mp_ckpt, @@ -472,7 +447,7 @@ checkpoints.save_checkpoint_multiprocess(ckpt_dir, orbax_checkpointer=async_checkpointer) ``` -```python id="a9f9724c" outputId="393c4a0e-8a8c-4ca6-c609-93c8bab38e75" +```python outputId="393c4a0e-8a8c-4ca6-c609-93c8bab38e75" mp_restored = checkpoints.restore_checkpoint(ckpt_dir, target=ref_ckpt, step=3, diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb index 62e1b1ea6..5c76c39a0 100644 --- a/docs/quick_start.ipynb +++ b/docs/quick_start.ipynb @@ -3,9 +3,7 @@ { "cell_type": "markdown", "id": "6eea21b3", - "metadata": { - "id": "6eea21b3" - }, + "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/getting_started.ipynb)\n", @@ -22,9 +20,7 @@ { "cell_type": "markdown", "id": "nwJWKIhdwxDo", - "metadata": { - "id": "nwJWKIhdwxDo" - }, + "metadata": {}, "source": [ "## 1. Install Flax" ] @@ -34,7 +30,6 @@ "execution_count": null, "id": "bb81587e", "metadata": { - "id": "bb81587e", "tags": [ "skip-execution" ] @@ -47,9 +42,7 @@ { "cell_type": "markdown", "id": "b529fbef", - "metadata": { - "id": "b529fbef" - }, + "metadata": {}, "source": [ "## 2. Loading data\n", "\n", @@ -62,14 +55,7 @@ "cell_type": "code", "execution_count": 48, "id": "bRlrHqZVXZvk", - "metadata": { - "executionInfo": { - "elapsed": 54, - "status": "ok", - "timestamp": 1673483483044 - }, - "id": "bRlrHqZVXZvk" - }, + "metadata": {}, "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS for MNIST\n", @@ -98,9 +84,7 @@ { "cell_type": "markdown", "id": "7057395a", - "metadata": { - "id": "7057395a" - }, + "metadata": {}, "source": [ "## 3. Define network\n", "\n", @@ -117,14 +101,7 @@ "cell_type": "code", "execution_count": 49, "id": "cbc079cd", - "metadata": { - "executionInfo": { - "elapsed": 53, - "status": "ok", - "timestamp": 1673483483208 - }, - "id": "cbc079cd" - }, + "metadata": {}, "outputs": [], "source": [ "from flax import linen as nn # Linen API\n", @@ -150,9 +127,7 @@ { "cell_type": "markdown", "id": "hy7iRu7_zlx-", - "metadata": { - "id": "hy7iRu7_zlx-" - }, + "metadata": {}, "source": [ "### View model layers\n", "\n", @@ -164,12 +139,6 @@ "execution_count": 50, "id": "lDHfog81zLQa", "metadata": { - "executionInfo": { - "elapsed": 103, - "status": "ok", - "timestamp": 1673483483427 - }, - "id": "lDHfog81zLQa", "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" }, "outputs": [ @@ -238,9 +207,7 @@ { "cell_type": "markdown", "id": "4b5ac16e", - "metadata": { - "id": "4b5ac16e" - }, + "metadata": {}, "source": [ "## 4. Create a `TrainState`\n", "\n", @@ -257,12 +224,6 @@ "execution_count": null, "id": "qXr7JDpIxGNZ", "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483483631 - }, - "id": "qXr7JDpIxGNZ", "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" }, "outputs": [], @@ -274,14 +235,7 @@ "cell_type": "code", "execution_count": 52, "id": "CJDaJNijyOji", - "metadata": { - "executionInfo": { - "elapsed": 1, - "status": "ok", - "timestamp": 1673483483754 - }, - "id": "CJDaJNijyOji" - }, + "metadata": {}, "outputs": [], "source": [ "from clu import metrics\n", @@ -293,9 +247,7 @@ { "cell_type": "markdown", "id": "8b86b5f1", - "metadata": { - "id": "8b86b5f1" - }, + "metadata": {}, "source": [ "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." ] @@ -304,14 +256,7 @@ "cell_type": "code", "execution_count": 53, "id": "7W0qf7FC9uG5", - "metadata": { - "executionInfo": { - "elapsed": 55, - "status": "ok", - "timestamp": 1673483483958 - }, - "id": "7W0qf7FC9uG5" - }, + "metadata": {}, "outputs": [], "source": [ "@struct.dataclass\n", @@ -323,9 +268,7 @@ { "cell_type": "markdown", "id": "f3ce5e4c", - "metadata": { - "id": "f3ce5e4c" - }, + "metadata": {}, "source": [ "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." @@ -335,14 +278,7 @@ "cell_type": "code", "execution_count": 54, "id": "e0102447", - "metadata": { - "executionInfo": { - "elapsed": 54, - "status": "ok", - "timestamp": 1673483484125 - }, - "id": "e0102447" - }, + "metadata": {}, "outputs": [], "source": [ "class TrainState(train_state.TrainState):\n", @@ -360,9 +296,7 @@ { "cell_type": "markdown", "id": "a15de484", - "metadata": { - "id": "a15de484" - }, + "metadata": {}, "source": [ "## 5. Training step\n", "\n", @@ -388,14 +322,7 @@ "cell_type": "code", "execution_count": 55, "id": "9b0af486", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483484293 - }, - "id": "9b0af486" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -415,9 +342,7 @@ { "cell_type": "markdown", "id": "0ff5145f", - "metadata": { - "id": "0ff5145f" - }, + "metadata": {}, "source": [ "## 6. Metric computation\n", "\n", @@ -428,14 +353,7 @@ "cell_type": "code", "execution_count": 56, "id": "961bf70b", - "metadata": { - "executionInfo": { - "elapsed": 53, - "status": "ok", - "timestamp": 1673483484460 - }, - "id": "961bf70b" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -453,9 +371,7 @@ { "cell_type": "markdown", "id": "497241c3", - "metadata": { - "id": "497241c3" - }, + "metadata": {}, "source": [ "## 7. Download data" ] @@ -464,14 +380,7 @@ "cell_type": "code", "execution_count": 57, "id": "bff5393e", - "metadata": { - "executionInfo": { - "elapsed": 515, - "status": "ok", - "timestamp": 1673483485090 - }, - "id": "bff5393e" - }, + "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", @@ -484,9 +393,7 @@ "attachments": {}, "cell_type": "markdown", "id": "809ae1a0", - "metadata": { - "id": "809ae1a0" - }, + "metadata": {}, "source": [ "## 8. Seed randomness\n", "\n", @@ -503,14 +410,7 @@ "cell_type": "code", "execution_count": 58, "id": "xC4MFyBsfT-U", - "metadata": { - "executionInfo": { - "elapsed": 59, - "status": "ok", - "timestamp": 1673483485268 - }, - "id": "xC4MFyBsfT-U" - }, + "metadata": {}, "outputs": [], "source": [ "tf.random.set_seed(0)" @@ -520,14 +420,7 @@ "cell_type": "code", "execution_count": 59, "id": "e4f6f4d3", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483485436 - }, - "id": "e4f6f4d3" - }, + "metadata": {}, "outputs": [], "source": [ "init_rng = jax.random.key(0)" @@ -536,9 +429,7 @@ { "cell_type": "markdown", "id": "80fbb60b", - "metadata": { - "id": "80fbb60b" - }, + "metadata": {}, "source": [ "## 9. Initialize the `TrainState`\n", "\n", @@ -550,14 +441,7 @@ "cell_type": "code", "execution_count": 60, "id": "445fcab0", - "metadata": { - "executionInfo": { - "elapsed": 56, - "status": "ok", - "timestamp": 1673483485606 - }, - "id": "445fcab0" - }, + "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.01\n", @@ -568,14 +452,7 @@ "cell_type": "code", "execution_count": 61, "id": "5221eafd", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483485777 - }, - "id": "5221eafd" - }, + "metadata": {}, "outputs": [], "source": [ "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", @@ -585,9 +462,7 @@ { "cell_type": "markdown", "id": "b1c00230", - "metadata": { - "id": "b1c00230" - }, + "metadata": {}, "source": [ "## 10. Train and evaluate\n", "\n", @@ -610,14 +485,7 @@ "cell_type": "code", "execution_count": 62, "id": "74295360", - "metadata": { - "executionInfo": { - "elapsed": 55, - "status": "ok", - "timestamp": 1673483485947 - }, - "id": "74295360" - }, + "metadata": {}, "outputs": [], "source": [ "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", @@ -628,14 +496,7 @@ "cell_type": "code", "execution_count": 63, "id": "cRtnMZuQFlKl", - "metadata": { - "executionInfo": { - "elapsed": 1, - "status": "ok", - "timestamp": 1673483486076 - }, - "id": "cRtnMZuQFlKl" - }, + "metadata": {}, "outputs": [], "source": [ "metrics_history = {'train_loss': [],\n", @@ -649,12 +510,6 @@ "execution_count": 64, "id": "2c40ce90", "metadata": { - "executionInfo": { - "elapsed": 17908, - "status": "ok", - "timestamp": 1673483504133 - }, - "id": "2c40ce90", "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" }, "outputs": [ @@ -716,9 +571,7 @@ { "cell_type": "markdown", "id": "gfsecJzvzgCT", - "metadata": { - "id": "gfsecJzvzgCT" - }, + "metadata": {}, "source": [ "## 11. Visualize metrics" ] @@ -728,12 +581,6 @@ "execution_count": 65, "id": "Zs5atiqIG9Kz", "metadata": { - "executionInfo": { - "elapsed": 358, - "status": "ok", - "timestamp": 1673483504621 - }, - "id": "Zs5atiqIG9Kz", "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" }, "outputs": [ @@ -776,9 +623,7 @@ { "cell_type": "markdown", "id": "qQbKS0tV3sZ1", - "metadata": { - "id": "qQbKS0tV3sZ1" - }, + "metadata": {}, "source": [ "## 12. Perform inference on test set\n", "\n", @@ -789,14 +634,7 @@ "cell_type": "code", "execution_count": 66, "id": "DFwxgBQf44ks", - "metadata": { - "executionInfo": { - "elapsed": 580, - "status": "ok", - "timestamp": 1673483505350 - }, - "id": "DFwxgBQf44ks" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -813,12 +651,6 @@ "execution_count": 67, "id": "5d5nF3u44JFI", "metadata": { - "executionInfo": { - "elapsed": 1250, - "status": "ok", - "timestamp": 1673483506723 - }, - "id": "5d5nF3u44JFI", "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" }, "outputs": [ @@ -844,9 +676,7 @@ { "cell_type": "markdown", "id": "edb528b6", - "metadata": { - "id": "edb528b6" - }, + "metadata": {}, "source": [ "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", "the same example, but structured differently as a couple of Python modules, test\n", diff --git a/docs/quick_start.md b/docs/quick_start.md index 0fe3f6312..e12dc8491 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -9,8 +9,6 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "6eea21b3"} - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/getting_started.ipynb) @@ -22,19 +20,16 @@ Flax is an open source Python neural network library built on top of [JAX](https network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train the network for image classification on the MNIST dataset. -+++ {"id": "nwJWKIhdwxDo"} ++++ ## 1. Install Flax ```{code-cell} -:id: bb81587e :tags: [skip-execution] !pip install -q flax>=0.7.5 ``` -+++ {"id": "b529fbef"} - ## 2. Loading data Flax can use any @@ -42,13 +37,6 @@ data-loading pipeline and this example demonstrates how to utilize TFDS. Define samples to floating-point numbers. ```{code-cell} ---- -executionInfo: - elapsed: 54 - status: ok - timestamp: 1673483483044 -id: bRlrHqZVXZvk ---- import tensorflow_datasets as tfds # TFDS for MNIST import tensorflow as tf # TensorFlow operations @@ -72,8 +60,6 @@ def get_datasets(num_epochs, batch_size): return train_ds, test_ds ``` -+++ {"id": "7057395a"} - ## 3. Define network Create a convolutional neural network with the Linen API by subclassing @@ -85,13 +71,6 @@ stacking layers—you can define the inlined submodules directly within the decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. ```{code-cell} ---- -executionInfo: - elapsed: 53 - status: ok - timestamp: 1673483483208 -id: cbc079cd ---- from flax import linen as nn # Linen API class CNN(nn.Module): @@ -112,21 +91,13 @@ class CNN(nn.Module): return x ``` -+++ {"id": "hy7iRu7_zlx-"} - ### View model layers Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. ```{code-cell} ---- -executionInfo: - elapsed: 103 - status: ok - timestamp: 1673483483427 -id: lDHfog81zLQa -outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da ---- +:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da + import jax import jax.numpy as jnp # JAX NumPy @@ -135,8 +106,6 @@ print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), compute_flops=True, compute_vjp_flops=True)) ``` -+++ {"id": "4b5ac16e"} - ## 4. Create a `TrainState` A common pattern in Flax is to create a single dataclass that represents the @@ -147,62 +116,31 @@ Because this is such a common pattern, Flax provides the class that serves most basic usecases. ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483483631 -id: qXr7JDpIxGNZ -outputId: 1249b7fb-6787-41eb-b34c-61d736300844 ---- +:outputId: 1249b7fb-6787-41eb-b34c-61d736300844 + !pip install -q clu ``` ```{code-cell} ---- -executionInfo: - elapsed: 1 - status: ok - timestamp: 1673483483754 -id: CJDaJNijyOji ---- from clu import metrics from flax.training import train_state # Useful dataclass to keep train state from flax import struct # Flax dataclasses import optax # Common loss functions and optimizers ``` -+++ {"id": "8b86b5f1"} - We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). ```{code-cell} ---- -executionInfo: - elapsed: 55 - status: ok - timestamp: 1673483483958 -id: 7W0qf7FC9uG5 ---- @struct.dataclass class Metrics(metrics.Collection): accuracy: metrics.Accuracy loss: metrics.Average.from_output('loss') ``` -+++ {"id": "f3ce5e4c"} - You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. ```{code-cell} ---- -executionInfo: - elapsed: 54 - status: ok - timestamp: 1673483484125 -id: e0102447 ---- class TrainState(train_state.TrainState): metrics: Metrics @@ -215,8 +153,6 @@ def create_train_state(module, rng, learning_rate, momentum): metrics=Metrics.empty()) ``` -+++ {"id": "a15de484"} - ## 5. Training step A function that: @@ -237,13 +173,6 @@ it with [XLA](https://www.tensorflow.org/xla) into fused device operations that run faster and more efficiently on hardware accelerators. ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483484293 -id: 9b0af486 ---- @jax.jit def train_step(state, batch): """Train for a single step.""" @@ -258,20 +187,11 @@ def train_step(state, batch): return state ``` -+++ {"id": "0ff5145f"} - ## 6. Metric computation Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. ```{code-cell} ---- -executionInfo: - elapsed: 53 - status: ok - timestamp: 1673483484460 -id: 961bf70b ---- @jax.jit def compute_metrics(*, state, batch): logits = state.apply_fn({'params': state.params}, batch['image']) @@ -284,26 +204,15 @@ def compute_metrics(*, state, batch): return state ``` -+++ {"id": "497241c3"} - ## 7. Download data ```{code-cell} ---- -executionInfo: - elapsed: 515 - status: ok - timestamp: 1673483485090 -id: bff5393e ---- num_epochs = 10 batch_size = 32 train_ds, test_ds = get_datasets(num_epochs, batch_size) ``` -+++ {"id": "809ae1a0"} - ## 8. Seed randomness - Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. @@ -315,60 +224,28 @@ train_ds, test_ds = get_datasets(num_epochs, batch_size) and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) ```{code-cell} ---- -executionInfo: - elapsed: 59 - status: ok - timestamp: 1673483485268 -id: xC4MFyBsfT-U ---- tf.random.set_seed(0) ``` ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483485436 -id: e4f6f4d3 ---- init_rng = jax.random.key(0) ``` -+++ {"id": "80fbb60b"} - ## 9. Initialize the `TrainState` Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics and puts them into the training state dataclass that is returned. ```{code-cell} ---- -executionInfo: - elapsed: 56 - status: ok - timestamp: 1673483485606 -id: 445fcab0 ---- learning_rate = 0.01 momentum = 0.9 ``` ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483485777 -id: 5221eafd ---- state = create_train_state(cnn, init_rng, learning_rate, momentum) del init_rng # Must not be used anymore. ``` -+++ {"id": "b1c00230"} - ## 10. Train and evaluate Create a "shuffled" dataset by: @@ -386,25 +263,11 @@ Define a training loop that: Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. ```{code-cell} ---- -executionInfo: - elapsed: 55 - status: ok - timestamp: 1673483485947 -id: '74295360' ---- # since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs ``` ```{code-cell} ---- -executionInfo: - elapsed: 1 - status: ok - timestamp: 1673483486076 -id: cRtnMZuQFlKl ---- metrics_history = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], @@ -412,14 +275,8 @@ metrics_history = {'train_loss': [], ``` ```{code-cell} ---- -executionInfo: - elapsed: 17908 - status: ok - timestamp: 1673483504133 -id: 2c40ce90 -outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 ---- +:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 + for step,batch in enumerate(train_ds.as_numpy_iterator()): # Run optimization steps over training batches and compute batch metrics @@ -447,19 +304,11 @@ for step,batch in enumerate(train_ds.as_numpy_iterator()): f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") ``` -+++ {"id": "gfsecJzvzgCT"} - ## 11. Visualize metrics ```{code-cell} ---- -executionInfo: - elapsed: 358 - status: ok - timestamp: 1673483504621 -id: Zs5atiqIG9Kz -outputId: 431a2fcd-44fa-4202-f55a-906555f060ac ---- +:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac + import matplotlib.pyplot as plt # Visualization # Plot loss and accuracy in subplots @@ -475,20 +324,11 @@ plt.show() plt.clf() ``` -+++ {"id": "qQbKS0tV3sZ1"} - ## 12. Perform inference on test set Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. ```{code-cell} ---- -executionInfo: - elapsed: 580 - status: ok - timestamp: 1673483505350 -id: DFwxgBQf44ks ---- @jax.jit def pred_step(state, batch): logits = state.apply_fn({'params': state.params}, test_batch['image']) @@ -499,14 +339,8 @@ pred = pred_step(state, test_batch) ``` ```{code-cell} ---- -executionInfo: - elapsed: 1250 - status: ok - timestamp: 1673483506723 -id: 5d5nF3u44JFI -outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e ---- +:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e + fig, axs = plt.subplots(5, 5, figsize=(12, 12)) for i, ax in enumerate(axs.flatten()): ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') @@ -514,8 +348,6 @@ for i, ax in enumerate(axs.flatten()): ax.axis('off') ``` -+++ {"id": "edb528b6"} - Congratulations! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax's Git repo: diff --git a/examples/imagenet/imagenet.ipynb b/examples/imagenet/imagenet.ipynb index 2614653ef..d271c9d11 100644 --- a/examples/imagenet/imagenet.ipynb +++ b/examples/imagenet/imagenet.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax Imagenet Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": 28, "metadata": { - "id": "ecyWhpr9X6tE", "outputId": "cb862d1a-2f71-444f-9770-9f0d53b11389" }, "outputs": [ @@ -72,7 +65,6 @@ "cell_type": "code", "execution_count": 29, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "80340396-77c2-4654-cc6d-67040f227eb9" }, "outputs": [ @@ -92,9 +84,7 @@ { "cell_type": "code", "execution_count": 30, - "metadata": { - "id": "SwX8bCNEGhJM" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/imagenet'\n", @@ -108,7 +98,6 @@ "execution_count": 31, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "9449a7b4-8a5d-4446-abe0-7886435ebd1c" }, "outputs": [ @@ -249,7 +238,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "acc1f45d-5062-4ff3-e6d4-10b4ffe0f8ef" }, "outputs": [], @@ -260,9 +248,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -272,7 +258,6 @@ "execution_count": 33, "metadata": { "cellView": "form", - "id": "4EzOChfJeVrU", "outputId": "9dc7fb32-331e-44a6-b6e8-830f6a64d845" }, "outputs": [ @@ -303,9 +288,7 @@ { "cell_type": "code", "execution_count": 34, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "import json\n", @@ -324,9 +307,7 @@ { "cell_type": "code", "execution_count": 35, - "metadata": { - "id": "7O2C7AY3p4ZF" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", @@ -356,7 +337,6 @@ "cell_type": "code", "execution_count": 36, "metadata": { - "id": "6Y1ru2Ovp4ZI", "outputId": "f943d165-b953-4a70-9f93-96eb857c3d53" }, "outputs": [ @@ -382,9 +362,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -393,7 +371,6 @@ "cell_type": "code", "execution_count": 37, "metadata": { - "id": "12KP_4h-_10s", "outputId": "a9b6cfe9-cc1c-451a-f8f7-69356cb7bdd2" }, "outputs": [ @@ -473,9 +450,7 @@ { "cell_type": "code", "execution_count": 38, - "metadata": { - "id": "UnuSCpoYBPKN" - }, + "metadata": {}, "outputs": [], "source": [ "# Utilities to help with Imagenette labels.\n", @@ -514,7 +489,6 @@ "cell_type": "code", "execution_count": 39, "metadata": { - "id": "EBibz3g905qt", "outputId": "78142300-cc8b-4a6c-f781-5ab29578d828" }, "outputs": [ @@ -542,7 +516,6 @@ "cell_type": "code", "execution_count": 40, "metadata": { - "id": "ccF8NVuX1Msk", "outputId": "8b3b9cf2-7649-4953-99bb-32a689fe0a29" }, "outputs": [ @@ -565,9 +538,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training from scratch" ] @@ -575,9 +546,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -591,7 +560,6 @@ "cell_type": "code", "execution_count": 42, "metadata": { - "id": "GGgHtVhIIuH7", "outputId": "2d0bc789-213d-4a34-a7b1-e7852b40f375" }, "outputs": [ @@ -638,7 +606,6 @@ "cell_type": "code", "execution_count": 43, "metadata": { - "id": "4bGmMCQd6S8U", "outputId": "de56d320-c336-459b-f258-5d6ae41ce0af" }, "outputs": [ @@ -667,7 +634,6 @@ "cell_type": "code", "execution_count": 44, "metadata": { - "id": "OBSJAvUqGgDq", "outputId": "018da2c5-c6f0-42ac-843f-7ac855a6bf14" }, "outputs": [ @@ -760,8 +726,7 @@ "cell_type": "code", "execution_count": 45, "metadata": { - "cellView": "form", - "id": "mZOKD0Y7p4ZW" + "cellView": "form" }, "outputs": [], "source": [ @@ -776,9 +741,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Load pre-trained model" ] @@ -787,7 +750,6 @@ "cell_type": "code", "execution_count": 46, "metadata": { - "id": "uKeJJJ5FJksQ", "outputId": "b06fa3d8-a950-46d2-e03e-fc6c971bdbd0" }, "outputs": [ @@ -815,9 +777,7 @@ { "cell_type": "code", "execution_count": 47, - "metadata": { - "id": "UCBikf4GvGuR" - }, + "metadata": {}, "outputs": [], "source": [ "# Load config that was used to train checkpoint.\n", @@ -829,7 +789,6 @@ "cell_type": "code", "execution_count": 48, "metadata": { - "id": "YfA4OnlyKe5x", "outputId": "57777298-4b4b-4a82-b0f2-4b6ff3b949af" }, "outputs": [ @@ -860,9 +819,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "HeMRgkbGiXo9" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -871,7 +828,6 @@ "cell_type": "code", "execution_count": 49, "metadata": { - "id": "i-7r57EkYJtc", "outputId": "793a656b-f3ad-4596-ad4f-44c686e5e885" }, "outputs": [ @@ -895,9 +851,7 @@ { "cell_type": "code", "execution_count": 50, - "metadata": { - "id": "KNTNZZJKYEHF" - }, + "metadata": {}, "outputs": [], "source": [ "# Evaluate using model trained on imagenet.\n", @@ -908,7 +862,6 @@ "cell_type": "code", "execution_count": 51, "metadata": { - "id": "ti55teFObTZW", "outputId": "6ab4bb0b-2c03-4663-d7ac-e51b979d121f" }, "outputs": [ @@ -934,7 +887,6 @@ "cell_type": "code", "execution_count": 52, "metadata": { - "id": "k5bKo731c98H", "outputId": "142c1acf-037e-4ab0-9ca3-bdf0829c51c4" }, "outputs": [ @@ -964,7 +916,6 @@ "cell_type": "code", "execution_count": 53, "metadata": { - "id": "2tEFrztxnh2B", "outputId": "4fae2533-5598-4f2e-c133-50bfba463311" }, "outputs": [ @@ -995,7 +946,6 @@ "cell_type": "code", "execution_count": 54, "metadata": { - "id": "SY3YQbgLgJe1", "outputId": "d01e1993-28ab-4a4a-ac58-01c83b80e6c9" }, "outputs": [ diff --git a/examples/mnist/mnist.ipynb b/examples/mnist/mnist.ipynb index 94b8412fe..3bc16317b 100644 --- a/examples/mnist/mnist.ipynb +++ b/examples/mnist/mnist.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax MNIST Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "8520b2f8-2b9d-4216-ba1f-d96175455bbc" }, "outputs": [ @@ -86,7 +79,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "id": "SwX8bCNEGhJM", "tags": [] }, "outputs": [], @@ -102,7 +94,6 @@ "execution_count": 3, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "2dfbdfa6-d213-4b5b-dc82-ee1765705255" }, "outputs": [ @@ -226,7 +217,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "e9061488-ac3e-4d23-f24f-06e1988e7541" }, "outputs": [ @@ -245,9 +235,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -255,9 +243,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import logging\n", @@ -273,9 +259,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "7O2C7AY3p4ZF" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", @@ -302,7 +286,6 @@ "cell_type": "code", "execution_count": 7, "metadata": { - "id": "6Y1ru2Ovp4ZI", "tags": [] }, "outputs": [], @@ -318,9 +301,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -329,7 +310,6 @@ "cell_type": "code", "execution_count": 8, "metadata": { - "id": "BRg0rNsJp4ZL", "outputId": "bb4525f4-8ca4-4e9d-d1cc-48a3e0533645", "tags": [] }, @@ -413,7 +393,6 @@ "cell_type": "code", "execution_count": 9, "metadata": { - "id": "B0LgjT3Vp4ZP", "outputId": "89de05b0-aede-414f-cf43-5e7c71871140" }, "outputs": [ @@ -439,9 +418,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -449,9 +426,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -465,7 +440,6 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "id": "RHoBUKSkp4ZS", "outputId": "a0eb78b5-ee73-4f4f-8400-41b521f42b75", "tags": [] }, @@ -532,7 +506,6 @@ "execution_count": 12, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "tags": [] }, "outputs": [], @@ -548,9 +521,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -559,7 +530,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "Q-45FkBLp4ZY", "outputId": "3af424f7-4433-475d-817c-5c0bbc4599ae" }, "outputs": [ @@ -587,7 +557,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "xIIvyz8Jp4Zb", "outputId": "949487f5-8aa2-45c8-9b54-efbf34ab58f1" }, "outputs": [ diff --git a/examples/ogbg_molpcba/ogbg_molpcba.ipynb b/examples/ogbg_molpcba/ogbg_molpcba.ipynb index c51a4f7f4..61baa503c 100644 --- a/examples/ogbg_molpcba/ogbg_molpcba.ipynb +++ b/examples/ogbg_molpcba/ogbg_molpcba.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "81wUkzl5gCUr" - }, + "metadata": {}, "source": [ "# Flax ogbg-molpcba Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "hfbxr1U9eciL" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -27,7 +23,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "cKmRTXhHdm_U", "outputId": "6508ab2f-b0e5-4693-f6a0-7bc495ec1344" }, "outputs": [ @@ -61,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "bdI9miDfEk9Y" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/ogbg_molpcba'\n", @@ -77,7 +70,6 @@ "execution_count": 3, "metadata": { "cellView": "form", - "id": "bCKbiylLgURG", "outputId": "8261a349-b41e-4e1b-a2ca-8d23412155be" }, "outputs": [ @@ -231,7 +223,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "ifRtigyGgZYk", "outputId": "14b17380-5077-4354-e651-027f3d933cfe" }, "outputs": [ @@ -251,9 +242,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "hvyohKMtelLG" - }, + "metadata": {}, "source": [ "## Imports" ] @@ -261,9 +250,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "bh18UXRVerEz" - }, + "metadata": {}, "outputs": [], "source": [ "# Base imports\n", @@ -281,9 +268,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "r-T0M1okfkrA" - }, + "metadata": {}, "outputs": [], "source": [ "# Local imports from current directory - auto reload.\n", @@ -298,18 +283,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "NFQxAbTWerTQ" - }, + "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", - "metadata": { - "id": "K6Soh8gkYKQB" - }, + "metadata": {}, "source": [ "TensorFlow Datasets supports customizable visualization of the ogbg_molpcba dataset." ] @@ -317,9 +298,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "id": "0ohDGSB_ZC0q" - }, + "metadata": {}, "outputs": [], "source": [ "# Visualization helpers\n", @@ -364,7 +343,6 @@ "cell_type": "code", "execution_count": 8, "metadata": { - "id": "U5jNKcD3YFsO", "outputId": "d9336190-e685-43e8-e3f1-73e3f1ce1cd2" }, "outputs": [ @@ -594,9 +572,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "CZJz7xoKevYn" - }, + "metadata": {}, "source": [ "## Training" ] @@ -604,9 +580,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Bha8sF3ne0mg" - }, + "metadata": {}, "outputs": [], "source": [ "# Start TensorBoard\n", @@ -621,7 +595,6 @@ "cell_type": "code", "execution_count": 10, "metadata": { - "id": "iVA3nVAth5Wh", "outputId": "7696aae4-beb5-4df2-b72a-7f391ac30c2e" }, "outputs": [ @@ -750,8 +723,7 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "cellView": "form", - "id": "L9BT3cGoiMNo" + "cellView": "form" }, "outputs": [], "source": [ @@ -767,9 +739,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "541fzDOQeyA0" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -777,9 +747,7 @@ { "cell_type": "code", "execution_count": 12, - "metadata": { - "id": "HMxkT1Lge1h5" - }, + "metadata": {}, "outputs": [], "source": [ "# Create deterministic evaluation model.\n", @@ -791,7 +759,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "NyDT6Ayp_s-G", "outputId": "51242dc2-5260-4279-b41a-9d7614b33c97" }, "outputs": [ @@ -822,7 +789,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "dhQSaZ5Z2sd6", "outputId": "2a529c9f-fed6-4221-a738-a8ffd9049d7e" }, "outputs": [ @@ -852,9 +818,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "id": "7AcUjIPN7pE8" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for formatting labels and predictions.\n", @@ -882,9 +846,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "JNQnsIFQtf92" - }, + "metadata": {}, "source": [ "We can choose one of the 128 different tasks and see how the model predictions\n", "match up with the true labels.\n", @@ -897,9 +859,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "id": "U-lXmVI4LBPE" - }, + "metadata": {}, "outputs": [], "source": [ "# Define which task to plot labels for.\n", @@ -910,7 +870,6 @@ "cell_type": "code", "execution_count": 17, "metadata": { - "id": "hDKR3yVIwOm3", "outputId": "38934096-13a6-4701-823a-ac83f3b7eaac" }, "outputs": [ diff --git a/examples/seq2seq/seq2seq.ipynb b/examples/seq2seq/seq2seq.ipynb index 6ecce1474..5cea63120 100644 --- a/examples/seq2seq/seq2seq.ipynb +++ b/examples/seq2seq/seq2seq.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax seq2seq Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "4c0a705c-8d7e-44cc-d851-873a40ac115e" }, "outputs": [ @@ -78,7 +71,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "SwX8bCNEGhJM", "tags": [] }, "outputs": [], @@ -94,7 +86,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "4801432e-4090-4b13-f0f2-d99a3039ce47" }, "outputs": [ @@ -230,7 +221,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "a292a7a2-ae3c-4518-af28-9c2fa0ed2d7b" }, "outputs": [ @@ -249,9 +239,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports" ] @@ -259,9 +247,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import app\n", @@ -277,7 +263,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "6Y1ru2Ovp4ZI", "outputId": "7e1a29ce-9d8b-4715-ce60-9eae100a1df3", "tags": [] }, @@ -303,9 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -314,7 +297,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xce4axo5Y9xp", "outputId": "cb5f7f6e-1e6f-40ff-e0d6-5b428511d75b" }, "outputs": [ @@ -343,7 +325,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "k_ZD70nIYlEq", "outputId": "b58ea813-e757-4cc5-f3ba-3cb0f05d35a6" }, "outputs": [ @@ -376,7 +357,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "UF19Nr2zZRQo", "outputId": "3b33e061-f0b5-42d7-ad49-5058e8fd3b90" }, "outputs": [ @@ -398,9 +378,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -408,9 +386,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -423,9 +399,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "LR9apE1dcFy0" - }, + "metadata": {}, "outputs": [], "source": [ "import time\n", @@ -436,7 +410,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "HgjiCPuAbZ5m", "outputId": "e49554e2-9336-4d97-a1e2-82b9e98407da" }, "outputs": [ @@ -464,7 +437,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "az3CUuNacBkS", "outputId": "49396889-35b0-4a11-8b8a-e67624be32a7" }, "outputs": [ @@ -598,7 +570,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "outputId": "2beaf4e9-b10b-4156-d2d9-187777306de0", "tags": [] }, @@ -655,9 +626,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -666,7 +635,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "hwi0ylrOgVKT", "outputId": "e22b7208-5413-4a63-abfb-b510af60f340" }, "outputs": [ @@ -690,9 +658,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "hNRtka4Ng61k" - }, + "metadata": {}, "outputs": [], "source": [ "# Using different random seeds generates different samples.\n", @@ -703,7 +669,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "2LWKWLyohTt8", "outputId": "e5cdfd75-2c66-4165-8ab7-9fdecde5062a" }, "outputs": [ diff --git a/examples/sst2/sst2.ipynb b/examples/sst2/sst2.ipynb index 10d05c98f..1a3d0e1bc 100644 --- a/examples/sst2/sst2.ipynb +++ b/examples/sst2/sst2.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax SST-2 Example\n", "\n", @@ -16,18 +14,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "qYA_sBsVn1SY" - }, + "metadata": {}, "source": [ "**Before you start:** Select Runtime -> Change runtime type -> GPU." ] }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -49,9 +43,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -59,9 +51,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "SwX8bCNEGhJM" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/sst2'\n", @@ -71,9 +61,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "o65RonwHp4Y9" - }, + "metadata": {}, "outputs": [], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", @@ -126,9 +114,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "xcXZ-F3_zBuJ" - }, + "metadata": {}, "outputs": [], "source": [ "# Note: In Colab, above cell changed the working directory.\n", @@ -138,9 +124,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "qgUlFbSy_9q_" - }, + "metadata": {}, "outputs": [], "source": [ "# Install SST-2 dependencies.\n", @@ -149,9 +133,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -159,9 +141,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "08kWwdKZYZtG" - }, + "metadata": {}, "outputs": [], "source": [ "# If you want to use TPU instead of GPU, you need to run this to make it work.\n", @@ -178,9 +158,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import logging\n", @@ -200,7 +178,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "6Y1ru2Ovp4ZI", "tags": [] }, "outputs": [], @@ -219,9 +196,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -230,7 +205,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "BRg0rNsJp4ZL", "tags": [] }, "outputs": [], @@ -243,9 +217,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -253,9 +225,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -269,7 +239,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "RHoBUKSkp4ZS", "tags": [] }, "outputs": [], @@ -286,7 +255,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "tags": [] }, "outputs": [], diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index 8949e0dca..327c92ad6 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -4,8 +4,7 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "colab_type": "code", - "id": "x0SPwYS9dtYA" + "colab_type": "code" }, "outputs": [], "source": [ @@ -19,8 +18,7 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "colab_type": "code", - "id": "7n9cxyCzluvI" + "colab_type": "code" }, "outputs": [], "source": [ @@ -31,8 +29,7 @@ "cell_type": "code", "execution_count": 3, "metadata": { - "colab_type": "code", - "id": "0L7YCrobkfzU" + "colab_type": "code" }, "outputs": [], "source": [ @@ -44,12 +41,6 @@ "execution_count": 4, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 1116, - "status": "ok", - "timestamp": 1590673431275 - }, - "id": "aDLGb3iGkjoL", "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "tags": [] }, @@ -114,12 +105,6 @@ "execution_count": 5, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 526, - "status": "ok", - "timestamp": 1590672865722 - }, - "id": "LTFjZbRmlqZh", "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "tags": [] }, @@ -178,12 +163,6 @@ "execution_count": 6, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 342, - "status": "ok", - "timestamp": 1590673618925 - }, - "id": "TMlae0hem0u5", "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733" }, "outputs": [ diff --git a/tests/colab_tpu_jax_version.ipynb b/tests/colab_tpu_jax_version.ipynb index 77a65bc81..7f55df00d 100644 --- a/tests/colab_tpu_jax_version.ipynb +++ b/tests/colab_tpu_jax_version.ipynb @@ -3,9 +3,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "6RTgYOrq2Mbp" - }, + "metadata": {}, "outputs": [], "source": [ "# JAX/jaxlib should be both 0.3.25\n", @@ -17,9 +15,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "0TI6hM1oU-y9" - }, + "metadata": {}, "outputs": [], "source": [ "# should show 8 TPU devices\n", @@ -42,9 +38,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "xBlI1bhd0QzN" - }, + "metadata": {}, "outputs": [], "source": [ "# in case JAX version has changed after the '!pip install`, below command should\n", @@ -56,9 +50,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Xf--m2iHVUgh" - }, + "metadata": {}, "outputs": [], "source": [ "# it's possible to get dependency tree without installing packages, but this\n",