Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 10, 2023
1 parent e566bb6 commit 0794889
Show file tree
Hide file tree
Showing 15 changed files with 242 additions and 1,184 deletions.
127 changes: 28 additions & 99 deletions docs/guides/flax_fundamentals/flax_basics.ipynb

Large diffs are not rendered by default.

71 changes: 6 additions & 65 deletions docs/guides/flax_fundamentals/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ jupytext:
jupytext_version: 1.13.8
---

+++ {"id": "yf-nWLh0naJi"}

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)

Expand All @@ -23,14 +21,13 @@ This notebook will walk you through the following workflow:
* Serialization of parameters and other objects.
* Creating your own models and managing state.

+++ {"id": "KyANAaZtbs86"}
+++

## Setting up our environment

Here we provide the code needed to set up the environment for our notebook.

```{code-cell}
:id: qdrEVv9tinJn
:outputId: e30aa464-fa52-4f35-df96-716c68a4b3ee
:tags: [skip-execution]
Expand All @@ -41,17 +38,13 @@ Here we provide the code needed to set up the environment for our notebook.
```

```{code-cell}
:id: kN6bZDaReZO2
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
```

+++ {"id": "pCCwAbOLiscA"}

## Linear regression with Flax

In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.
Expand All @@ -61,22 +54,17 @@ A dense layer is a layer that has a kernel parameter $W\in\mathcal{M}_{m,n}(\mat
This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`).

```{code-cell}
:id: zWX2zEtphT4Y
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)
```

+++ {"id": "UmzP1QoQYAAN"}

Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.

### Model parameters & initialization

Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data.

```{code-cell}
:id: K529lhzeYtl8
:outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5
key1, key2 = random.split(random.key(0))
Expand All @@ -85,8 +73,6 @@ params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
```

+++ {"id": "NH7Y9xMEewmO"}

*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*

The result is what we expect: bias and kernel parameters of the correct size. Under the hood:
Expand All @@ -96,19 +82,16 @@ The result is what we expect: bias and kernel parameters of the correct size. Un
* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.
* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`.

+++ {"id": "M1qo9M3_naJo"}
+++

To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:

```{code-cell}
:id: J8ietJecWiuK
:outputId: 7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae
model.apply(params, x)
```

+++ {"id": "lVsjgYzuSBGL"}

### Gradient descent

If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:
Expand All @@ -118,7 +101,6 @@ $$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(
Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example.

```{code-cell}
:id: bFIiMnL4dl-e
:outputId: 6eae59dc-0632-4f53-eac8-c22a7c646a52
# Set problem dimensions.
Expand All @@ -141,13 +123,9 @@ y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
```

+++ {"id": "ZHkioicCiUbx"}

We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees)).

```{code-cell}
:id: JqJaVc7BeNyT
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
Expand All @@ -159,12 +137,9 @@ def mse(params, x_batched, y_batched):
return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
```

+++ {"id": "wGKru__mi15v"}

And finally perform the gradient descent.

```{code-cell}
:id: ePEl1ndse0Jq
:outputId: 50d975b3-4706-4d8a-c4b8-2629ab8e3ac4
learning_rate = 0.3 # Gradient step size.
Expand All @@ -185,8 +160,6 @@ for i in range(101):
print(f'Loss step {i}: ', loss_val)
```

+++ {"id": "zqEnJ9Poyb6q"}

### Optimizing with Optax

