diff --git a/docs/guides/converting_and_upgrading/optax_update_guide.ipynb b/docs/guides/converting_and_upgrading/optax_update_guide.ipynb new file mode 100644 index 0000000000..e81bf29bec --- /dev/null +++ b/docs/guides/converting_and_upgrading/optax_update_guide.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "004a425ab000" + }, + "source": [ + "# Upgrading my codebase to Optax from `flax.optim`\n", + "\n", + "In 2021, [FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md) proposed to replace `flax.optim` with [Optax](https://optax.readthedocs.io). And since Flax v0.6.0, Optax has been the default Flax optimizer library. This guide shows how to update your :py:mod:`flax.optim` code to Optax.\n", + "\n", + "You can also refer to the [Optax 101 (quick start)](https://optax.readthedocs.io/en/latest/optax-101.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c47f55048d1e" + }, + "source": [ + "## Setup and imports\n", + "\n", + "Install/upgrade Flax in Colab (the Flax package comes with Optax):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3d36c6e9da4c" + }, + "outputs": [], + "source": [ + "!pip install --upgrade -q pip jax jaxlib flax" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f3e05fb841c9" + }, + "outputs": [], + "source": [ + "# Import the necessary libraries\n", + "\n", + "import flax\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import flax.linen as nn\n", + "import optax\n", + "\n", + "# Note: this is the minimal code required to make below code run.\n", + "\n", + "batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])}\n", + "ds_train = [batch]\n", + "get_ds_train = lambda: [batch]\n", + "model = nn.Dense(1)\n", + "variables = model.init(jax.random.key(0), batch['image'])\n", + "learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1.\n", + "loss = lambda params, batch: jnp.array(0.)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "65090e843841" + }, + "source": [ + "## Replacing ``flax.optim`` with ``optax``" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "301328d793ba" + }, + "outputs": [], + "source": [ + "Optax has drop-in replacements for all of Flax's optimizers. Refer to the Optax [Common Optimizers API docs](https://optax.readthedocs.io/en/latest/api.html) for more details.\n", + "\n", + "The usage is very similar, with some differences that include:\n", + "\n", + "- Optax (`optax`) does not keep a copy of the `params`, so they need to be passed around separately.\n", + "- Flax provides the utility `flax.training.train_state.TrainState` to store the optimizer state, parameters, and other associated data in a single dataclass (not used in the code example below)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bb2b17689f4f" + }, + "outputs": [], + "source": [ + "# # Code with `flax.optim`\n", + "# \n", + "# @jax.jit\n", + "# def train_step(optimizer, batch):\n", + "# grads = jax.grad(loss)(optimizer.target, batch)\n", + "# return optimizer.apply_gradient(grads)\n", + "# \n", + "# optimizer_def = flax.optim.Momentum(\n", + "# learning_rate, momentum)\n", + "# optimizer = optimizer_def.create(variables['params'])\n", + "# \n", + "# for batch in get_ds_train():\n", + "# optimizer = train_step(optimizer, batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fb5fea94b9a1" + }, + "outputs": [], + "source": [ + "# Code with Optax\n", + "\n", + "@jax.jit\n", + "def train_step(params, opt_state, batch):\n", + " grads = jax.grad(loss)(params, batch)\n", + " updates, opt_state = tx.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " return params, opt_state\n", + "\n", + "tx = optax.sgd(learning_rate, momentum)\n", + "params = variables['params']\n", + "opt_state = tx.init(params)\n", + "\n", + "for batch in ds_train:\n", + " params, opt_state = train_step(params, opt_state, batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f48d6324bb90" + }, + "source": [ + "Composable gradient transformations\n", + "-----------------------------------\n", + "\n", + "The function |optax.sgd()|_ used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use |optax.chain()|_ to combine multiple of these generic building blocks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f982f184943e" + }, + "outputs": [], + "source": [ + "# # Code with Optax (pre-defined alias)\n", + "#\n", + "# # Note that the aliases follow the convention to use positive\n", + "# # values for the learning rate by default.\n", + "# tx = optax.sgd(learning_rate, momentum)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "14d116acebe2" + }, + "outputs": [], + "source": [ + "# Code with Optax (combining transformations)\n", + "\n", + "tx = optax.chain(\n", + " # 1. Step: keep a trace of past updates and add to gradients.\n", + " optax.trace(decay=momentum),\n", + " # 2. Step: multiply result from step 1 with negative learning rate.\n", + " # Note that `optax.apply_updates()` simply adds the final updates to the\n", + " # parameters, so we must make sure to flip the sign here for gradient\n", + " # descent.\n", + " optax.scale(-learning_rate),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "60bd02c4be2d" + }, + "source": [ + "## Weight decay\n", + "\n", + "- Some of `flax.optim` optimizers include the weight decay parameter.\n", + "- In Optax, some optimizers also have a weight decay parameter (such as `optax.adamw()`). For other optimizers that don't have it by default, the weight decay can be added as another \"gradient transformation\" `optax.add_decayed_weights()` that adds an update derived from the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "be35feb00b46" + }, + "outputs": [], + "source": [ + "# # Code with `flax.optim`\n", + "\n", + "optimizer_def = flax.optim.Adam(\n", + " learning_rate, weight_decay=weight_decay)\n", + "optimizer = optimizer_def.create(variables['params'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aca35545407a" + }, + "outputs": [], + "source": [ + "# Code with Optax\n", + "\n", + "# (Note that you could also use `optax.adamw()` in this case)\n", + "tx = optax.chain(\n", + " optax.scale_by_adam(),\n", + " optax.add_decayed_weights(weight_decay),\n", + " # params -= learning_rate * (adam(grads) + params * weight_decay)\n", + " optax.scale(-learning_rate),\n", + ")\n", + "# Note that you'll need to specify `params` when computing the udpates:\n", + "# tx.update(grads, opt_state, params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "347f546f7e21" + }, + "source": [ + "## Gradient clipping\n", + "\n", + "Training can be stabilized by clipping gradients to a global norm ([Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)).\n", + "\n", + "- In Flax this is often done by processing the gradients before passing them to the optimizer.\n", + "- In Optax this becomes just another gradient transformation `optax.clip_by_global_norm()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "106017c5f5cc" + }, + "outputs": [], + "source": [ + "# # Code with `flax.optim`\n", + "#\n", + "# def train_step(optimizer, batch):\n", + "# grads = jax.grad(loss)(optimizer.target, batch)\n", + "# grads_flat, _ = jax.tree_util.tree_flatten(grads)\n", + "# global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))\n", + "# g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)\n", + "# grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)\n", + "# return optimizer.apply_gradient(grads)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5b0e77e9e7d2" + }, + "outputs": [], + "source": [ + "# Code with Optax\n", + "\n", + "tx = optax.chain(\n", + " optax.clip_by_global_norm(grad_clip_norm),\n", + " optax.trace(decay=momentum),\n", + " optax.scale(-learning_rate),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "63edf0d78788" + }, + "source": [ + "## Learning rate schedules\n", + "\n", + "For learning rate schedules: \n", + "\n", + "- Flax allows overwriting hyper parameters when applying the gradients.\n", + "- Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with `optax.scale_by_schedule()`. Optax also allows specifying functions to inject arbitrary scalar values for other gradient updates via `optax.inject_hyperparams()`.\n", + "\n", + "You can learn more in the [Learning rate scheduling](https://flax.readthedocs.io/en/latest/guides/training_techniques/lr_schedule.html) guide and the Optax [Optimizer schedules](https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules) API docs. Note that the\n", + "standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a\n", + "learning rate schedule as a parameter for ``learning_rate``." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cc927742dd25" + }, + "outputs": [], + "source": [ + "# # Code with `flax.optim`\n", + "#\n", + "# def train_step(step, optimizer, batch):\n", + "# grads = jax.grad(loss)(optimizer.target, batch)\n", + "# return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1d3ea2eb883b" + }, + "outputs": [], + "source": [ + "# Code with Optax\n", + "\n", + "tx = optax.chain(\n", + " optax.trace(decay=momentum),\n", + " # Note that we still want a negative value for scaling the updates!\n", + " optax.scale_by_schedule(lambda step: -schedule(step)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b922b3b678fc" + }, + "source": [ + "## Multiple optimizers / Updating a subset of parameters\n", + "\n", + "In Flax, traversals are used to specify which parameters should be updated by an\n", + "optimizer.\n", + "- Combining traversals using `flax.optim` was accomplished with `flax.optim.MultiOptimizer` for applying apply different optimizers on different parameters.\n", + "- In Optax, the equivalent methods are `optax.masked()` and `optax.chain()`.\n", + "\n", + "Note that the example below is using `flax.traverse_util` to create the boolean masks required by `optax.masked()`. Alternatively, you could also create them manually, or use `optax.multi_transform()` that takes a multivalent pytree to specify gradient transformations.\n", + "\n", + "Beware that `optax.masked()` flattens the pytree internally, and the inner\n", + "gradient transformations will only be called with that partial flattened view of\n", + "the params/gradients. This is not a problem usually, but it makes it hard to\n", + "nest multiple levels of masked gradient transformations (because the inner\n", + "masks will expect the mask to be defined in terms of the partial flattened view\n", + "that is not readily available outside the outer mask)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f97d8d4db0e0" + }, + "outputs": [], + "source": [ + "# # Code with `flax.optim`\n", + "# \n", + "# kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)\n", + "# biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)\n", + "# \n", + "# kernel_opt = flax.optim.Momentum(learning_rate, momentum)\n", + "# bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)\n", + "# \n", + "# optimizer = flax.optim.MultiOptimizer(\n", + "# (kernels, kernel_opt),\n", + "# (biases, bias_opt)\n", + "# ).create(variables['params'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bdc911d53101" + }, + "outputs": [], + "source": [ + "# Code with Optax\n", + "\n", + "kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)\n", + "biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)\n", + "\n", + "all_false = jax.tree_util.tree_map(lambda _: False, params)\n", + "kernels_mask = kernels.update(lambda _: True, all_false)\n", + "biases_mask = biases.update(lambda _: True, all_false)\n", + "\n", + "tx = optax.chain(\n", + " optax.trace(decay=momentum),\n", + " optax.masked(optax.scale(-learning_rate), kernels_mask),\n", + " optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dabfd735ffab" + }, + "source": [ + "## Final words\n", + "\n", + "The patterns described in this guide can be mixed together, and Optax makes it possible to\n", + "encapsulate the transformations mentioned here into a single place outside of the main\n", + "training loop, which makes testing much easier." + ] + } + ], + "metadata": { + "colab": { + "name": "optax_update_guide.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 + } + \ No newline at end of file diff --git a/docs/guides/converting_and_upgrading/optax_update_guide.rst b/docs/guides/converting_and_upgrading/optax_update_guide.rst index a1a07df09d..1d821c20b6 100644 --- a/docs/guides/converting_and_upgrading/optax_update_guide.rst +++ b/docs/guides/converting_and_upgrading/optax_update_guide.rst @@ -1,14 +1,14 @@ -Upgrading my codebase to Optax -============================== +Upgrading my codebase to Optax from ``flax.optim`` +================================================== -We have proposed to replace :py:mod:`flax.optim` with `Optax -`_ in 2021 with `FLIP #1009 -`_ and -the Flax optimizers have been removed in v0.6.0 - this guide is targeted -towards :py:mod:`flax.optim` users to help them update their code to Optax. +**`Open in Colab `_** -See also Optax's quick start documentation: -https://optax.readthedocs.io/en/latest/optax-101.html +In 2021, `FLIP #1009 `_ +proposed to replace :py:mod:`flax.optim` with `Optax `_. +And since Flax v0.6.0, Optax has been the default Flax optimizer library. +This guide shows how to update your :py:mod:`flax.optim` code to Optax. + +You can also refer to the `Optax 101 (quick start) `_. .. testsetup:: @@ -31,15 +31,14 @@ https://optax.readthedocs.io/en/latest/optax-101.html Replacing ``flax.optim`` with ``optax`` --------------------------------------- -Optax has drop-in replacements for all of Flax's optimizers. Refer to Optax's -documentation `Common Optimizers `_ -for API details. +Optax has drop-in replacements for all of Flax's optimizers. Refer to the Optax +`Common Optimizers API docs `_ +for more details. + +The usage is very similar, with some differences that include: -The usage is very similar, with the difference that ``optax`` does not keep a -copy of the ``params``, so they need to be passed around separately. Flax -provides the utility :py:class:`~flax.training.train_state.TrainState` to store -optimizer state, parameters, and other associated data in a single dataclass -(not used in code below). +- Optax (``optax``) does not keep a copy of the ``params``, so they need to be passed around separately. +- Flax provides the utility :py:class:`~flax.training.train_state.TrainState` to store the optimizer state, parameters, and other associated data in a single dataclass (not used in the code example below). .. codediff:: :title_left: flax.optim @@ -77,7 +76,7 @@ optimizer state, parameters, and other associated data in a single dataclass params, opt_state = train_step(params, opt_state, batch) -Composable Gradient Transformations +Composable gradient transformations ----------------------------------- The function |optax.sgd()|_ used in the code snippet above is simply a wrapper @@ -112,13 +111,11 @@ generic building blocks. optax.scale(-learning_rate), ) -Weight Decay +Weight decay ------------ -Some of Flax's optimizers also include a weight decay. In Optax, some optimizers -also have a weight decay parameter (such as |optax.adamw()|_), and to others the -weight decay can be added as another "gradient transformation" -|optax.add_decayed_weights()|_ that adds an update derived from the parameters. +- Some of `flax.optim` optimizers include the weight decay parameter. +- In Optax, some optimizers also have a weight decay parameter (such as |optax.adamw()|_).. For other optimizers that don't have it by default, the weight decay can be added as another "gradient transformation" |optax.add_decayed_weights()|_ that adds an update derived from the parameters. .. |optax.adamw()| replace:: ``optax.adamw()`` .. _optax.adamw(): https://optax.readthedocs.io/en/latest/api.html#optax.adamw @@ -146,13 +143,14 @@ weight decay can be added as another "gradient transformation" # Note that you'll need to specify `params` when computing the udpates: # tx.update(grads, opt_state, params) -Gradient Clipping +Gradient clipping ----------------- Training can be stabilized by clipping gradients to a global norm (`Pascanu et -al, 2012 `_). In Flax this is often done by -processing the gradients before passing them to the optimizer. With Optax this -becomes just another gradient transformation |optax.clip_by_global_norm()|_. +al, 2012 `_). + +- In Flax this is often done by processing the gradients before passing them to the optimizer. +- With Optax this becomes just another gradient transformation |optax.clip_by_global_norm()|_. .. |optax.clip_by_global_norm()| replace:: ``optax.clip_by_global_norm()`` .. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api.html#optax.clip_by_global_norm @@ -178,20 +176,17 @@ becomes just another gradient transformation |optax.clip_by_global_norm()|_. optax.scale(-learning_rate), ) -Learning Rate Schedules +Learning rate schedules ----------------------- -For learning rate schedules, Flax allows overwriting hyper parameters when -applying the gradients. Optax maintains a step counter and provides this as an -argument to a function for scaling the updates added with -|optax.scale_by_schedule()|_. Optax also allows specifying a functions to -inject arbitrary scalar values for other gradient updates via -|optax.inject_hyperparams()|_. +For learning rate schedules: -Read more about learning rate schedules in the :doc:`lr_schedule` guide. +- Flax allows overwriting hyperparameters when applying the gradients. +- Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with |optax.scale_by_schedule()|_. Optax also allows specifying functions to inject arbitrary scalar values for other gradient updates via |optax.inject_hyperparams()|_. -Read more about schedules defined in Optax under `Optimizer Schedules -`_. the +You can learn more in the `Learning rate scheduling +`_ guide and the Optax `Optimizer schedules +`_ API docs. Note that the standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a learning rate schedule as a parameter for ``learning_rate``. @@ -218,20 +213,21 @@ learning rate schedule as a parameter for ``learning_rate``. optax.scale_by_schedule(lambda step: -schedule(step)), ) -Multiple Optimizers / Updating a Subset of Parameters +Multiple optimizers / Updating a subset of parameters ----------------------------------------------------- In Flax, traversals are used to specify which parameters should be updated by an -optimizer. And you can combine traversals using -:py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different -parameters. The equivalent in Optax is |optax.masked()|_ and |optax.chain()|_. +optimizer. + +- Combining traversals using :py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different parameters. +- In Opytax, the equivalent methods are |optax.masked()|_ and |optax.chain()|_. Note that the example below is using :py:mod:`flax.traverse_util` to create the -boolean masks required by |optax.masked()|_ - alternatively you could also +boolean masks required by |optax.masked()|_. Alternatively, you could also create them manually, or use |optax.multi_transform()|_ that takes a multivalent pytree to specify gradient transformations. -Beware that |optax.masked()|_ flattens the pytree internally and the inner +Beware that |optax.masked()|_ flattens the pytree internally, and the inner gradient transformations will only be called with that partial flattened view of the params/gradients. This is not a problem usually, but it makes it hard to nest multiple levels of masked gradient transformations (because the inner @@ -275,9 +271,9 @@ that is not readily available outside the outer mask). optax.masked(optax.scale(-learning_rate * 0.1), biases_mask), ) -Final Words +Final words ----------- -All above patterns can of course also be mixed and Optax makes it possible to -encapsulate all these transformations into a single place outside the main -training loop, which makes testing much easier. +The patterns described in this guide can be mixed together, and Optax makes it possible to +encapsulate the transformations mentioned here into a single place outside of the main +training loop, which makes testing much easier. \ No newline at end of file