diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1b862d1ccb..3eed5f2ec4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -54,7 +54,7 @@ jobs: # one too high when executing this as a Github Action. if (( $diff > 6)); then echo "ERROR! More than 5 commits in PR -- please squash your commits." - url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pr + url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request echo "See $url for help on how to resolve this." exit 1 fi diff --git a/CHANGELOG.md b/CHANGELOG.md index 1be1757f7e..43f91d44d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -343,7 +343,7 @@ Other Improvements: - Add Adadelta optimizer - Fully deprecate all "pre-Linen" `flax.nn` classes and methods. - Some Module arguments can now be passed either as dataclass attribute or - as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html) + as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/guides/arguments.html) - Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`. See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns. - Support passing in modules directly as attributes to other modules, and diff --git a/docs/advanced_topics/arguments.md b/docs/advanced_topics/arguments.md deleted file mode 100644 index 40212e9aca..0000000000 --- a/docs/advanced_topics/arguments.md +++ /dev/null @@ -1,93 +0,0 @@ -# Dealing with Module Arguments - -## Introduction - -In Linen we can define `Module` arguments either as dataclass attributes or as arguments to methods (usually `__call__`). -Typically the distinction is clear: -* Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparamaters cannot share in a meaningful way. -* Dynamic properties, such as input data and top-level "mode switches" like `train=True/False`, should be passed as arguments to `__call__` or another method. - -Some cases are however less clear cut. Take for example the `Dropout` module. -We have a number of clear hyperparameters: - -1. The dropout rate -2. The axes for which a dropout mask is generated - -And some clear call time arguments: - -1. The input that should be masked using dropout -2. The (optional) rng used to sample the random mask - -There is however one property that is ambiguous -- the `deterministic` property in a Dropout module. - -If `deterministic` is `True` no dropout mask is sampled. This is typically used during model evaluation. -However, if we pass `eval=True` or `train=False` to a top-level Module. The `deterministic` argument needs -to be applied everywhere and the boolean argument needs to be passed down to all the layers that might use `Dropout`. -If instead `deterministic` is a dataclass attribute, we might do the following: - -```python -from functools import partial -from flax import linen as nn - -class ResidualModel(nn.Module): - drop_rate: float - - @nn.compact - def __call__(self, x, *, train): - dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train) - for i in range(10): - x += ResidualBlock(dropout=dropout, ...)(x) -``` - -It makes sense to pass `determinstic` to the constructor here because this way we can pass the dropout template to the sub-modules. -Now the sub-module no longer needs to take care of train vs eval mode and can simply use the `dropout` argument. -Note that because the dropout layer can only be constructed in the sub-module we can only partially apply `deterministic` to the constructor but not to `__call__`. - -However, if `deterministic` is a dataclass attribute we run into trouble when using the setup pattern. We would **want** to write our module code like this: - -```python -class SomeModule(nn.Module): - drop_rate: float - - def setup(self): - self.dropout = nn.Dropout(rate=self.drop_rate) - - @nn.compact - def __call__(self, x, *, train): - # ... - x = self.dropout(x, deterministic=not train) - # ... -``` - -But, as defined above, `deterministic` would be an attribute, so this doesn't work. -Here it makes sense to pass `deterministic` during `__call__` because it depends on the `train` argument. - -## Solution - -We can support both use cases described before by allowing certain properties to be passed -as dataclass attributes or as method argument (but not both!). -This can be implemented as follows: -```python -class MyDropout(nn.Module): - drop_rate: float - deterministic: Optional[bool] = None - - @nn.compact - def __call__(self, x, deterministic=None): - deterministic = nn.merge_param('deterministic', self.deterministic, deterministic) - # ... -``` - -In this example `nn.merge_param` will ensure that either `self.deterministic` or `deterministic` is set but not both. -An error is raised if both values are `None` or both values are not `None`. -This avoids confusing behavior where 2 different parts of the code set the same parameter and one is overruled by the other. -It also avoids a default value which would probably cause either the train step or eval step of a training procedure to be broken by default. - - - -## Functional Core - -Functional core defines functions rather than classes. -Therefore, there is no clear distinction between hyperparameters and call-time arguments. -The only way to pre-determine the hyperparameters is by using `partial`. -On the upside, there are no ambiguous cases where method arguments could also be attributes. diff --git a/docs/advanced_topics/contributing.md b/docs/advanced_topics/contributing.md deleted file mode 100644 index 366770fd47..0000000000 --- a/docs/advanced_topics/contributing.md +++ /dev/null @@ -1,262 +0,0 @@ -# How to contribute - -Everyone can contribute to Flax, and we value everyone's contributions. -You can contribute in many more ways than just writing code. Answering questions -on our [Discussions page](https://github.com/google/flax/discussions), helping -each other, and improving our documentation are extremely valuable to our -ecosystem. - -We also appreciate if you spread the word, for instance by starring our GitHub -repo, or referencing Flax in blog posts of projects that used it. - -This project follows -[Google's Open Source Community Guidelines](https://opensource.google/conduct/). - -## Ways to contribute - -We welcome pull requests (PRs), in particular for those issues -[marked as PR-ready](https://github.com/google/flax/issues?q=is%3Aopen+is%3Aissue+label%3A%22Status%3A+pull+requests+welcome%22). -For other proposals, you should first open a GitHub Issue or a GitHub Discussion to -start a conversation about your planned contribution. - -## Contributing code using Pull Requests - -We do all of our development using git, so basic knowledge is assumed. - -Follow these steps to contribute code: - -### Create a Pull Request in your own branch - -1. Fork the Flax repository by clicking the 'Fork' button on the - [repository page](http://www.github.com/google/flax). This creates a copy - of the Flax repository in your own account. - -2. Install [Python >=3.6](https://www.python.org/downloads/). - -3. (Optional) Create a virtual environment or a Docker container. See - [`dev/README.md`](https://github.com/google/flax/blob/main/dev/README.md) - for details on how to set up a Docker Container. To set up a virtual environment, - run the following: - - ```bash - python3 -m virtualenv env - . env/bin/activate - ``` - - This ensures all your dependencies are installed in this environment. - -4. Clone your local forked Flax repo, then install the required packages with [PyPi](https://pip.pypa.io/en/stable/cli/pip_install/). - This enables you to immediately test the code after modifying it: - - ```bash - git clone https://github.com/YOUR_USERNAME/flax - cd flax - pip install -e . - pip install ".[testing]" - pip install -r docs/requirements.txt - # install in editable mode again because docs/requirements.txt - # reinstalls project in non-editable mode - pip install -e . - ``` - -5. Set up pre-commit hooks, this will run some automated checks during each `git` commit and - possibly update some files that require changes. - - ```bash - pip install pre-commit - pre-commit install - ``` - -6. Add the Google Flax repo (not your fork) as an upstream remote, so you can use it to sync your - changes. - - ```bash - git remote add upstream http://www.github.com/google/flax - ``` - - -7. Create a branch where you will develop from: - - ```bash - git checkout -b name-of-change - ``` - -8. Implement your changes using your favorite editor (we recommend - [Visual Studio Code](https://code.visualstudio.com/)). - - Make sure the tests pass by running the following command from the top of - the repository: - - ```bash - ./tests/run_all_tests.sh - ``` - -9. Once your change is done, create a commit as follows - ([how to write a commit message](https://chris.beams.io/posts/git-commit/)): - - ```bash - git add file1.py file2.py ... - git commit -m "Your commit message" - ``` - - Then sync your code with the main repo: - - ```bash - git rebase upstream/main - ``` - -10. Finally push your commit on your development branch and create a remote - branch in your fork that you can use to create a Pull Request form: - - ```bash - git push --set-upstream origin name-of-change - ``` - - After running the command, you should see a GitHub link in your terminal output that you can click on to create a Pull Request. - If you do not see this link in the terminal after doing a `git push`, go to the GitHub web UI; there should be a button there that lets you turn the commit into a Pull Request yourself. - -11. Make sure your PR passes the - [PR checklist](https://github.com/google/flax/blob/main/.github/pull_request_template.md#checklist). - If so, create a Pull Request from the Flax repository and send it for review. - Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) - for more information on using Pull Requests. - -### Update notebooks - -We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of docs -in `docs/notebooks`: one in the Jupyter Notebook (`.ipynb`) format, and one in Markdown (`.md`). - -The former can be opened and executed directly in [Google Colab](https://colab.research.google.com/). -Markdown makes it easier to track changes/diffs within version control and, for example, GitHub -web UI, since `.ipynb` files are based on JSON. - -**NOTE**: If your notebook contains a cell that uses `pip` to install a package -you must add a `skip-execution` tag to that cell so `myst-nb` will skip the cell -when testing the notebooks. - -#### Editing Jupyter Notebooks (`.ipynb`) - -For making large changes that substantially modify code and outputs, it's recommended to edit -the notebooks in [Jupyter](https://jupyter.org/install) or in [Colab](https://colab.research.google.com/). - -If you choose to work in Colab, go to **File** and click **Upload notebook**, then pick your file. -After loading it into Colab and editing it, make sure you run the cells, and that there aren't any errors. -Click on **Runtime**, then select **Run all**. After you finish, click **File** > **Download** > **Download ipynb**. -You may also want to test that the file executes properly by using `sphinx-build`, as explained above. - -#### Editing Markdown files (`.md`) - -For making smaller changes to the text content of the notebooks, it is easiest to edit the -`.md` versions using a text editor. - -#### Syncing notebooks - -After editing either the `.ipynb` or `.md` versions of the docs, sync the two versions -using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` on the updated -notebooks - -First, make sure you have jupytext (version 1.13.8) installed. The jupytext version should match -the one specified in [.pre-commit-config.yaml](https://github.com/google/flax/blob/main/.pre-commit-config.yaml). - -``` -pip install jupytext==1.13.8 -``` - -Then, if you worked on a Jupyter Notebook document, sync the contents with its Markdown-equivalent -file by running the following command: - -``` -jupytext --sync path/to/the/file.ipynb -``` - -Similarly, to sync your Markdown file with its Jupyter Notebook version, run: - -``` -jupytext --sync path/to/the/file.md -``` - -To check that the `.md` and `.ipynb` files are properly synced, you can also use the -[pre-commit](https://pre-commit.com/) framework to perform the same checks used -in the GitHub CI: - -``` -git add docs -u # pre-commit runs on files in git staging. -pre-commit run jupytext -``` - -#### Creating new notebooks - -If you are adding a new Jupyter Notebook to the documentation, you can use `jupytext --set-formats`. -It can set up both the Jupyter Notebook (`.ipynb`) and Markdown (`.md`) versions of the file: - -``` -jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb -``` - -This works by adding a `"jupytext"` metadata field to the notebook file which specifies the -desired formats. The `jupytext --sync` command can then recognize them when invoked. - -After you make changes in your file(s), follow the steps from the _Syncing notebooks_ -section above to keep the contents of both Markdown and Jupyter Notebook files in sync. - -#### Notebooks within the sphinx build - -Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build. -The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). -You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else -re-saves the notebook. - -We exclude some notebooks from the build because, for example, they contain long computations. -See `exclude_patterns` in [`conf.py`](https://github.com/google/flax/blob/main/docs/conf.py). - -### Updating the Pull Request contents - -Every Pull Request should ideally be limited to just one commit, so if you have multiple commits please squash them. - -Assuming you now have only one commit in your Pull Request, and want to add changes requested during review: - -1. Make the changes locally in your editor. -2. Run `git commit -a --amend`. This updates the commit contents and allows you to edit the commit message. -3. At this point, `git push` alone will result in an error. Instead, use `git push --force`. -4. Check that it's done: The changes to your commit should be immediately reflected in the Github web UI. - -## Troubleshooting - -### Too many commits in a PR - -If your PR has too many commits associated with it, then our build process may -fail with an error message. This is because of two reasons: - -* We prefer to keep our commit history clean. - -* Our source sync process will fail if our commit tree is too large. - -If you encounter this error message, you should squash your commits. To -rebase your branch to `main` and create a new commit containing all your -changes, run the following command: - -```bash -git rebase main && git reset --soft main && git commit -``` - -This will apply all your changes to the main branch. Note that if you had to -resolve any conflicts while working on your change (for instance, you did a -`pull upstream main` which led to conflict), then you will have to resolve these -conflicts again. - -After you have successfully rebased your branch, you should push your changes. -And because you changed the commit history, you may have to use `git push --force`. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution; -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. diff --git a/docs/advanced_topics/convert_pytorch_to_flax.rst b/docs/advanced_topics/convert_pytorch_to_flax.rst deleted file mode 100644 index d7b2f04599..0000000000 --- a/docs/advanced_topics/convert_pytorch_to_flax.rst +++ /dev/null @@ -1,276 +0,0 @@ -Convert PyTorch Models to Flax -============================== - -.. testsetup:: - - import numpy as np - import jax - from jax import random, numpy as jnp - import flax - - from flax import linen as nn - - import torch - -We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling. - - -FC Layers --------------------------------- - -Let's start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC] -and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick. - -.. testcode:: - - t_fc = torch.nn.Linear(in_features=3, out_features=4) - - kernel = t_fc.weight.detach().cpu().numpy() - bias = t_fc.bias.detach().cpu().numpy() - - # [outC, inC] -> [inC, outC] - kernel = jnp.transpose(kernel, (1, 0)) - - key = random.PRNGKey(0) - x = random.normal(key, (1, 3)) - - variables = {'params': {'kernel': kernel, 'bias': bias}} - j_fc = nn.Dense(features=4) - j_out = j_fc.apply(variables, x) - - t_x = torch.from_numpy(np.array(x)) - t_out = t_fc(t_x) - t_out = t_out.detach().cpu().numpy() - - np.testing.assert_almost_equal(j_out, t_out) - - -Convolutions --------------------------------- - -Let's now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC. -Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW] -and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick. - -.. testcode:: - - t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') - - kernel = t_conv.weight.detach().cpu().numpy() - bias = t_conv.bias.detach().cpu().numpy() - - # [outC, inC, kH, kW] -> [kH, kW, inC, outC] - kernel = jnp.transpose(kernel, (2, 3, 1, 0)) - - key = random.PRNGKey(0) - x = random.normal(key, (1, 6, 6, 3)) - - variables = {'params': {'kernel': kernel, 'bias': bias}} - j_conv = nn.Conv(features=4, kernel_size=(2, 2), padding='valid') - j_out = j_conv.apply(variables, x) - - # [N, H, W, C] -> [N, C, H, W] - t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) - t_out = t_conv(t_x) - # [N, C, H, W] -> [N, H, W, C] - t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) - - np.testing.assert_almost_equal(j_out, t_out, decimal=6) - - - -Convolutions and FC Layers --------------------------------- - -We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). -In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then -reshaped to [N, C * H * W] before being fed to the fc layers. -When we port our weights from PyToch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. -Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. - -Consider this PyTorch model: - -.. testcode:: - - class TModel(torch.nn.Module): - - def __init__(self): - super(TModel, self).__init__() - self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') - self.fc = torch.nn.Linear(in_features=100, out_features=2) - - def forward(self, x): - x = self.conv(x) - x = x.reshape(x.shape[0], -1) - x = self.fc(x) - return x - - - t_model = TModel() - - - -Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this: - - -.. testcode:: - - class JModel(nn.Module): - - @nn.compact - def __call__(self, x): - x = nn.Conv(features=4, kernel_size=(2, 2), padding='valid', name='conv')(x) - # [N, H, W, C] -> [N, C, H, W] - x = jnp.transpose(x, (0, 3, 1, 2)) - x = jnp.reshape(x, (x.shape[0], -1)) - x = nn.Dense(features=2, name='fc')(x) - return x - - - j_model = JModel() - - - -The model looks very similar to the PyTorch model, except that we included a transpose operation before -reshaping our activations for the fc layer. -We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1. - -Other than the transpose operation before reshaping, we can convert the weights the same way as we did before: - - -.. testcode:: - - conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy() - conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy() - fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy() - fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy() - - # [outC, inC, kH, kW] -> [kH, kW, inC, outC] - conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0)) - - # [outC, inC] -> [inC, outC] - fc_kernel = jnp.transpose(fc_kernel, (1, 0)) - - variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias}, - 'fc': {'kernel': fc_kernel, 'bias': fc_bias}}} - - key = random.PRNGKey(0) - x = random.normal(key, (1, 6, 6, 3)) - - j_out = j_model.apply(variables, x) - - # [N, H, W, C] -> [N, C, H, W] - t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) - t_out = t_model(t_x) - t_out = t_out.detach().cpu().numpy() - - np.testing.assert_almost_equal(j_out, t_out, decimal=6) - - - -Batch Norm --------------------------------- - -``torch.nn.BatchNorm2d`` uses ``0.1`` as the default value for the ``momentum`` parameter while -|nn.BatchNorm|_ uses ``0.9``. However, this corresponds to the same computation, because PyTorch multiplies -the estimated statistic with ``(1 − momentum)`` and the new observed value with ``momentum``, -while Flax multiplies the estimated statistic with ``momentum`` and the new observed value with ``(1 − momentum)``. - -.. |nn.BatchNorm| replace:: ``nn.BatchNorm`` -.. _nn.BatchNorm: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html - -.. testcode:: - - t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1) - t_bn.eval() - - scale = t_bn.weight.detach().cpu().numpy() - bias = t_bn.bias.detach().cpu().numpy() - mean = t_bn.running_mean.detach().cpu().numpy() - var = t_bn.running_var.detach().cpu().numpy() - - variables = {'params': {'scale': scale, 'bias': bias}, - 'batch_stats': {'mean': mean, 'var': var}} - - key = random.PRNGKey(0) - x = random.normal(key, (1, 6, 6, 3)) - - j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True) - - j_out = j_bn.apply(variables, x) - - # [N, H, W, C] -> [N, C, H, W] - t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) - t_out = t_bn(t_x) - # [N, C, H, W] -> [N, H, W, C] - t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) - - np.testing.assert_almost_equal(j_out, t_out) - - - -Average Pooling --------------------------------- - -``torch.nn.AvgPool2d`` and |nn.avg_pool()|_ are compatible when using default parameters. -However, ``torch.nn.AvgPool2d`` has a parameter ``count_include_pad``. When ``count_include_pad=False``, -the zero-padding will not be considered for the average calculation. There does not exist a similar -parameter for |nn.avg_pool()|_. However, we can easily implement a wrapper around the pooling -operation. ``nn.pool()`` is the core function behind |nn.avg_pool()|_ and |nn.max_pool()|_. - -.. |nn.avg_pool()| replace:: ``nn.avg_pool()`` -.. _nn.avg_pool(): https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.avg_pool.html - -.. |nn.max_pool()| replace:: ``nn.max_pool()`` -.. _nn.max_pool(): https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.max_pool.html - - -.. testcode:: - - def avg_pool(inputs, window_shape, strides=None, padding='VALID'): - """ - Pools the input by taking the average over a window. - In comparison to nn.avg_pool(), this pooling operation does not - consider the padded zero's for the average computation. - """ - assert len(window_shape) == 2 - - y = nn.pool(inputs, 0., jax.lax.add, window_shape, strides, padding) - counts = nn.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding) - y = y / counts - return y - - - key = random.PRNGKey(0) - x = random.normal(key, (1, 6, 6, 3)) - - j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1))) - t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False) - - # [N, H, W, C] -> [N, C, H, W] - t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) - t_out = t_pool(t_x) - # [N, C, H, W] -> [N, H, W, C] - t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) - - np.testing.assert_almost_equal(j_out, t_out) - - - -Transposed Convolutions --------------------------------- - -``torch.nn.ConvTranspose2d`` and |nn.ConvTranspose|_ are not compatible. -|nn.ConvTranspose|_ is a wrapper around |jax.lax.conv_transpose|_ which computes a fractionally strided convolution, -while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolution. Currently, there is no -implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ -that contains an implementation. - -.. _`pull request`: https://github.com/google/jax/pull/5772 - -.. |nn.ConvTranspose| replace:: ``nn.ConvTranspose`` -.. _nn.ConvTranspose: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.ConvTranspose.html - -.. |jax.lax.conv_transpose| replace:: ``jax.lax.conv_transpose`` -.. _jax.lax.conv_transpose: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html - diff --git a/docs/advanced_topics/index.rst b/docs/advanced_topics/index.rst deleted file mode 100644 index 212aeb3d19..0000000000 --- a/docs/advanced_topics/index.rst +++ /dev/null @@ -1,22 +0,0 @@ -Advanced Topics -=============== - -.. toctree:: - :maxdepth: 1 - - arguments - module_lifecycle - lift - convert_pytorch_to_flax - optax_update_guide - linen_upgrade_guide - linen_design_principles - - -.. toctree:: - :maxdepth: 1 - :caption: Contributing - - contributing - philosophy - FLIPs diff --git a/docs/advanced_topics/lift.md b/docs/advanced_topics/lift.md deleted file mode 100644 index 47f92f7398..0000000000 --- a/docs/advanced_topics/lift.md +++ /dev/null @@ -1,365 +0,0 @@ -# Lifted Transformations - -⚠️ Advanced topic ⚠️ - -This design note explains the underlying implementation of `flax.linen.transform`, which enables JAX transformations inside `Module`s. - - -## Introduction - -JAX uses a functional API meaning that it only guarantees correct behavior when using functions without side effects ([JAX docs](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#differences-from-numpy)). -Typically, these side effects are the result of mutating an object that lives outside the function. - -The functional paradigm has some advantages like the ability to explicitly reason about state and stochasticity. -The function output only changes when an input argument changes. -Therefore, a function is guaranteed to behave deterministically. - -But pure functions offer another big advantage to JAX: specifically, they enable functional transformations. -For example `jax.vmap(f)` will vectorize a function `f`. -Because `f` cannot have side effects the vectorized/parallel version of `f` is well-defined. To see why we need this restriction, consider what happens if `f` would increment a counter or draw a random number. -Would `f` draw the same or a different random number for each item in the vector? -Would each item in the batch have its own counter or is the counter shared among the items? -And in what order is the counter incremented if `f` is computed in parallel? -The answer to all these questions is "it depends". -The behavior is ambiguous and the functional constraint elegantly avoids this problem. - -Flax introduces a safe way to have limited randomness and stateful variables in a JAX-compatible form. -The reason why the state in Flax is not problematic is because it is local: inside a Flax `Module` there are variables and PRNG sequences, -but on the outside there are only JAX Arrays and PRNG keys. - -For most use cases, Flax is used to define models in a stateful way. -Because a `Module` behaves like a pure function externally, we can fully utilize JAX with all of its transformations. -There are, however, cases when we want to have the best of both worlds by using transformations and `Module` together. -This design note explains how we extend JAX's functional transformation to work on `Module`s that have internal state and randomness. - - -## Functionalization - -Before we jump into the details let's consider a simple example where we would like to use `vmap` inside a `Module`. - -First, we define a simple MLP without any transformations: - -```python -import jax -from jax import random, numpy as jnp -from flax import linen as nn - -class MLP(nn.Module): - @nn.compact - def __call__(self, xs): - h = nn.Dense(4, name='hidden')(xs) - h = nn.relu(h) - return nn.Dense(1, name='out')(h) -``` - -Now what if we want to have separate MLP parameters for each item in `xs`? -If this were "vanilla JAX" we could imagine writing something like `jax.vmap(apply_mlp)(mlp_params, xs)`. -But doing something like this in Linen will actually fail: - -```python -class NaiveVmapMLP(nn.Module): - @nn.compact - def __call__(self, xs): - mlp = MLP() - return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs) # fails -``` - -JAX will raise an error when `vmap` is used on `mlp` because it's not a JAX array or a simple container of arrays. -We can not really blame JAX for refusing to perform this under-specified job. -After all, it's not even clear what should happen here. -The parameters inside the MLP are not even initialized yet and we will need a separate PRNG key for each group of parameters. -`jax.vmap` can only broadcast or map over an axis but it cannot automatically split an PRNG key. -Therefore, we have to call `jax.random.split` manually. - -We can fix this problem by first turning `MLP` into a pure init and apply function. -Afterwards, we use the `param` method to store the parameters: - -```python -class ManualVmapMLP(nn.Module): - @nn.compact - def __call__(self, xs): - mlp = MLP(parent=None) - init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params'] - apply_fn = jax.vmap(mlp.apply, in_axes=0) - mlp_params = self.param('mlp', init_fn, xs) - return apply_fn({'params': mlp_params}, xs) - -xs = jnp.ones((3, 4)) -variables = ManualVmapMLP().init(random.PRNGKey(0), xs) -print(jax.tree_util.tree_map(jnp.shape, variables['params'])) -"""==> -{ - mlp: { - hidden: { - bias: (3, 4), - kernel: (3, 4, 4), - }, - out: { - bias: (3, 1), - kernel: (3, 4, 1), - }, - }, -} -""" -``` - -Here, `MLP(parent=None)` creates a detached instance of `MLP`. -This avoids reserving a name for the submodule inside the current module. -Although not strictly necessary, this also ensures we cannot accidentally use the MLP instance in a stateful way and we are forced to use it through either `.init` or `.apply`. - -This example is still relatively concise but it already takes a few extra "bookkeeping" statements to make it work. -However, this implementation has a number of limitations: -1. During initialization, we call the submodule twice through `init_fn` and `apply_fn`. If the submodule used the same trick to do - functional transformation we will end up executing a lot of code as the number of module calls grows like 2^d where d is the number of - nested function transformations. -2. The implementation assumes the submodule only requires the parameter RNG sequence. -3. The implementation assumes we only create variables in the "params" collection during `init`. However, it does not support other variable collections and creating/updating variables in `apply`. - -Point 3 in particular makes manual functionalization cumbersome. -Feel free to try and extend the above example with a `nn.BatchNorm` layer in the `MLP` module. -This will require dealing with some additional complexity like storing the updated batch stats and making sure the batch stats are not mutable inside `vmap` when it should be immutable (e.g.: eval mode). - - -We call the process of transforming a stateful Module into a pure function "functionalization". -By temporarily turning a stateful `Module` into a function we make it compatible with JAX's functional transformations. - -## Lifting - -Flax provides an alternative for manual functionalization which we call lifted transformation. -Lifted transformations are defined in `flax.core.lift`. -All the lifted JAX transformations are defined with a single generic lifting API called `pack`. - -A number of decisions had to be made in order to define `pack`. The implementation -of `pack` controls how variables and rngs are lifted and how fine-grained the user control is. -It must also decide whether lifting decisions are made at variable or transformation definition. - - -### Lifting granularity - - -With the Linen API, users can define arbitrary variable collections and PRNG sequences. -Each variable in a collection is lifted in the same way. - -Collections are typically given a semantically meaningful name like "params" or "batch_stats" rather than a general purpose name like "state". -Because collections carry semantic meaning we can decide at the transformation level how each collection should be lifted. -For example, we want to share all parameter variables when we add a batch dimension to a model. - -At the same time we can write generic code that uses transformations without knowing exactly what kind of variables the submodules will create. -Collections thus strike a balance between fine-grained control and generality. -We also avoid brittle string matching code that loops over all variables and tries to split up collections in an ad-hoc way based on -naming conventions like: target all variables with the name prefix "kernel". -If more fine-grained control is necessary a user can simply split up a set of variables over multiple collections that should be handled differently. - - -### Transformation vs variable control - - -Lifting behavior could be defined either at the transformation level or during variable definition. -We use transformation level definitions of lifting behavior. -The reason for this choice is that there are many different transformations with various behaviors. -For example: `vmap` has broadcasted and vectorized arguments, while `scan` has scan, carry, and broadcast arguments. -A variable would have to define its behavior for all these transformations otherwise a `Module` would not be compatible with -these transformations. Alternatively, we would have to make default decisions for how transformations are handled. -However, this could lead to silent bugs because the behavior might not actually be valid given the users intent. - -The lift package also provides a general purpose `transform`, which allows an arbitrary function to transform a variable collection. -For example, this can be used to tie the weights in a tied auto-encoder by transposing the weights. -It is unclear whether a similar general purpose transform could be defined if lifting decisions were made at variable definition. - - -### Linen - -The lifting module does not know about the Linen `Module` API. -Instead it operates directly on instances of `flax.core.Scope`. -A `Scope` instance contains the variables and PRNG sequences of a `Module`. -Each `Module` instance has a `Scope` instance in the `.scope` field if it has a parent or it was created using `init` or `apply`. -Typically, the top-level `Module` instance — on which you call `init` or `apply` — is the only `Module` instance that does not have a `Scope` bound to it. - -When a `Module` is transformed, we use the `flax.core.lift` APIs to lift the scope and use `Module.clone()` to create a new `Module` instance with the lifted scope bound to it. - -`flax.linen.transforms` exposes wrappers for the transformations in `flax.core.lift`. The core lifting APIs operate on functions while -the Linen wrappers can transform either a `Module` class or a `Module` method. - -Thus, lifting is implemented independently from the Linen API. This separation of concern simplifies the implementation, while potentially allowing alternative `Module` abstractions to build upon a common core for lifting and state management. - - -### Implementation - -The `pack(fn, in_vars, out_vars, rngs)` API goes through the following stages: - - -1. *Scope de-duplication* - - This stage is only relevant if multiple Scopes are lifted together. - In this case we must first find the set of root scopes. - A scope is a root if none of its ancestors are in the set of scopes that need to be lifted. - - By only lifting roots we avoid lifting the same variables twice. - - For non-root scopes we store a reference to its ancestor scope and a path such that we can later reconstruct it (stage 4). - -2. *Filter stage* - - Variables and PRNG sequences are split up into groups. This way `fn` can lift each group into the transformation separately. - A group is defined by a filter specified as: - - a list of collections/prng names - - `True` (match everything) - - `False` (match nothing) - - `DenyList(filter)` (match everything but the specified collections (e.g.: `DenyList(['params'])` matches everything except the 'params' collection.)). - - A collection or PRNG sequence can only be put into a single group. If a collection matches multiple filters, it will be put into the first group with a matching filter. - If a collection or PRNG sequence does not match any filter it will not be lifted. - This means that it cannot be used inside the transformation and attempting to do this will cause an error to be raised. - For example, `in_vars = (["params"], True)` will cause the "params" collection to be put in the first group and all other collection to be put in the second group. - - For each PRNG sequence that is matched we seed a new PRNG sequence by calling `make_rng`. - This avoids the need to update the PRNG state after the lifted transformation is complete. - -3. *Transform-specific lifting* - - `fn` is called with the variable and PRNG groups. - JAX transforms have varying signatures and lifting options. Arguably the cleanest example is `vmap`. - In the case of vmap the function arguments, PRNGs and variable collections are passed into a `jax.vmap` wrapped function. - -4. *Scope reconstruction* - - Now that the variables and PRNGs are lifted inside the transformation, we want to recreate the lifted scopes. Pack calls - `fn` with a `scope_fn` that takes the lifted variables and PRNGs and returns the reconstructed scopes with the lifted variables and rng sequences. - -5. *Repack stage* - - After we have used the lifted scopes we have to retrieve the updated variables (PRNG sequences can simply be discarded). - pack passes the `repack_fn` to support this. - This stage is similar to stage 2 except that we only lift variables and immutable variables are ignored. - Immutable variables cannot be updated. Therefore, they should not be returned from the transformed function. - -6. *Commit stage* - - `pack` expects `fn` to return a pair where the first item will simply be returned from pack and the second item should be the repacked variables. - The updated variables are stored in the original/un-lifted scopes such that the mutations that happen inside the transformation survive after the transformation completes. - - -### Using pack example - - -A minimal example of using `pack` to transpose each matrix in a variable collection: - -```python -from flax.core import lift -from flax.core import Scope, init, apply, nn as core_nn - -def lift_transpose(fn, target='params', variables=True, rngs=True): - # by default we transpose 'params' and simply pass through all other variables. - def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args): - # normally we would first call into a JAX transformed function here... - target, rest = variable_groups - def trans(x): - if x.ndim == 2: - return x.T - return x - target = jax.tree_util.tree_map(trans, target) - variable_groups = (target, rest) - scope = scope_fn(variable_groups, rng_groups) - y = fn(scope, *args) - out_variables = repack_fn(scope) - return y, out_variables - return lift.pack( - wrapper, - in_variable_filters=(target, variables), - out_variable_filters=(variables,), - rng_filters=(rngs,)) - -x = jnp.ones((3, 2)) -y, params = init(lift_transpose(core_nn.dense))(random.PRNGKey(0), x, 4) -``` - -NOTE that most users should not need to interact with `pack` directly. -Please open a GitHub issue when you find a use case that is not supported yet by the existing lifted transformations. - -### Supported transformations - -| Jax Transform | Supported in Linen? | Comments | -|-|-|-| -| vmap | ✅ | | -| scan | ✅ | Carry variables cannot be initialized inside the scan body. | -| remat | ✅ | | -| jit | ✅ | Current implementation might cause unnecessary recompilation. | -| jvp | ✅ | | -| vjp | ✅ | | -| custom_vjp | ✅ | | -| custom_jvp | ❌ | | -| while_loop | ✅ | Carry variables cannot be initialized inside the while_loop body. | -| cond | ✅ | Variable initialization / mutation must structurally match across branches. | -| switch | ✅ | Variable initialization / mutation must structurally match across branches. | -| pmap | ❌ | | -| xmap | ❌ | | - -References: -- [Linen transforms documentation](https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.linen.transforms). -- [Linen transforms source code](https://github.com/google/flax/blob/main/flax/linen/transforms.py) -- [Core lifting source code](https://github.com/google/flax/blob/main/flax/core/lift.py) - -### Linen examples - -Going back to our original example, we can now use `nn.vmap` to simplify our implementation: - -```python -class LinenVmapMLP(nn.Module): - @nn.compact - def __call__(self, xs): - VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0) - return VmapMLP(name='mlp')(xs) - -variables = LinenVmapMLP().init(random.PRNGKey(0), xs) -print(jax.tree_util.tree_map(jnp.shape, variables['params'])) -"""==> -{ - mlp: { - Dense_0: { - bias: (3, 4), - kernel: (3, 2, 4), - }, - Dense_1: { - bias: (3, 1), - kernel: (3, 4, 1), - }, - }, -} -""" -``` - -Here we use `variable_axes={'params': 0}` to indicate that parameters are vectorized rather than shared and `split_rngs={'params': True}` means each set of parameters is initialized independently. - -We can also extend the example with some inner state by adding a `BatchNorm` layer: - -```python -class StatefulMLP(nn.Module): - @nn.compact - def __call__(self, x, *, train): - h = nn.Dense(4, name='hidden')(x) - h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train) - h = nn.relu(h) - return nn.Dense(1, name='out')(h) - -class LinenStatefulVmapMLP(nn.Module): - @nn.compact - def __call__(self, xs, *, train): - VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0) - return VmapMLP(name='mlp')(xs, train=train) -variables = LinenStatefulVmapMLP().init(random.PRNGKey(0), xs) -``` - -All we had to add to `nn.vmap` is `'batch_stats': 0`, indicating that the batch stats are vectorized rather than shared along the first axis. - - -## Alternatives - -Other numerical computation frameworks consider variables a first-class citizen. -An alternative to functionalization would be to use a variable system either integrated or on top of JAX. -An advantage of this is that per-variable lifting becomes easier. -If variables are part of the JAX IR (JAXPR), we could inspect which variables have to be lifted in a certain computation. -Optionally, they could be annotated with a collection tag to decide on various lifting options. - -The downside of this approach is that a variable system is more complicated. -Variables are related references and break a core assumption of Functional Programming (see [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency)) -Other APIs that currently have a functional interface would probably require integration as well (e.g.: checkpointing and optimization APIs). diff --git a/docs/advanced_topics/linen_design_principles.rst b/docs/advanced_topics/linen_design_principles.rst deleted file mode 100644 index 48ace0048a..0000000000 --- a/docs/advanced_topics/linen_design_principles.rst +++ /dev/null @@ -1,101 +0,0 @@ -Linen Design Principles -======================= - -Flax is a neural network library built on JAX that has been adopted by a -growing set of users, most notably in the JAX submissions for the MLPerf -0.7 benchmark. Our experience over the last year (and many conversations -with users and JAX core devs) has guided a redesign of the API called -Linen in response to the following basic design questions. - -How does a neural network library benefit from being built on JAX and leverage JAX’s unique strengths? ------------------------------------------------------------------------------------------------------- - -The world already has TensorFlow and PyTorch, and there’s little need to -build a clone of either. We believe that the composable -function-transformation approach that JAX takes opens up new frontiers -for making neural net code more maintainable, more scalable and more -performant than existing libraries. While we strive to offer an API -familiar to those experienced with Keras/Sonnet/PyTorch, Linen is -fundamentally a functional system for defining neural nets in JAX. Just -a few examples of what we believe a JAX-targeted library can enable: - -- write models as “single-example” code and introduce batching - automatically with vmap -- automatically handle ragged batches in NLP and other masking issues -- create efficient compile-time and runtime models by utilizing - rematerialized scan for massive conv-nets. -- remove memory headaches by enabling easy rematerialization, - reversibility, and model-parallel data sharding. - -How does one interoperate with JAX transformations? ---------------------------------------------------- - -Arguably the entire point of a neural net library is to offer an -implicit variable management API to save the user from having to -manually thread thousands of variables through a complex tree of -functions. However, JAX operates on pure functions. To handle both -current and future JAX transforms (configured and composed in any way), -Linen Modules are directly “functionalized”, that is, automatically cast -in-place as explicit functions of the form: - -.. math:: f(v_{in}, x) \rightarrow v_{out}, y - -Where :math:`v_{in}` is the variable collections and PRNG state used by -the model, :math:`v_{out}` the mutated output variable collections, -:math:`x` the input data and :math:`y` the output data. Applying JAX -transformations then simply reduces to specifying any argument-specific -transform options to the various variable collections and PRNG state. -This unleashes the flexibility and strength of JAX transformations – for -example, one can achieve either device-parallel training or per-device -ensembling by using ``pmap`` in different ways, without any explicit -library support. Moreover, **within Modules**, we expose lightweight -wrappers around the complex JAX transforms such as ``vmap`` and ``scan`` -that annotate how each variable collection is to be transformed by JAX. -Importantly, we handle the nontrivial cases of creating new variables -and transformed variables under mapping and loop transforms correctly -for initialization and application. - -How are parameters represented, and how do we handle general “differentiable algorithms” that update stateful variables? ------------------------------------------------------------------------------------------------------------------------- - -We follow the JAX functional conventions of storing data in “pytrees”: -JAX arrays contained in nested tuples, lists, dictionaries. Because -researchers inevitably manually interact with this data, we use nested -dictionaries with meaningful default keys and offer several utilities -(traversals, etc.) for handling them directly. Linen uses an accelerated -version of a Python frozen dictionary that caches its JAX-flattened form -to speed up jitted function call overheads. - -Flax generalizes the operation of a neural net by allowing models to -accept collections of several different “kinds”: parameters, batch-norm -stats, autoregressive caches, debug information, fine-grained -hyperparameters, etc. Each collection is stored in a nested dictionary -of the same structure as the model. Importantly, we do *not* conflate -these various kinds under the single vague rubric of “state”, but keep -different logical types of variables separate that can be treated -differently under JAX transformations and under mutations (e.g. training -vs prediction). Similarly, we allow for multiple separate named PRNG -chains inside Modules for separate treatment of randomness for different -applications such as initialization, dropout, sampling, etc. - -At every stage the data associated with a neural net is not kept in a -custom object hierarchy, but left in an explicit, Python and JAX native -form that is easy to introspect and modify. Users have utilized this to -map TF and PyTorch checkpoints to Flax, to implement submodel-specific -loss terms, and to perform fast model surgery, etc. For saving this -data, most Flax examples store these nested dictionaries via the -efficient “msgpack” binary format – but as variables are simply Python -dicts, you can use any (non-JAX-aware) serialization library directly. - -How does one interoperate with purely functional JAX code? ----------------------------------------------------------- - -To be broadly useful to the JAX ecosystem, users shouldn’t need to -heavily refactor their code in order to add “trainability” for a given -numerical task. “The library should not get in the way.” Utilizing -purely functional code from within Linen is trivial: Module -implementations are just JAX code with named variables. Using Linen -modules inside otherwise purely functional code can be as simple as -using a single top-level module transformation to allow initialization -and pure application of any JAX program that might contain various -trainable sections. diff --git a/docs/advanced_topics/linen_upgrade_guide.rst b/docs/advanced_topics/linen_upgrade_guide.rst deleted file mode 100644 index 364af9ace2..0000000000 --- a/docs/advanced_topics/linen_upgrade_guide.rst +++ /dev/null @@ -1,518 +0,0 @@ -Upgrading my Codebase to Linen -============================== - -As of Flax v0.4.0, ``flax.nn`` no longer exists, and is replaced with the new -Linen API at ``flax.linen``. If your codebase is still using the old API, you -can use this upgrade guide to upgrade it to Linen. - -.. testsetup:: - - from flax.training import train_state - from jax import random - import optax - import jax - from flax.linen import initializers - - from jax import lax - import jax.numpy as jnp - import numpy as np - from typing import Any, Callable, Sequence, Tuple - - PRNGKey = Any - Shape = Tuple[int, ...] - Dtype = Any - Array = Any - - default_kernel_init = initializers.lecun_normal() - -Defining Simple Modules --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - from flax import nn - - class Dense(base.Module): - def apply(self, - inputs, - features, - use_bias=True, - kernel_init=default_kernel_init, - bias_init=initializers.zeros_init()): - - kernel = self.param('kernel', - (inputs.shape[-1], features), kernel_init) - y = jnp.dot(inputs, kernel) - if use_bias: - bias = self.param( - 'bias', (features,), bias_init) - y = y + bias - return y - - return new_state, metrics - --- - from flax import linen as nn # [1] #! - - class Dense(nn.Module): - features: int # [2] #! - use_bias: bool = True - kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init - bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() - - @nn.compact - def __call__(self, inputs): # [3] #! - kernel = self.param('kernel', - self.kernel_init, (inputs.shape[-1], self.features)) # [4] #! - y = jnp.dot(inputs, kernel) - if self.use_bias: - bias = self.param( - 'bias', self.bias_init, (self.features,)) # [5] #! - y = y + bias - return y - -1. Replace from ``flax import nn`` with from ``flax import linen as nn``. - -2. Move arguments to ``apply`` into dataclass attributes. Add type annotations - (or use type ``Any`` to bypass). - -3. Rename method ``apply`` to ``__call__`` and (optionally) wrap with - |@compact|_. Methods wrapped in |@compact|_ can define submodules directly - within the method (like in old Flax). You can only wrap a single method with - |@compact|_. Alternatively, you can define a ``setup`` method. For more - details, please see our other HOWTO `Should I use setup or nn.compact?`_. - -4. Access dataclass attributes values by ``self.`` inside methods, e.g. - ``self.features``. - -5. Move shape to the end of the arguments to |self.param|_ (initializer functions - can take arbitrary argument lists). - - -Using Modules inside other Modules --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - class Encoder(nn.Module): - - def apply(self, x): - x = nn.Dense(x, 500) - x = nn.relu(x) - z = nn.Dense(x, 500, name="latents") - return z - --- - class Encoder(nn.Module): - @nn.compact - def __call__(self, x): - x = nn.Dense(500)(x) # [1] #! - x = nn.relu(x) - z = nn.Dense(500, name='latents')(x) # [2] #! - return z - -1. Module constructors no longer return the outputs. Instead, they work like - normal constructors and return module instances. These instances can be - shared like in normal Python (instead of using ``.shared()`` in old Flax). - Since most modules implement ``__call__``, you can retain the conciseness of - old Flax. - -2. Names can be optionally passed to all module constructors. - -Sharing submodules and defining multiple methods --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - class AutoEncoder(nn.Module): - def _create_submodules(self): - return Decoder.shared(name="encoder") - - def apply(self, x, z_rng, latents=20): - decoder = self._create_decoder() - z = Encoder(x, latents, name="encoder") - return decoder(z) - - @nn.module_method - def generate(self, z, **unused_kwargs): - decoder = self._create_decoder() - return nn.sigmoid(decoder(z)) - --- - class AutoEncoder(nn.Module): - latents: int = 20 - - def setup(self): # [1] #! - self.encoder = Encoder(self.latents) # [2] #! - self.decoder = Decoder() - - def __call__(self, x): # [3] #! - z = self.encoder(x) - return self.decoder(z) - - def generate(self, z): # [4] #! - return nn.sigmoid(self.decoder(z)) - - -1. Use |setup|_ instead of ``__init__``, which is already defined in - the dataclasses library. Flax calls setup right after modules are ready to be - used. (You can do this for all modules if you like instead of using - |@compact|, but we like how |@compact| co-locates where modules are defined - and used, especially if you have loops or conditionals). - -2. Like regular Python, share submodules by assigning to self during - initialization. Similar to PyTorch, ``self.encoder`` automatically has the - name ``"encoder"``. - -3. We don't use |@compact|_ here because we're not defining any inline - submodules (all submodules are defined in setup). - -4. Define additional methods just like in regular Python. - -``Module.partial`` inside other modules --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - # no import #! - - class ResNet(nn.Module): - """ResNetV1.""" - - - def apply(self, x, - stage_sizes, - num_filters=64, - train=True): - conv = nn.Conv.partial(bias=False) - norm = nn.BatchNorm.partial( - use_running_average=not train, - momentum=0.9, epsilon=1e-5) - - x = conv(x, num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='conv_init') - x = norm(x, name='bn_init') - - # [...] - return x - --- - from functools import partial #! - - class ResNet(nn.Module): - """ResNetV1.""" - stage_sizes: Sequence[int] - num_filters: int = 64 - train: bool = True - - @nn.compact - def __call__(self, x): - conv = partial(nn.Conv, use_bias=False) #! - norm = partial(nn.BatchNorm, #! - use_running_average=not self.train, #! - momentum=0.9, epsilon=1e-5) #! - - x = conv(self.num_filters, (7, 7), (2, 2), - padding=[(3, 3), (3, 3)], - name='conv_init')(x) - x = norm(name='bn_init')(x) - - # [...] - return x - -Use normal ``functools.partial`` instead of ``Module.partial``. The rest stays -the same. - -Top-level training code patterns --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - def create_model(key): - _, initial_params = CNN.init_by_shape( - key, [((1, 28, 28, 1), jnp.float32)]) - model = nn.Model(CNN, initial_params) - return model - - def create_optimizer(model, learning_rate): - optimizer_def = optim.Momentum(learning_rate=learning_rate) - optimizer = optimizer_def.create(model) - return optimizer - - def cross_entropy_loss(*, logits, labels): - one_hot_labels = jax.nn.one_hot(labels, num_classes=10) - return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1)) - - def loss_fn(model): - logits = model(batch['image']) - one_hot = jax.nn.one_hot(batch['label'], num_classes=10) - loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'], - axis=-1)) - return loss, logits - --- - def create_train_state(rng, config): # [1] #! - variables = CNN().init(rng, jnp.ones([1, 28, 28, 1])) # [2] #! - params = variables['params'] # [3] #! - tx = optax.sgd(config.learning_rate, config.momentum) # [4] #! - return train_state.TrainState.create( - apply_fn=CNN.apply, params=params, tx=tx) - - - - - def loss_fn(params): - logits = CNN().apply({'params': params}, batch['image']) # [5] #! - one_hot = jax.nn.one_hot(batch['label'], 10) - loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, - labels=one_hot)) - return loss, logits - - -1. We no longer use the ``Model`` abstraction -- instead we pass parameters - around directly, usually encapsulated in a `Train State`_ object, which can - directly be passed to JAX transformations. - -2. To compute initial parameters, construct a module instance and call |init|_ - or |init_with_output|_. We haven't ported over ``init_by_shape`` because this - function did some magic we did not like (it evaluated the function by shape. - but returned real values anyway). Therefore, you should now pass concrete - values to the initializer functions, and you can optimize the initialization - by wrapping it with |jax.jit|_, which is highly recommended to avoid running - a full forward pass. - -3. Linen generalizes parameters into variables. Parameters are one - "collection" of variables. Variables are nested dicts, where the top-level - keys reflect the different variable collections, of which "param" is one of. - See the `Variables documentation`_ for more details. - -4. We recommend using Optax optimizers. See our separate HOWTO called - `Upgrading my Codebase to Optax`_ for more details. - -5. To make predictions with your model, make an instance at the top level (this - is free -- just a wrapper around constructor attributes) and call the - ``apply`` method (which will call ``__call__`` internally). - -Non-trainable variables ("state"): Use within Modules --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - class BatchNorm(nn.Module): - def apply(self, x, ...): - # [...] - ra_mean = self.state( - 'mean', (x.shape[-1], ), initializers.zeros_init()) - ra_var = self.state( - 'var', (x.shape[-1], ), initializers.ones_init()) - # [...] - --- - class BatchNorm(nn.Module): - def __call__(self, x): - # [...] - ra_mean = self.variable( #! - 'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], )) - ra_var = self.variable( - 'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], )) - # [...] - -The first argument is the name of the variable collection ("param" is the only -variable collection that's always available). Some colllections may be treated -as mutable, and others as immutable at top-level training code (see next section -for details). Flax also lets you treat each variable collection differently when -using JAX transformations inside modules. - -Non-trainable variables ("state"): Top-level training code patterns --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - # initial params and state - def initial_model(key, init_batch): - with nn.stateful() as initial_state: - _, initial_params = ResNet.init(key, init_batch) - model = nn.Model(ResNet, initial_params) - return model, init_state - - - # updates batch statistics during training - def loss_fn(model, model_state): - with nn.stateful(model_state) as new_model_state: - logits = model(batch['image']) - # [...] - - - - - # reads immutable batch statistics during evaluation - def eval_step(model, model_state, batch): - with nn.stateful(model_state, mutable=False): - logits = model(batch['image'], train=False) - return compute_metrics(logits, batch['label']) - --- - # initial variables ({"param": ..., "batch_stats": ...}) - def initial_variables(key, init_batch): - return ResNet().init(key, init_batch) # [1] #! - - - - - - # updates batch statistics during training - def loss_fn(params, batch_stats): - variables = {'params': params, 'batch_stats': batch_stats} # [2] #! - logits, new_variables = ResNet(train=true).apply( - variables, batch['image'], mutable=['batch_stats']) # [3] #! - new_batch_stats = new_variables['batch_stats'] - # [...] - - - # reads immutable batch statistics during evaluation - def eval_step(params, batch_stats, batch): - variables = {'params': params, 'batch_stats': batch_stats} - logits = ResNet(train=False).apply( - variables, batch['image'], mutable=False) # [4] #! - return compute_metrics(logits, batch['label']) - -1. |init|_ returns a variable dict, e.g. ``{"param": ..., "batch_stats": ...}`` - (see `Variable documentation`_). - -2. Combine the different variable collections into a variable dict. - -3. During training, the ``batch_stats`` variable collection changes. Since we - specify that in the mutable argument, the return value from ``module.apply`` - becomes an ordered pair of ``output, new_variables``. - -4. During evaluation, we want to raise an error if we're accidentally applying - Batch Norm in training mode. By passing ``mutable=False`` into - ``module.apply`` we enforce that. Since no variables are mutated, the return - value is once again just the output. - -Loading pre-Linen checkpoints --------------------------------- - -While most Linen modules should be able to use pre-Linen weights without any -modification, there is one catch: In pre-Linen API submodules were numbered -incrementally, independent of the submodule class. With Linen this behavior has -changed to keep separate submodule counts per module class. - -In pre-Linen, params have the following structure: - -``{'Conv_0': { ... }, 'Dense_1': { ... } }`` - -In Linen this is instead: - -``{'Conv_0': { ... }, 'Dense_0': { ... } }`` - -TODO: Add an example here how to load a new ``TrainState`` object. - -Randomness --------------------------------- - -.. codediff:: - :title_left: Old Flax - :title_right: Linen - :sync: - - def dropout(inputs, rate, deterministic=False): - keep_prob = 1. - rate - if deterministic: - return inputs - else: - mask = random.bernoulli( - make_rng(), p=keep_prob, shape=inputs.shape) - return lax.select( - mask, inputs / keep_prob, jnp.zeros_like(inputs)) - - - def loss_fn(model, dropout_rng): - with nn.stochastic(dropout_rng): - logits = model(inputs) - --- - class Dropout(nn.Module): - rate: float - - @nn.compact - def __call__(self, inputs, deterministic=False): - keep_prob = 1. - self.rate - if deterministic: - return inputs - else: - mask = random.bernoulli( - self.make_rng('dropout'), p=keep_prob, shape=inputs.shape) # [1] #! - return lax.select( - mask, inputs / keep_prob, jnp.zeros_like(inputs)) - - - def loss_fn(params, dropout_rng): - logits = Transformer().apply( - {'params': params}, inputs, rngs={'dropout': dropout_rng}) # [2] #! - -1. RNGs in Linen have "kinds" -- in this case "dropout". Different kinds can be - treated different in JAX transformations (for example -- do you want the same - dropout mask for each timestep in a sequence model or a different one?) - -2. Instead of using the ``nn.stochastic`` context manager, you pass in RNGs - explicitly to ``module.apply``. During evaluation you wouldn't pass any RNGs - -- then if you accidentally use dropout in non-deterministic mode, - ``self.make_rng('dropout')`` would raise an error. - - -Lifted Transforms --------------------------------- - -In Linen, rather than using JAX transformation directly, we are using -"lifted transforms", which are JAX transformations applied to Flax Modules. - -For more information, please see the design note on `Lifted Transformations`_. - -TODO: Given an example of ``jax.scan_in_dim`` (pre-Linen) vs. ``nn.scan`` -(Linen). - -.. _`Should I use setup or nn.compact?`: https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html -.. _`Variables documentation`: https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.core.variables -.. _`TrainState`: https://flax.readthedocs.io/en/latest/flax.training.html#train-state -.. _`Upgrading my Codebase to Optax`: https://flax.readthedocs.io/en/latest/howtos/optax_update_guide.html -.. _`Lifted Transformations`: https://flax.readthedocs.io/en/latest/design_notes/lift.html - - -.. |@compact| replace:: ``@compact`` -.. _@compact: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.compact - -.. |init| replace:: ``init`` -.. _init: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init - -.. |init_with_output| replace:: ``init_with_output`` -.. _init_with_output: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init_with_output - -.. |jax.jit| replace:: ``jax.jit`` -.. _jax.jit: https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit - -.. |self.param| replace:: ``self.param`` -.. _self.param: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.param - -.. |setup| replace:: ``setup`` -.. _setup: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.setup - -.. |@flax.struct.dataclass| replace:: ``@flax.struct.dataclass`` -.. _@flax.struct.dataclass: https://flax.readthedocs.io/en/latest/flax.struct.html#flax.struct.dataclass - -.. |checkpoints.convert_pre_linen()| replace:: ``checkpoints.convert_pre_linen()`` -.. _checkpoints.convert_pre_linen(): https://flax.readthedocs.io/en/latest/flax.training.html#flax.training.checkpoints.convert_pre_linen diff --git a/docs/advanced_topics/module_lifecycle.rst b/docs/advanced_topics/module_lifecycle.rst deleted file mode 100644 index 86422f6a91..0000000000 --- a/docs/advanced_topics/module_lifecycle.rst +++ /dev/null @@ -1,381 +0,0 @@ -The Module lifecycle -###################### - -.. testsetup:: - - from typing import Any, Callable, Iterable - import flax - from flax import linen as nn - from jax import random - import jax - - -This design note is intended for users who are already familiar with linen Modules but want to understand more about the design principles behind the abstraction. This note should give you a good understanding of the assumptions and guarantees the Module API is built upon. If you have no practical experience with Modules yet, check out the `MNIST Tutorial `_. - -linen Modules offer a Pythonic abstraction on top of Flax core. The `Module `_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways. - - -Overview -*********** - -Definition -============= - -Let's start with a high-level overview of the Module lifecycle. First, define a simple Module: - - -.. testcode:: - - class MLP(nn.Module): - # 1. Attribute annotations - hidden_size: int - out_size: int - - # 2. The ``setup`` method - def setup(self): - self.hidden = nn.Dense(self.hidden_size) - self.out = nn.Dense(self.out_size) - - # 3. User methods - def __call__(self, x): - a = self.hidden(x) - h = nn.relu(a) - return self.out(h) - - -This Module consists of: - -#. **Attribute annotations**, defined as `dataclass `_ fields. These annotations automatically define a constructor. -#. **The ``setup`` method**, which creates submodules and assigns them to attributes. -#. **User methods**. By convention, most Modules have just one ``__call__`` method, but you can define multiple methods or use different method names. - -Construction/initialization -============================= - -Now we want to construct and use the ``MLP`` Module: - - -.. testcode:: - - mlp = MLP(hidden_size=5, out_size=3) - x = jax.numpy.ones((1, 2)) - variables = mlp.init(random.PRNGKey(0), x) - y = mlp.apply(variables, x) - - -First, we construct an instance of ``MLP`` and pass the construction attributes. Note that construction here is different from what you might expect if you are not used to Functional Programming patterns. The ``MLP`` constructor does not actually create variables or any internal state whatsoever. It's best to think of it as a specification or template of the Module that contains functionality but no data. - -Let's take a closer look at initialization. Surprisingly, there actually is no separate initialization path in Flax. Calling ``init`` is just a special case of ``apply``, which you can also write as: - - -.. testcode:: - - # equivalent to: variables = mlp.init(random.PRNGKey(0), x) - _, variables = mlp.apply({}, x, rngs={"params": random.PRNGKey(0)}, mutable=True) - - -Thus, ``init`` is nothing more than a wrapper around ``apply`` where: - -#. We call a Module without any initial variables (an empty dict). -#. A PRNG generator named ``"params"`` is always passed for randomly initializing parameters (using the parameter initialization function). -#. All variable collections are set to mutable (``mutable=True``). When a collection is mutable, existing variables can be updated and new variables can be created. Thus, inside ``init`` variables can be initialized in any variable collection and they are all added to the returned variable dictionary. - -Lifecycle -============= - - -Now that you have learned about ``init`` being a special case of ``apply``, let's look at ``.apply(...)`` in more detail. In fact, most of the complexity of Modules resides in the ``apply`` method. The "Module lifecycle" consists of constructing and ``apply``-ing a Module. We can summarize the Module lifecycle as follows: - - -#. We construct ``mlp = MLP(hidden_size=5, out_size=3)``, such that ``mlp.hidden_size=5`` and ``mlp.out_size=3``. - -#. Then, call ``mlp.apply``, which: - - #. Makes a clone of ``mlp``, let's call it ``mlp_copy``. - - #. Calls ``mlp_copy.setup()``. - - #. Returns the output of ``mlp_copy.__call__()`` and optionally the variable collections that were specified as mutable using the keyword argument ``mutable=``. - -Notice that the lifecycle includes cloning the Module instance. This is done to ensure that ``apply`` can be treated as a pure function (i.e., if you pass the same arguments in, it will return the same outputs). You will learn about this in more detail later in the :ref:`Top-level Modules` section. - -Variables -========== - -The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables `_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable. - -The ``init`` and ``apply`` methods return the variables as a nested dictionary with string keys and JAX arrays at the leaves. At the top level each key corresponds to a variable collection. Inside each collection the nested dict structure corresponds with the ``Module`` hierarchy. The variable dict is immutable and therefore really just a snapshot of state the variables are in. When ``apply`` is called again, the variable dict is passed as an argument. Such that the variables are in the same state as when the previous ``init`` / ``apply`` call finished. - - -.. note:: - Module fields are declared using the `field_name: TypeHint` syntax (same as dataclasses). Without a type hint, an attribute is considered a static property of the class. In case you cannot specify the type you can use ``typing.Any`` as a wildcard type. - - -Compact Modules -****************** - -Linen provides an alternative API for defining modules more compactly. This is especially useful for the common case where the Module consists of only one method that uses parameters and/or sub-modules. Using the compact API the MLP can be rewritten as follows: - - -.. testcode:: - - class CompactMLP(nn.Module): - hidden_size: int - out_size: int - - @nn.compact - def __call__(self, x): - a = nn.Dense(self.hidden_size)(x) - h = nn.relu(a) - return nn.Dense(self.out_size)(h) - - -A compact ``Module`` is similar in spirit to a function. It offers a concise notation and restricts external interaction to the inputs and return values of the function. In this case the concise notation might make it easier for others to understand what the Module does. There is no need to jump back and forth between the ``setup`` and ``__call__`` method to understand what the submodules are doing. Instead, simply reading the ``__call__`` method from top to bottom once should provide a concise overview. This can make a significant difference if you are implementing complex Modules with many hyperparameters. See `setup or compact `_ for a practical guide on deciding between setup and compact. - -Another benefit of defining submodules and/or variables inline is that you can add arguments to your method when constructing variables. The most common example of this is using shape information to determine the shape of a parameter like this: - - -.. testcode:: - - class CompactScaledMLP(nn.Module): - hidden_size: int - out_size: int - - @nn.compact - def __call__(self, x): - scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:]) - x *= scale[None] - a = nn.Dense(self.hidden_size)(x) - h = nn.relu(a) - return nn.Dense(self.out_size)(h) - - -.. testcode:: - :hide: - - mdl = CompactScaledMLP(hidden_size=4, out_size=5) - x = jax.numpy.ones((3, 2)) - vars = mdl.init(random.PRNGKey(0), x) - assert vars["params"]["scale"].shape == (2,) - -Many of the standard Linen Modules like ``nn.Dense`` use shape inference already to avoid the need to specify input shapes (like the number of input features to a Dense layer). - -Compact control flow -===================== - -The order in which you define submodules determines the name of a submodule if none is provided explicitly (using the ``name=`` keyword argument passed to the Module's constructor). Because the ``name`` determines how parameters are mapped to submodules, you must be careful about mixing control flow with auto-generated names. Using control flow can change the order or remove certain submodules altogether. This is useful in case a submodule should only exist depending on some construction argument. However, when control flow depends on the input arguments to the Module, you should be careful. For example, the following Module will break: - - -.. testcode:: - - class WrongModule(nn.Module): - @nn.compact - def __call__(self, x, mode): - if mode == "encode": - return nn.Dense(features=8)(x) - elif mode == "decode": - return nn.Dense(features=4)(x) - - -The above Module will break because either the encoder or decoder path will construct a Module named "Dense_0". This means the two Modules will share parameters which is not intended here. Actually, the two Modules cannot share parameters because they each have a different number of features. - -This problem can be solved in various ways: - - Provide explicit names - - create the modules in ``setup`` - - or move the constructor out of the control flow. - -The latter is done as follows: - -.. testcode:: - - class CorrectModule(nn.Module): - @nn.compact - def __call__(self, x, mode): - encoder = nn.Dense(8) - decoder = nn.Dense(4) - if mode == "encode": - return encoder(x) - elif mode == "decode": - return decoder(x) - -.. testcode:: - :hide: - - def init_fn(mdl): - x = jax.numpy.ones((3, 2)) - z = mdl(x, "encode") - return mdl(z, "decode") - - mdl = CorrectModule() - vars = nn.init(init_fn, mdl)(random.PRNGKey(0)) - assert vars["params"]["Dense_0"]["kernel"].shape == (2, 8) - assert vars["params"]["Dense_1"]["kernel"].shape == (8, 4) - - -In the above example the construction order is fixed. After construction the submodules can be used in an arbitrary order. - -.. note:: - compact modules show a strong resemblance to `React hooks `_. - - -Top-level Modules -***************** - -When a Module instance is created at the "top-level", it will be in an "unbound" state - that is, it has no variables attached. "Top-level" means it is not constructed as a sub-Module inside another Module class. Apart from calling ``init`` and ``apply``, there is not much you can do with an unbound Module. Note also that ``setup`` is not called on unbound Modules, so you can only access the construction arguments. Refer to the :ref:`Future work` section to learn how this might change in the future. - -Why are top-level Modules always unbound? -=============================================== - -When we call ``apply``, a copy of the top-level Module is created which will actually hold the variables and PRNG sequences. This stateful, "bound", clone only exists while we are executing the apply method. The reason for this is that if you create a stateful object and destroy it before the apply function returns, the ``apply`` function itself behaves like a pure function. A pure function has two constraints: - -#. If you put the same arguments in, it will return the same outputs -#. It does not change anything outside the function. This means you cannot manipulate stateful objects that are accessible outside the pure function. - - -Pure functions have many advantages but when using JAX they are often essential. For example, most code requires compilation using ``jax.jit`` to be fast and once you created a Module you probably want to optimize its parameters using ``jax.grad``. However, these APIs expect a pure function and don't work on stateful bound ``Module`` instances directly. Moreover, pure functions allow for flexible interoperability with other libraries. For example, We recommend `Optax `_ for optimizing parameters. The optimizers in Optax expect and return a PyTree of JAX arrays to optimize, just like the ``apply`` function of a Linen Module. - -Cloning -=============================================== - -To make this approach work reliably we need well-defined cloning behavior. Rather than relying on a complex nested cloning procedure like Python's ``deepcopy``, Flax enforces that a ``Module`` is exactly defined by its construction arguments. Therefore cloning a Module reduces to calling the constructor with its original construction arguments. Because ``Module`` acts as an immutable dataclass, the construction arguments are mapped directly to instance attributes. Non-construction attributes that are computed in ``setup`` or ``__post_init__`` should also depend only on the construciton arguments to ensure a well-defined clone. - -Bind -=============================================== - -Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind `_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax `_ can optimze a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax. - - -Setup -********** - -The ``setup`` method is often used like the constructor hook (``__init__``) in normal Python classes. However, for more advanced use cases it's good to realize that it is not quite the same as a constructor. - -``setup`` is only called after a Module becomes bound. Normally, this is not an issue because most Modules are bound (almost) immediately (as part of ``init`` and ``apply``). Inside ``setup``, sub-modules become bound when they are assigned to an attribute. Inside an ``nn.compact`` decorated method, sub-modules are bound immediately when constructed. As explained in the previous section, top-level Modules are never bound and thus setup is not called when they are constructed. This means you cannot access attributes assigned in setup from an unbound, top-level module. - -.. testcode:: - - class TopLevelAccess(nn.Module): - - def setup(self): - self.foo = nn.Dense(2) - - mdl = TopLevelAccess() - assert not hasattr(mdl, "foo") # foo is not defined because setup is not called - -The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on functionalization will explain why we need ``setup`` to be lazy in the first place. - - -Functionalization -****************** - -So far we had a pure ``apply`` function that is typically transformed with some JAX transformations and inside ``apply`` we have a stateful Module instance to work with. In other words: Outside of a Module we are in a functional world where we have the power of JAX's functional transformations and inside the Module we get the power of Flax's stateful variables and PRNG sequence, and the ``apply`` method is our bridge between these two worlds. - -But what if we want to use JAX transformations **inside** Modules? The answer to this is functionalization. - -This procedure itself is tedious and error-prone but handled internally by Flax. At a high-level we can summarize it as follows. For a method ``fn`` defined within a Module: - -#. Collect the state (variables & PRNG sequences) of the Module(s) that should be available inside the JAX transformation and take a snapshot of it. - -#. Call the JAX transformation with the original arguments and the collected state. Then inside the transformation: - - #. Unpack the state and recreate the Modules - - #. Call the user code ``fn`` - - #. Collect the updated variables and rng and return it together with the original return values from ``fn`` - -#. Update the original state with the updated state returned from the transformation. - -A more in-depth explanation of functionalization and lifting can be found in the `Lifted Transformation `_ design note. - -Practical consequences -========================== - -For the most part functionalization is something that is handled automatically for you. Still there are some constraints that you must take into account. Most importantly, Flax only handles the stateful primitives (Linen variables and RNGs) and not arbitrary stateful Python code. Most importantly: You cannot close over stateful objects and ``Module`` objects because they are invisible to Flax's internals (and to JAX in general). - - -.. testcode:: - - class Foo(nn.Module): - @nn.compact - def __call__(self, x): - dense = nn.Dense(x.shape[-1]) - fn = lambda x: dense(x) + 1 - # simply calling inner works fine - # return self.inner(x, fn) - # but applying a transformation doesn't: - vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True}) - return vmap_inner(self, x, fn) - - def inner(self, x, fn): - for i in range(3): - x = fn(x) - return x - -Here ``inner`` takes a function that closes over a Module instance. In this example, that works fine because we are not transforming the inner method with a lifted transformation. Most methods are not transformed but it is good to know how to make Module methods transformable. - -The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree `_ arguments. That's arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct `_ API. - -Function closure is the most common way to accidentally hide a JAX array or Linen Module from a transformation. There is however an easy workaround if you want to pass closures that are also compatible with JAX and Linen transformations: - - -.. testcode:: - - class Partial(flax.struct.PyTreeNode): - fn: Callable = flax.struct.field(pytree_node=False) - args: Iterable[Any] - - def __call__(self, *args, **kwargs): - return self.fn(*(tuple(self.args) + args), **kwargs) - - class Foo(nn.Module): - - @nn.compact - def __call__(self, x): - dense = nn.Dense(x.shape[-1]) - fn = lambda mdl, x: mdl(x) + 1 - vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True}) - return vmap_inner(self, x, Partial(fn, [dense])) - - def inner(self, x, fn): - for i in range(3): - x = fn(x) - return x - - -.. testcode:: - :hide: - - x = jax.numpy.ones((3, 2)) - mdl = Foo() - vars = mdl.init(random.PRNGKey(0), x) - assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2) - - - -Here the closure is implemented using a Flax dataclass. The function itself is annotated with ``flax.struct.field(pytree_node=False)`` to indicate that it does not contain JAX Arrays or Linen Modules. The partially applied ``args`` on the other hand is treated as a pytree container. We rewrite the closure to use Partial. Now the inner method can be transformed using lifted transformations. - - -Future work -************* - - -Setup for unbound Modules -=========================== - -The current Module abstraction is particularly restrictive when it comes to initializing fields after construction. In the current Module API, the ``setup`` method is the place to initialize the fields of the Module instance. Because ``setup`` is only called on a bound Module, the full Module API is available inside ``setup``, including variable declaration. However, oftentimes we don't actually require any stateful API's to initialize a field. In fact, most commonly we simply want to declare a submodule. More importantly, it's often useful to inspect submodules for debugging or to partially run the model. Consider for example: - - -.. testcode:: - - class AutoEncoder(nn.Module): - def setup(self): - self.encoder = Encoder(...) - self.decoder = Decoder(...) - - -Imagine we want to call just the decoder using `auto_encoder.decoder.apply(decoder_variables, x)`. With the current setup API this does not work because we must first bind the variables before setup is called and the decoder attribute is defined. Of course we can manually construct the Decoder Module with the same attributes as in setup but this is not ideal in many cases. - -There are two possible solutions to make this use case more ergonomic. First, setup could be made to run immediately after construction before it becomes bound. This means you can still create sub modules but you can no longer define or manipulate variables. Therefore, this would be a breaking change and it would require a new API for defining variables lazily - -Alternatively, an additional special method could be introduced that runs right away after Module construction and before it becomes bound. In this case, the ``setup`` method would preserve its original semantics. diff --git a/docs/advanced_topics/optax_update_guide.rst b/docs/advanced_topics/optax_update_guide.rst deleted file mode 100644 index bb50364ed2..0000000000 --- a/docs/advanced_topics/optax_update_guide.rst +++ /dev/null @@ -1,286 +0,0 @@ -.. image:: https://colab.research.google.com/assets/colab-badge.svg - :target: https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/optax_update_guide.ipynb - -Upgrading my Codebase to Optax -============================== - -We have proposed to replace :py:mod:`flax.optim` with `Optax -`_ in 2021 with `FLIP #1009 -`_ and -the Flax optimizers have been removed in v0.6.0 - this guide is targeted -towards :py:mod:`flax.optim` users to help them update their code to Optax. - -See also Optax's quick start documentation: -https://optax.readthedocs.io/en/latest/optax-101.html - -.. testsetup:: - - import flax - import jax - import jax.numpy as jnp - import flax.linen as nn - import optax - - # Note: this is the minimal code required to make below code run. See in the - # Colab linked above for a more meaningful definition of datasets etc. - batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])} - ds_train = [batch] - get_ds_train = lambda: [batch] - model = nn.Dense(1) - variables = model.init(jax.random.PRNGKey(0), batch['image']) - learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1. - loss = lambda params, batch: jnp.array(0.) - -Replacing ``flax.optim`` with ``optax`` ---------------------------------------- - -Optax has drop-in replacements for all of Flax's optimizers. Refer to Optax's -documentation `Common Optimizers `_ -for API details. - -The usage is very similar, with the difference that ``optax`` does not keep a -copy of the ``params``, so they need to be passed around separately. Flax -provides the utility :py:class:`~flax.training.train_state.TrainState` to store -optimizer state, parameters, and other associated data in a single dataclass -(not used in code below). - -.. codediff:: - :title_left: flax.optim - :title_right: optax - :sync: - - @jax.jit - def train_step(optimizer, batch): - grads = jax.grad(loss)(optimizer.target, batch) - - - return optimizer.apply_gradient(grads) - - optimizer_def = flax.optim.Momentum( - learning_rate, momentum) - optimizer = optimizer_def.create(variables['params']) - - for batch in get_ds_train(): - optimizer = train_step(optimizer, batch) - - --- - - @jax.jit - def train_step(params, opt_state, batch): - grads = jax.grad(loss)(params, batch) - updates, opt_state = tx.update(grads, opt_state) - params = optax.apply_updates(params, updates) - return params, opt_state - - tx = optax.sgd(learning_rate, momentum) - params = variables['params'] - opt_state = tx.init(params) - - for batch in ds_train: - params, opt_state = train_step(params, opt_state, batch) - - -Composable Gradient Transformations ------------------------------------ - -The function |optax.sgd()|_ used in the code snippet above is simply a wrapper -for the sequential application of two gradient transformations. Instead of using -this alias, it is common to use |optax.chain()|_ to combine multiple of these -generic building blocks. - -.. |optax.sgd()| replace:: ``optax.sgd()`` -.. _optax.sgd(): https://optax.readthedocs.io/en/latest/api.html#optax.sgd -.. |optax.chain()| replace:: ``optax.chain()`` -.. _optax.chain(): https://optax.readthedocs.io/en/latest/api.html#chain - -.. codediff:: - :title_left: Pre-defined alias - :title_right: Combining transformations - - # Note that the aliases follow the convention to use positive - # values for the learning rate by default. - tx = optax.sgd(learning_rate, momentum) - - --- - - # - - tx = optax.chain( - # 1. Step: keep a trace of past updates and add to gradients. - optax.trace(decay=momentum), - # 2. Step: multiply result from step 1 with negative learning rate. - # Note that `optax.apply_updates()` simply adds the final updates to the - # parameters, so we must make sure to flip the sign here for gradient - # descent. - optax.scale(-learning_rate), - ) - -Weight Decay ------------- - -Some of Flax's optimizers also include a weight decay. In Optax, some optimizers -also have a weight decay parameter (such as |optax.adamw()|_), and to others the -weight decay can be added as another "gradient transformation" -|optax.add_decayed_weights()|_ that adds an update derived from the parameters. - -.. |optax.adamw()| replace:: ``optax.adamw()`` -.. _optax.adamw(): https://optax.readthedocs.io/en/latest/api.html#optax.adamw -.. |optax.add_decayed_weights()| replace:: ``optax.add_decayed_weights()`` -.. _optax.add_decayed_weights(): https://optax.readthedocs.io/en/latest/api.html#optax.add_decayed_weights - -.. codediff:: - :title_left: flax.optim - :title_right: optax - :sync: - - optimizer_def = flax.optim.Adam( - learning_rate, weight_decay=weight_decay) - optimizer = optimizer_def.create(variables['params']) - - --- - - # (Note that you could also use `optax.adamw()` in this case) - tx = optax.chain( - optax.scale_by_adam(), - optax.add_decayed_weights(weight_decay), - # params -= learning_rate * (adam(grads) + params * weight_decay) - optax.scale(-learning_rate), - ) - # Note that you'll need to specify `params` when computing the udpates: - # tx.update(grads, opt_state, params) - -Gradient Clipping ------------------ - -Training can be stabilized by clipping gradients to a global norm (`Pascanu et -al, 2012 `_). In Flax this is often done by -processing the gradients before passing them to the optimizer. With Optax this -becomes just another gradient transformation |optax.clip_by_global_norm()|_. - -.. |optax.clip_by_global_norm()| replace:: ``optax.clip_by_global_norm()`` -.. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api.html#optax.clip_by_global_norm - -.. codediff:: - :title_left: flax.optim - :title_right: optax - :sync: - - def train_step(optimizer, batch): - grads = jax.grad(loss)(optimizer.target, batch) - grads_flat, _ = jax.tree_util.tree_flatten(grads) - global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat])) - g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2) - grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads) - return optimizer.apply_gradient(grads) - - --- - - tx = optax.chain( - optax.clip_by_global_norm(grad_clip_norm), - optax.trace(decay=momentum), - optax.scale(-learning_rate), - ) - -Learning Rate Schedules ------------------------ - -For learning rate schedules, Flax allows overwriting hyper parameters when -applying the gradients. Optax maintains a step counter and provides this as an -argument to a function for scaling the updates added with -|optax.scale_by_schedule()|_. Optax also allows specifying functions to -inject arbitrary scalar values for other gradient updates via -|optax.inject_hyperparams()|_. - -Read more about learning rate schedules in the :doc:`lr_schedule` guide. - -Read more about schedules defined in Optax under `Optimizer Schedules -`_. the -standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a -learning rate schedule as a parameter for ``learning_rate``. - - -.. |optax.scale_by_schedule()| replace:: ``optax.scale_by_schedule()`` -.. _optax.scale_by_schedule(): https://optax.readthedocs.io/en/latest/api.html#optax.scale_by_schedule -.. |optax.inject_hyperparams()| replace:: ``optax.inject_hyperparams()`` -.. _optax.inject_hyperparams(): https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams - -.. codediff:: - :title_left: flax.optim - :title_right: optax - :sync: - - def train_step(step, optimizer, batch): - grads = jax.grad(loss)(optimizer.target, batch) - return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step)) - - --- - - tx = optax.chain( - optax.trace(decay=momentum), - # Note that we still want a negative value for scaling the updates! - optax.scale_by_schedule(lambda step: -schedule(step)), - ) - -Multiple Optimizers / Updating a Subset of Parameters ------------------------------------------------------ - -In Flax, traversals are used to specify which parameters should be updated by an -optimizer. And you can combine traversals using -:py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different -parameters. The equivalent in Optax is |optax.masked()|_ and |optax.chain()|_. - -Note that the example below is using :py:mod:`flax.traverse_util` to create the -boolean masks required by |optax.masked()|_ - alternatively you could also -create them manually, or use |optax.multi_transform()|_ that takes a -multivalent pytree to specify gradient transformations. - -Beware that |optax.masked()|_ flattens the pytree internally and the inner -gradient transformations will only be called with that partial flattened view of -the params/gradients. This is not a problem usually, but it makes it hard to -nest multiple levels of masked gradient transformations (because the inner -masks will expect the mask to be defined in terms of the partial flattened view -that is not readily available outside the outer mask). - -.. |optax.masked()| replace:: ``optax.masked()`` -.. _optax.masked(): https://optax.readthedocs.io/en/latest/api.html#optax.masked -.. |optax.multi_transform()| replace:: ``optax.multi_transform()`` -.. _optax.multi_transform(): https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform - -.. codediff:: - :title_left: flax.optim - :title_right: optax - :sync: - - kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) - biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) - - kernel_opt = flax.optim.Momentum(learning_rate, momentum) - bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum) - - - optimizer = flax.optim.MultiOptimizer( - (kernels, kernel_opt), - (biases, bias_opt) - ).create(variables['params']) - - --- - - kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) - biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) - - all_false = jax.tree_util.tree_map(lambda _: False, params) - kernels_mask = kernels.update(lambda _: True, all_false) - biases_mask = biases.update(lambda _: True, all_false) - - tx = optax.chain( - optax.trace(decay=momentum), - optax.masked(optax.scale(-learning_rate), kernels_mask), - optax.masked(optax.scale(-learning_rate * 0.1), biases_mask), - ) - -Final Words ------------ - -All above patterns can of course also be mixed and Optax makes it possible to -encapsulate all these transformations into a single place outside the main -training loop, which makes testing much easier. diff --git a/docs/advanced_topics/philosophy.md b/docs/advanced_topics/philosophy.md deleted file mode 100644 index d475f221d7..0000000000 --- a/docs/advanced_topics/philosophy.md +++ /dev/null @@ -1,37 +0,0 @@ -# The Flax Philosophy - -(in no particular order) - -* Library code should be easy to read and understand. - -* Prefer duplicating code over a bad abstraction. - -* Generally, prefer duplicating code over adding options to functions. - -* Comment-driven design: If it's hard to document your code, consider - changing the design. - -* Unit test-driven design: If it's hard to test your code, consider - changing the design. - -* People start projects by copying an existing implementation -- make - base implementations excellent. - -* If we expose an abstraction to our developers, we own the mental - overhead. - -* Developer-facing functional programming abstractions confuse some users, - expose them where the benefit is high. - -* "Read the manual" is not an appropriate response to developer confusion. - The framework should guide developers - towards good solutions, e.g. through assertions and error messages. - -* An unhelpful error message is a bug. - -* "Debugging is twice as hard as writing the code in the first - place. Therefore, if you write the code as cleverly as possible, you - are, by definition, not smart enough to debug it." -Brian Kernighan - - - diff --git a/docs/developer_notes/module_lifecycle.rst b/docs/developer_notes/module_lifecycle.rst index 1c9f4ed8af..799c58fa46 100644 --- a/docs/developer_notes/module_lifecycle.rst +++ b/docs/developer_notes/module_lifecycle.rst @@ -130,7 +130,7 @@ Linen provides an alternative API for defining modules more compactly. This is e return nn.Dense(self.out_size)(h) -A compact ``Module`` is similar in spirit to a function. It offers a concise notation and restricts external interaction to the inputs and return values of the function. In this case the concise notation might make it easier for others to understand what the Module does. There is no need to jump back and forth between the ``setup`` and ``__call__`` method to understand what the submodules are doing. Instead, simply reading the ``__call__`` method from top to bottom once should provide a concise overview. This can make a significant difference if you are implementing complex Modules with many hyperparameters. See `setup or compact `_ for a practical guide on deciding between setup and compact. +A compact ``Module`` is similar in spirit to a function. It offers a concise notation and restricts external interaction to the inputs and return values of the function. In this case the concise notation might make it easier for others to understand what the Module does. There is no need to jump back and forth between the ``setup`` and ``__call__`` method to understand what the submodules are doing. Instead, simply reading the ``__call__`` method from top to bottom once should provide a concise overview. This can make a significant difference if you are implementing complex Modules with many hyperparameters. See `setup or compact `_ for a practical guide on deciding between setup and compact. Another benefit of defining submodules and/or variables inline is that you can add arguments to your method when constructing variables. The most common example of this is using shape information to determine the shape of a parameter like this: @@ -286,7 +286,7 @@ This procedure itself is tedious and error-prone but handled internally by Flax. #. Update the original state with the updated state returned from the transformation. -A more in depth explanation of functionalization and lifting can be found in the `Lifted Transformation `_ design note. +A more in depth explanation of functionalization and lifting can be found in the `Lifted Transformation `_ design note. Practical consequences ========================== diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 4b63cceb64..aeb08ca641 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -472,6 +472,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "809ae1a0", "metadata": { @@ -486,7 +487,7 @@ " and use it for parameter initialization. (Learn\n", " more about\n", " [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", - " and [PRNG chains](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" + " and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" ] }, { @@ -852,7 +853,8 @@ "main_language": "python" }, "language_info": { - "name": "python" + "name": "python", + "version": "3.9.6" } }, "nbformat": 4, diff --git a/docs/getting_started.md b/docs/getting_started.md index 67fba8e4d4..32dab82e5b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -311,7 +311,7 @@ train_ds, test_ds = get_datasets(num_epochs, batch_size) and use it for parameter initialization. (Learn more about [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) - and [PRNG chains](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) + and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) ```{code-cell} --- diff --git a/docs/glossary.rst b/docs/glossary.rst index 2903872961..7b76422baa 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -15,7 +15,7 @@ For additional terms, refer to the `Jax glossary `__. + `module lifecycle `__. Compact / Non-compact Module Modules with a single method are able to declare submodules and variables inline by @@ -53,7 +53,7 @@ For additional terms, refer to the `Jax glossary `__. + Refer to the `Flax docs `__. Module A dataclass allowing the definition and initialization of parameters in a @@ -74,7 +74,7 @@ For additional terms, refer to the `Jax glossary `__. + `lifting transformations `__. Scope A container class for holding the variables and PRNG keys for each layer. diff --git a/docs/guides/linen_upgrade_guide.rst b/docs/guides/linen_upgrade_guide.rst index ab179a5e3e..b758a785a5 100644 --- a/docs/guides/linen_upgrade_guide.rst +++ b/docs/guides/linen_upgrade_guide.rst @@ -481,11 +481,11 @@ For more information, please see the design note on `Lifted transformations`_. TODO: Given an example of ``jax.scan_in_dim`` (pre-Linen) vs. ``nn.scan`` (Linen). -.. _`Should I use setup or nn.compact?`: https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html +.. _`Should I use setup or nn.compact?`: https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html .. _`Variables documentation`: https://flax.readthedocs.io/en/latest/flax.linen.html#module-flax.core.variables .. _`TrainState`: https://flax.readthedocs.io/en/latest/flax.training.html#train-state .. _`Upgrading my codebase to Optax`: https://flax.readthedocs.io/en/latest/guides/optax_update_guide.html -.. _`Lifted transformations`: https://flax.readthedocs.io/en/latest/design_notes/lift.html +.. _`Lifted transformations`: https://flax.readthedocs.io/en/latest/developer_notes/lift.html .. |@compact| replace:: ``@compact`` diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 035ecab8ad..c88443d1cb 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -15,7 +15,7 @@ """Recurrent neural network modules. THe RNNCell modules can be scanned using lifted transforms. For more information -see: https://flax.readthedocs.io/en/latest/advanced_topics/lift.html. +see: https://flax.readthedocs.io/en/latest/developer_notes/lift.html. """ from functools import partial # pylint: disable=g-importing-member diff --git a/flax/traverse_util.py b/flax/traverse_util.py index 61ef2203db..e12b27be4b 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -198,7 +198,7 @@ def __new__(cls, *args, **kwargs): warnings.warn( '`flax.traverse_util.Traversal` will be deprecated. If you are using ' 'it for `flax.optim`, use `optax` instead. Refer to the update guide ' - 'https://flax.readthedocs.io/en/latest/advanced_topics/optax_update_guide.html ' + 'https://flax.readthedocs.io/en/latest/guides/optax_update_guide.html ' 'for detailed instructions.', DeprecationWarning) return super().__new__(cls)