Flax used to use its own `flax.optim` package for optimization, but with
Expand All @@ -212,16 +185,13 @@ to the
[official documentation](https://optax.readthedocs.io/en/latest/).

```{code-cell}
:id: Ce77uDJx1bUF
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
```

```{code-cell}
:id: PTSv0vx13xPO
:outputId: eec0c096-1d9e-4b3c-f8e5-942ee63828ec
for i in range(101):
Expand All @@ -232,14 +202,11 @@ for i in range(101):
print('Loss step {}: '.format(i), loss_val)
```

+++ {"id": "0eAPPwtpXYu7"}

### Serializing the result

Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that.

```{code-cell}
:id: BiUPRU93XnAZ
:outputId: b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c
from flax import serialization
Expand All @@ -251,35 +218,29 @@ print('Bytes output')
print(bytes_output)
```

+++ {"id": "eielPo2KZByd"}

To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.

*The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.*

```{code-cell}
:id: MOhoBDCOYYJ5
:outputId: 13acc4e1-8757-4554-e2c8-d594ba6e67dc
serialization.from_bytes(params, bytes_output)
```

+++ {"id": "8mNu8nuOhDC5"}

## Defining your own models

Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.

*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*

+++ {"id": "1sllHAdRlpmQ"}
+++

### Module basics

The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function.

```{code-cell}
:id: vbfrfbkxgPhg
:outputId: b59c679c-d164-4fd6-92db-b50f0d310ec3
class ExplicitMLP(nn.Module):
Expand Down Expand Up @@ -310,8 +271,6 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.
print('output:\n', y)
```

+++ {"id": "DDITIjXitEZl"}

As we can see, a `nn.Module` subclass is made of:

* A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.
Expand All @@ -324,7 +283,6 @@ As we can see, a `nn.Module` subclass is made of:
Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input:

```{code-cell}
:id: DEYrVA6dnaJu
:outputId: 4af16ec5-b52a-43b0-fc47-1f8ab25e7058
try:
Expand All @@ -333,12 +291,9 @@ except AttributeError as e:
print(e)
```

+++ {"id": "I__UrmShnaJu"}

Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:

```{code-cell}
:id: ZTCbdpQ4suSK
:outputId: 183a74ef-f54e-4848-99bf-fee4c174ba6d
class SimpleMLP(nn.Module):
Expand Down Expand Up @@ -366,22 +321,19 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.
print('output:\n', y)
```

+++ {"id": "es7YHjgexT-L"}

There are, however, a few differences you should be aware of between the two declaration modes:

* In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).
* If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.
* The last initialization will be handled differently. See these notes for more details (TODO: add notes link).

+++ {"id": "-ykceROJyp7W"}
+++

### Module parameters

In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:

```{code-cell}
:id: wK371Pt_vVfR
:outputId: 83b5fea4-071e-4ea0-8fa8-610e69fb5fd5
class SimpleDense(nn.Module):
Expand Down Expand Up @@ -410,8 +362,6 @@ print('initialized parameters:\n', params)
print('output:\n', y)
```

+++ {"id": "MKyhfzVpzC94"}

Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` :

* `name` is simply the name of the parameter that will end up in the parameter structure.
Expand All @@ -420,7 +370,7 @@ Here, we see how to both declare and assign a parameter to the model using the `

Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site.

+++ {"id": "QmSpxyqLDr58"}
+++

### Variables and collections of variables

Expand All @@ -434,7 +384,6 @@ However this is not enough to cover everything that we would need for machine le
For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py).

```{code-cell}
:id: J6_tR-nPzB1i
:outputId: 75465fd6-cdc8-497c-a3ec-7f709b5dde7a
class BiasAdderWithRunningMean(nn.Module):
Expand Down Expand Up @@ -463,12 +412,9 @@ y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
```

+++ {"id": "5OHBbMJng3ic"}

Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:

```{code-cell}
:id: IbTsCAvZcdBy
:outputId: 09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b
for val in [1.0, 2.0, 3.0]:
Expand All @@ -479,14 +425,11 @@ for val in [1.0, 2.0, 3.0]:
print('updated state:\n', updated_state) # Shows only the mutable part
```

+++ {"id": "GuUSOSKegKIM"}

From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.

*This example isn't doing anything and is only for demonstration purposes.*

```{code-cell}
:id: TUgAbUPpnaJw
:outputId: 0906fbab-b866-4956-d231-b1374415d448
from functools import partial
Expand Down Expand Up @@ -517,14 +460,12 @@ for _ in range(3):
print('Updated state: ', state)
```

+++ {"id": "eWUmx5EjtWge"}

Note that the above function has a quite verbose signature and it would not actually
work with `jax.jit()` because the function arguments are not "valid JAX types".

Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more.

+++ {"id": "_GL0PsCwnaJw"}
+++

### Exporting to Tensorflow's SavedModel with jax2tf

Expand Down
Loading

0 comments on commit 0794889

Please sign in to comment.