From e566bb6145c7a3449ba7fb3988335b2499f97b92 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 10 Nov 2023 14:54:41 -0500 Subject: [PATCH 1/2] improve why + readme --- .pre-commit-config.yaml | 2 +- flax/experimental/nnx/README.md | 11 - flax/experimental/nnx/docs/why.ipynb | 514 ++++++++++++++++++++--- flax/experimental/nnx/docs/why.md | 404 ++++++++++++++++++ flax/experimental/nnx/tests/test_rngs.py | 4 +- 5 files changed, 851 insertions(+), 84 deletions(-) create mode 100644 flax/experimental/nnx/docs/why.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e64dcb476..57decb1562 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: --keep-output, --keep-count, --extra-keys, - "metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", + "cell.metadata.executionInfo cell.metadata.id metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index 5cc033c3aa..7fb5972da1 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -133,17 +133,6 @@ NNX takes the best features that allow Flax to scale to large projects and integ One place in which NNX strongly deviates from Flax is that (currently) it avoids shape inference in favor of static initialization. It is not a technical limitation but rather a design choice. This design both simplifies the internal implementation and makes it easier to reason about the code for the user, at the cost of being more verbose at times. On the other hand, Pytorch users will feel right at home. -### How is it different from Equinox? -While they might look similar at a surface-level, NNX's Module system is more powerful and flexible than Equinox's, it contains the following additional features: - -* Uses regular python classes (no mandatory dataclass behavior). -* Modules are mutable -* Reference sharing between Modules is allowed -* Mutable state lives inside the Module (no need for a separate [State container](https://docs.kidger.site/equinox/examples/stateful/)). -* Supports node metadata and semantic partitioning. - -One major difference between the two frameworks is that, by design, NNX Modules are not Pytrees. This adds a safety layer as it prevents state updates from being lost by accident due to referential transparency. It also removes the need of threading a separate [State container](https://docs.kidger.site/equinox/examples/stateful/) throughout the code in order to propagate state. In NNX state updates are either always preserved or explicitly discarded by the user. - ## User Guide ### Modules diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index 54269e982e..0fc10c0168 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -13,12 +13,28 @@ " - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms)\n", " - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses.\n", "\n", - "One choice we made was to use functional \"define by call\" semantics for NN programming via the lazy (ie just in time) initialization of parameters. This made for concise (`compact`) implementation code and allowed for a single specification when transforming a layer. It also aligned our API to be closer to Haiku. However that lazy-init meant that the semantics of variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets.\n", + "However, one choice we made was to use functional \"define by call\" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets.\n", "\n", - "NNX is an attempt to keep the features that made Linen great while introducing some new principles:\n", + "NNX is an attempt to keep the features that made Linen useful while introducing some new principles:\n", "\n", "- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references.\n", - "- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks." + "- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks.\n", + "\n", + "We'd love to hear from any of our users about their thoughts on these ideas.\n", + "\n", + "[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)]\n", + "[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)]" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "import jax\n", + "from jax import random, numpy as jnp" ] }, { @@ -36,16 +52,11 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": 2, + "metadata": { + "outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" + }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -58,20 +69,18 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x7f5d3c57baf0>,\n", - " bias_init=,\n", - " dot_general=\n", + " kernel_init=.init at 0x7f3dc9ad3370>,\n", + " bias_init=,\n", + " dot_general=\n", " )\n", ")\n" ] } ], "source": [ - "from flax.experimental import nnx\n", - "import jax\n", - "from jax import random, numpy as jnp\n", + "class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", + " pass\n", "\n", - "class Count(nnx.Variable): pass\n", "\n", "class CounterLinear(nnx.Module):\n", " def __init__(self, din, dout, *, rngs): # explicit RNG threading\n", @@ -79,7 +88,7 @@ " self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections\n", "\n", " def __call__(self, x):\n", - " self.count += 1 # inplace stateful updates\n", + " self.count += 1 # in-place stateful updates\n", " return self.linear(x)\n", "\n", "\n", @@ -98,8 +107,10 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": 3, + "metadata": { + "outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" + }, "outputs": [ { "name": "stdout", @@ -129,33 +140,77 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], + "execution_count": 23, + "metadata": { + "outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[1.7531997, 1.6318591, 2.1417565, 3.120555 ],\n", + " [1.7531997, 1.6318591, 2.1417565, 3.120555 ]], dtype=float32)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def load_pretrained():\n", - " return nnx.Linear(4, 4, rngs=nnx.Rngs(42)) # pretend this is pretrained\n", - "\n", - "model.linear = load_pretrained() # you can replace modules\n", + "# pretend this came from a checkpoint or elsewhere:\n", + "pretrained_weight = random.uniform(random.key(0), (4, 4))\n", "\n", - "y = model(jnp.ones((2, 4)))" + "# you can replace weights directly\n", + "model.linear.kernel = pretrained_weight\n", + "y = model(jnp.ones((2, 4)))\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "outputId": "5190ac7b-12f7-4400-d5bb-f91b97a557b6" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[1.624419 , 0.8313738 , 0.37612876, 1.9937458 ],\n", + " [1.624419 , 0.8313738 , 0.37612876, 1.9937458 ]], dtype=float32)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def load_pretrained_fragment():\n", + " # pretend this inits / loads some fragment of a model\n", + " replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1))\n", + " return replacement\n", + "\n", + "# you can replace modules directly\n", + "model.linear = load_pretrained_fragment()\n", + "y = model(jnp.ones((2, 4)))\n", + "y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The benefit of this is not only that its easier than messing with dictionary structures, but can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before)." + "Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before)." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ - "from functools import partial\n", - "\n", "rngs = nnx.Rngs(0)\n", "model = nnx.Sequence(\n", " [\n", @@ -170,6 +225,7 @@ "\n", "y = model(jnp.ones((2, 28, 28, 1)))\n", "\n", + "# Do some weird surgery of the stack:\n", "for i, layer in enumerate(model):\n", " if isinstance(layer, nnx.Conv):\n", " model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)\n", @@ -190,15 +246,19 @@ "source": [ "### Interacting with JAX is easy\n", "\n", - "While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations. NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", + "While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code.\n", + "\n", + "NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", "\n", "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module." ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": 96, + "metadata": { + "outputId": "9a3f378b-739e-4f45-9968-574651200ede" + }, "outputs": [ { "name": "stdout", @@ -211,7 +271,29 @@ " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", - "})\n" + "})\n", + "\n", + "static = ModuleDef(\n", + " type=CounterLinear,\n", + " index=0,\n", + " static_fields=(),\n", + " variables=(('count', Count(\n", + " value=Empty\n", + " )),),\n", + " submodules=(\n", + " ('linear', ModuleDef(\n", + " type=Linear,\n", + " index=1,\n", + " static_fields=(('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x7f3dc9ad3370>), ('out_features', 4), ('param_dtype', ), ('precision', None), ('use_bias', True)),\n", + " variables=(('bias', Param(\n", + " value=Empty\n", + " )), ('kernel', Param(\n", + " value=Empty\n", + " ))),\n", + " submodules=()\n", + " ))\n", + " )\n", + ")\n" ] } ], @@ -220,22 +302,28 @@ "\n", "state, static = model.split()\n", "\n", - "print(f'{state = }')" + "# state is a dictionary-like JAX pytree\n", + "print(f'{state = }')\n", + "\n", + "# static is also a JAX pytree, but containing no data, just metadata\n", + "print(f'\\n{static = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. \n", + "The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object.\n", "\n", "Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, + "execution_count": 97, + "metadata": { + "outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" + }, "outputs": [ { "name": "stdout", @@ -248,14 +336,14 @@ ], "source": [ "@jax.jit\n", - "def forward(state: nnx.State, x: jax.Array):\n", + "def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array):\n", " model = static.merge(state)\n", " y = model(x)\n", " state, _ = model.split()\n", " return y, state\n", "\n", "x = jnp.ones((2, 4))\n", - "y, state = forward(state, x)\n", + "y, state = forward(static,state, x)\n", "\n", "print(f'{y.shape = }')\n", "print(f'{state[\"count\"] = }')" @@ -265,17 +353,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Custom lifted Modules\n", + "#### Custom lifting and transformation\n", + "\n", + "By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior.\n", + "\n", + "One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases.\n", "\n", - "By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes this so easy that its realistic to implement custom lifted Modules for specific use cases.\n", + "As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple.\n", "\n", - "As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple:" + "It uses the single additional method `update` to locally modify model state." ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 98, + "metadata": { + "outputId": "fdd212d7-4994-4fa5-d922-5a7d7cfad3e3" + }, "outputs": [ { "name": "stdout", @@ -302,29 +396,34 @@ " return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(\n", " nnx.Param, Count\n", " )\n", - "\n", " params, counts, static = jax.vmap(\n", " vmap_init, in_axes=(0,), out_axes=(0, None, None)\n", " )(keys)\n", + "\n", " # update wrapped submodule reference\n", " self.models = static.merge(params, counts)\n", "\n", " def __call__(self, x):\n", - " # get module values, define pure fn\n", + " # get module values, define pure fn,\n", + " # notice that we split the data into two collections by their types.\n", " params, counts, static = self.models.split(nnx.Param, Count)\n", "\n", + " # define pure init fn and vmap\n", " def vmap_apply(x, params, counts, static):\n", " model = static.merge(params, counts)\n", " y = model(x)\n", " params, counts, static = model.split(nnx.Param, Count)\n", " return y, params, counts, static\n", "\n", - " # vmap and call\n", " y, params, counts, static = jax.vmap(\n", - " vmap_apply, in_axes=(None, 0, None, None), out_axes=(0, 0, None, None)\n", + " vmap_apply,\n", + " in_axes=(None, 0, None, None),\n", + " out_axes=(0, 0, None, None)\n", " )(x, params, counts, static)\n", + "\n", " # update wrapped module\n", - " self.models.update(params, counts, static) # use `update` to integrate the new state\n", + " # uses `update` to integrate the new state\n", + " self.models.update(params, counts, static)\n", " return y\n", "\n", "x = jnp.ones((4,))\n", @@ -342,16 +441,125 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Why Modules are not Pytrees?\n", + "#### Convenience lifted transforms" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications." + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": { + "outputId": "c4800a49-efd1-4ee5-e703-6e63e18da4cb" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'scan_module/bias': Array([[0., 0.],\n", + " [0., 0.],\n", + " [0., 0.],\n", + " [0., 0.]], dtype=float32),\n", + " 'scan_module/kernel': Array([[[-0.32325608, 0.16164146],\n", + " [ 0.46505648, -0.34060344]],\n", + " \n", + " [[-1.1558908 , 1.2445341 ],\n", + " [-1.3710847 , -0.1787171 ]],\n", + " \n", + " [[-0.68510336, 0.25847596],\n", + " [ 1.0730107 , -0.11857361]],\n", + " \n", + " [[-0.01770882, 0.5472832 ],\n", + " [-0.84826714, 0.17867221]]], dtype=float32)\n", + "})" + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# class transform:\n", + "ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n", + "\n", + "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", + "scanned.get_state()" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "metadata": { + "outputId": "9efd6e71-d180-4674-ade0-2b02057a400b" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'model/bias': Array([[0., 0.],\n", + " [0., 0.],\n", + " [0., 0.],\n", + " [0., 0.]], dtype=float32),\n", + " 'model/kernel': Array([[[-0.32325608, 0.16164146],\n", + " [ 0.46505648, -0.34060344]],\n", + " \n", + " [[-1.1558908 , 1.2445341 ],\n", + " [-1.3710847 , -0.1787171 ]],\n", + " \n", + " [[-0.68510336, 0.25847596],\n", + " [ 1.0730107 , -0.11857361]],\n", + " \n", + " [[-0.01770882, 0.5472832 ],\n", + " [-0.84826714, 0.17867221]]], dtype=float32)\n", + "})" + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# method decorators:\n", + "\n", + "class ScannedLinear(nnx.Module):\n", + "\n", + " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", + " def __init__(self, din, dout, *, rngs: nnx.Rngs):\n", + " self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs))\n", + "\n", + " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", + " def __call__(self, x):\n", + " return self.model(x)\n", + "\n", + "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", + "scanned.get_state()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Aside: Why aren't Modules Pytrees?\n", + "\n", + "A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`.\n", "\n", - "Finally one of the most common questions we get is why NNX Modules are not Pytrees? Given the existance of Pytree-based NN frameworks like Equinox, Treex, [PytreeClass](https://github.com/ASEM000/PyTreeClass), it is a fair question.\n", + "The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees.\n", "\n", - "The answer is that Pytrees assume value semantics (referencial transparency) while Modules assume reference semantics, and therefore its not a good idea for Modules to be Pytrees. As an example, lets take a look at what would happen if we allowed this very simple program to be valid:" + "As an example, lets take a look at what would happen if we allowed this very simple program to be valid:" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -367,25 +575,191 @@ "Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong?\n", "\n", "There are two main problems with this:\n", - "* Shared references are not maintained, that is, if `m1.shared is m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.\n", - "* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undisired asymmetry and `jit` would no longer be a no-op." + "* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.\n", + "* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Standardized \"Hooks\"\n", + "\n", + "NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications." + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "outputId": "c4e6586a-bfe0-4f26-d05b-8c9e395971b2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.kernel.shape = (4, 8)\n", + "outer kernel shape = (8, 4)\n" + ] + } + ], + "source": [ + "class TransposedParam(nnx.Variable):\n", + " def create_value(self, value):\n", + " return value.T # called on variable creation to transform initial value\n", + " def get_value(self):\n", + " return self.value.T # called when value fetched via module getattr\n", + " def set_value(self, value):\n", + " return self.replace(value=value.T) # called when setting value from module setattr\n", + "\n", + "\n", + "class OddLinear(nnx.Module):\n", + " def __init__(self, din, dout, *, rngs):\n", + " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)))\n", + " self.bias = nnx.Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x):\n", + " print(f'{self.kernel.shape = }')\n", + " return x @ self.kernel + self.bias\n", + "\n", + "\n", + "model = OddLinear(4, 8, rngs=nnx.Rngs(0))\n", + "y = model(jnp.ones((2, 4)))\n", + "\n", + "print(f'outer kernel shape = {model.split()[0][\"kernel\"].shape}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema:" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": { + "outputId": "ef312738-0f56-4c0e-9aaf-3319d131f1a2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state.variables['kernel'].meta='foo'\n", + "state.variables['kernel'].other_meta=0\n", + "state.variables['bias'].meta='bar'\n", + "state.variables['bias'].other_meta=1\n" + ] + } + ], + "source": [ + "class MetadataParam(nnx.Param):\n", + " def __init__(self, *args, **kwargs):\n", + " for key in kwargs:\n", + " setattr(self, key, kwargs[key])\n", + " super().__init__(*args)\n", + "\n", + "\n", + "class AnnotatedLinear(nnx.Module):\n", + " def __init__(self, din, dout, *, rngs):\n", + " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0)\n", + " self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1)\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.kernel + self.bias\n", + "\n", + "\n", + "model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0))\n", + "y = model(jnp.ones((2, 4)))\n", + "\n", + "state, static = model.split()\n", + "\n", + "print(f\"{state.variables['kernel'].meta=}\\n{state.variables['kernel'].other_meta=}\")\n", + "print(f\"{state.variables['bias'].meta=}\\n{state.variables['bias'].other_meta=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Shape Inference\n", + "\n", + "Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic." + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": { + "outputId": "942a3788-bcbf-426d-87e6-c5a041172c64" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'encoder/bias': (4,),\n", + " 'encoder/kernel': (3, 3, 3, 4),\n", + " 'linear/bias': (4,),\n", + " 'linear/kernel': (144, 4)\n", + "})" + ] + }, + "execution_count": 129, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def batched_flatten(x):\n", + " return jnp.reshape(x, (x.shape[0], -1))\n", + "\n", + "class Example(nnx.Module):\n", + " def __init__(self, *,\n", + " in_filters=3,\n", + " out_filters=4,\n", + " input_shape=None, # provide an example input size\n", + " rngs):\n", + " self.encoder = nnx.Conv(in_filters, out_filters,\n", + " kernel_size=(3, 3),\n", + " strides=(1, 1),\n", + " padding=\"SAME\",\n", + " rngs=rngs)\n", + " # calculate the flattened shape post-conv using jax.eval_shape\n", + " encoded_shape = jax.eval_shape(\n", + " lambda x: batched_flatten(self.encoder(x)),\n", + " jax.ShapeDtypeStruct(input_shape, jnp.float32)\n", + " ).shape\n", + " # use this shape information to continue initializing\n", + " self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " x = self.encoder(x)\n", + " x = batched_flatten(x)\n", + " return self.linear(x)\n", + "\n", + "model = Example(in_filters=3,\n", + " out_filters=4,\n", + " input_shape=(2, 6, 6, 3),\n", + " rngs=nnx.Rngs(0))\n", + "\n", + "state, static = model.split()\n", + "jax.tree_map(jnp.shape, state)" ] } ], "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.13" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 0 } diff --git a/flax/experimental/nnx/docs/why.md b/flax/experimental/nnx/docs/why.md new file mode 100644 index 0000000000..2c262204b2 --- /dev/null +++ b/flax/experimental/nnx/docs/why.md @@ -0,0 +1,404 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Why NNX? + +Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years. + +We introduced some ideas that have proven to be good: + - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops. + - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms) + - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses. + +However, one choice we made was to use functional "define by call" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets. + +NNX is an attempt to keep the features that made Linen useful while introducing some new principles: + +- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references. +- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks. + +We'd love to hear from any of our users about their thoughts on these ideas. + +[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)] +[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)] + +```{code-cell} +from functools import partial +import jax +from jax import random, numpy as jnp +``` + +### NNX is Pythonic +The main feature of NNX Module is that it adheres to Python semantics. This means that: + +* fields are mutable so you can perform inplace updates +* Module references can be shared between multiple Modules +* Module construction implies parameter initialization +* Module methods can be called directly + +```{code-cell} +:outputId: d8ef66d5-6866-4d5c-94c2-d22512bfe718 + +class Count(nnx.Variable): # custom Variable types define the "collections" + pass + + +class CounterLinear(nnx.Module): + def __init__(self, din, dout, *, rngs): # explicit RNG threading + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections + + def __call__(self, x): + self.count += 1 # in-place stateful updates + return self.linear(x) + + +model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method +y = model(jnp.ones((2, 4))) # call methods directly + +print(f'{model = }') +``` + +Because NNX Modules contain their own state, they are very easily to inspect: + +```{code-cell} +:outputId: 10a46b0f-2993-4677-c26d-36a4ddf33449 + +print(f'{model.count = }') +print(f'{model.linear.kernel = }') +``` + +#### Intuitive Surgery + +In NNX surgery can be done at the Module level by simply updating / replacing existing fields. + +```{code-cell} +:outputId: e6f86be8-3537-4c48-f471-316ee0fb6c45 + +# pretend this came from a checkpoint or elsewhere: +pretrained_weight = random.uniform(random.key(0), (4, 4)) + +# you can replace weights directly +model.linear.kernel = pretrained_weight +y = model(jnp.ones((2, 4))) +y +``` + +```{code-cell} +:outputId: 5190ac7b-12f7-4400-d5bb-f91b97a557b6 + +def load_pretrained_fragment(): + # pretend this inits / loads some fragment of a model + replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1)) + return replacement + +# you can replace modules directly +model.linear = load_pretrained_fragment() +y = model(jnp.ones((2, 4))) +y +``` + +Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before). + +```{code-cell} +rngs = nnx.Rngs(0) +model = nnx.Sequence( + [ + nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs), + partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), + nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs), + partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), + lambda x: x.reshape((x.shape[0], -1)), # flatten + nnx.Linear(32 * 7 * 7, 10, rngs=rngs), + ] +) + +y = model(jnp.ones((2, 28, 28, 1))) + +# Do some weird surgery of the stack: +for i, layer in enumerate(model): + if isinstance(layer, nnx.Conv): + model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs) + +y = model(jnp.ones((2, 28, 28, 1))) +``` + +Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc. + ++++ + +### Interacting with JAX is easy + +While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code. + +NNX has two very simple APIs to interact with JAX: `split` and `merge`. + +The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module. + +```{code-cell} +:outputId: 9a3f378b-739e-4f45-9968-574651200ede + +model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) + +state, static = model.split() + +# state is a dictionary-like JAX pytree +print(f'{state = }') + +# static is also a JAX pytree, but containing no data, just metadata +print(f'\n{static = }') +``` + +The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. + +Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example: + +```{code-cell} +:outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d + +@jax.jit +def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array): + model = static.merge(state) + y = model(x) + state, _ = model.split() + return y, state + +x = jnp.ones((2, 4)) +y, state = forward(static,state, x) + +print(f'{y.shape = }') +print(f'{state["count"] = }') +``` + +#### Custom lifting and transformation + +By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. + +One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases. + +As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple. + +It uses the single additional method `update` to locally modify model state. + +```{code-cell} +:outputId: fdd212d7-4994-4fa5-d922-5a7d7cfad3e3 + +class LinearEnsemble(nnx.Module): + def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs): + # get raw rng seeds + keys = rngs.fork(num_models) # split all keys into `num_models` + + # define pure init fn and vmap + def vmap_init(keys): + return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split( + nnx.Param, Count + ) + params, counts, static = jax.vmap( + vmap_init, in_axes=(0,), out_axes=(0, None, None) + )(keys) + + # update wrapped submodule reference + self.models = static.merge(params, counts) + + def __call__(self, x): + # get module values, define pure fn, + # notice that we split the data into two collections by their types. + params, counts, static = self.models.split(nnx.Param, Count) + + # define pure init fn and vmap + def vmap_apply(x, params, counts, static): + model = static.merge(params, counts) + y = model(x) + params, counts, static = model.split(nnx.Param, Count) + return y, params, counts, static + + y, params, counts, static = jax.vmap( + vmap_apply, + in_axes=(None, 0, None, None), + out_axes=(0, 0, None, None) + )(x, params, counts, static) + + # update wrapped module + # uses `update` to integrate the new state + self.models.update(params, counts, static) + return y + +x = jnp.ones((4,)) +ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0)) + +# forward pass +y = ensemble(x) + +print(f'{y.shape = }') +print(f'{ensemble.models.count = }') +print(f'state = {jax.tree_map(jnp.shape, ensemble.get_state())}') +``` + +#### Convenience lifted transforms + ++++ + +Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications. + +```{code-cell} +:outputId: c4800a49-efd1-4ee5-e703-6e63e18da4cb + +# class transform: +ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4) + +scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) +scanned.get_state() +``` + +```{code-cell} +:outputId: 9efd6e71-d180-4674-ade0-2b02057a400b + +# method decorators: + +class ScannedLinear(nnx.Module): + + @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs)) + + @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) + def __call__(self, x): + return self.model(x) + +scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) +scanned.get_state() +``` + +#### Aside: Why aren't Modules Pytrees? + +A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`. + +The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees. + +As an example, lets take a look at what would happen if we allowed this very simple program to be valid: + +```{code-cell} +@jax.jit +def f(m1: nnx.Module, m2: nnx.Module): + return m1, m2 +``` + +Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong? + +There are two main problems with this: +* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`. +* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op. + ++++ + +### Standardized "Hooks" + +NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications. + +```{code-cell} +:outputId: c4e6586a-bfe0-4f26-d05b-8c9e395971b2 + +class TransposedParam(nnx.Variable): + def create_value(self, value): + return value.T # called on variable creation to transform initial value + def get_value(self): + return self.value.T # called when value fetched via module getattr + def set_value(self, value): + return self.replace(value=value.T) # called when setting value from module setattr + + +class OddLinear(nnx.Module): + def __init__(self, din, dout, *, rngs): + self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout))) + self.bias = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + print(f'{self.kernel.shape = }') + return x @ self.kernel + self.bias + + +model = OddLinear(4, 8, rngs=nnx.Rngs(0)) +y = model(jnp.ones((2, 4))) + +print(f'outer kernel shape = {model.split()[0]["kernel"].shape}') +``` + +SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema: + +```{code-cell} +:outputId: ef312738-0f56-4c0e-9aaf-3319d131f1a2 + +class MetadataParam(nnx.Param): + def __init__(self, *args, **kwargs): + for key in kwargs: + setattr(self, key, kwargs[key]) + super().__init__(*args) + + +class AnnotatedLinear(nnx.Module): + def __init__(self, din, dout, *, rngs): + self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0) + self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1) + + def __call__(self, x): + return x @ self.kernel + self.bias + + +model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0)) +y = model(jnp.ones((2, 4))) + +state, static = model.split() + +print(f"{state.variables['kernel'].meta=}\n{state.variables['kernel'].other_meta=}") +print(f"{state.variables['bias'].meta=}\n{state.variables['bias'].other_meta=}") +``` + +## Shape Inference + +Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic. + +```{code-cell} +:outputId: 942a3788-bcbf-426d-87e6-c5a041172c64 + +def batched_flatten(x): + return jnp.reshape(x, (x.shape[0], -1)) + +class Example(nnx.Module): + def __init__(self, *, + in_filters=3, + out_filters=4, + input_shape=None, # provide an example input size + rngs): + self.encoder = nnx.Conv(in_filters, out_filters, + kernel_size=(3, 3), + strides=(1, 1), + padding="SAME", + rngs=rngs) + # calculate the flattened shape post-conv using jax.eval_shape + encoded_shape = jax.eval_shape( + lambda x: batched_flatten(self.encoder(x)), + jax.ShapeDtypeStruct(input_shape, jnp.float32) + ).shape + # use this shape information to continue initializing + self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs) + + def __call__(self, x): + x = self.encoder(x) + x = batched_flatten(x) + return self.linear(x) + +model = Example(in_filters=3, + out_filters=4, + input_shape=(2, 6, 6, 3), + rngs=nnx.Rngs(0)) + +state, static = model.split() +jax.tree_map(jnp.shape, state) +``` diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py index adede26d52..c3fb5614aa 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -49,12 +49,12 @@ def test_rng_stream(self): key1 = rngs.params() assert rngs._rngs['params'].counts[-1] == 1 assert rngs._rngs['params'].key is key0 - assert not np.equal(key0, key1).all() + assert not jnp.allclose(key0, key1) key2 = rngs.params() assert rngs._rngs['params'].counts[-1] == 2 assert rngs._rngs['params'].key is key0 - assert not np.equal(key1, key2).all() + assert not jnp.allclose(key1, key2) def test_rng_fork(self): key0 = jax.random.key(0) From 079488999a02d1d6af838c69aff2b0cefc433a53 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 10 Nov 2023 14:55:15 -0500 Subject: [PATCH 2/2] fix pre-commit --- .../flax_fundamentals/flax_basics.ipynb | 127 +++------- docs/guides/flax_fundamentals/flax_basics.md | 71 +----- .../parallel_training/flax_on_pjit.ipynb | 189 ++++---------- docs/guides/parallel_training/flax_on_pjit.md | 95 +------ .../use_checkpointing.ipynb | 87 ++----- .../training_techniques/use_checkpointing.md | 65 ++--- docs/quick_start.ipynb | 234 +++--------------- docs/quick_start.md | 190 +------------- examples/imagenet/imagenet.ipynb | 82 ++---- examples/mnist/mnist.ipynb | 51 +--- examples/ogbg_molpcba/ogbg_molpcba.ipynb | 75 ++---- examples/seq2seq/seq2seq.ipynb | 57 +---- examples/sst2/sst2.ipynb | 60 ++--- flax/core/flax_functional_engine.ipynb | 27 +- tests/colab_tpu_jax_version.ipynb | 16 +- 15 files changed, 242 insertions(+), 1184 deletions(-) diff --git a/docs/guides/flax_fundamentals/flax_basics.ipynb b/docs/guides/flax_fundamentals/flax_basics.ipynb index 2a21981211..b07cc41291 100644 --- a/docs/guides/flax_fundamentals/flax_basics.ipynb +++ b/docs/guides/flax_fundamentals/flax_basics.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "yf-nWLh0naJi" - }, + "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", @@ -22,9 +20,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KyANAaZtbs86" - }, + "metadata": {}, "source": [ "## Setting up our environment\n", "\n", @@ -35,7 +31,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "qdrEVv9tinJn", "outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee", "tags": [ "skip-execution" @@ -61,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "kN6bZDaReZO2" - }, + "metadata": {}, "outputs": [], "source": [ "import jax\n", @@ -75,9 +68,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "pCCwAbOLiscA" - }, + "metadata": {}, "source": [ "## Linear regression with Flax\n", "\n", @@ -91,9 +82,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "id": "zWX2zEtphT4Y" - }, + "metadata": {}, "outputs": [], "source": [ "# We create one dense layer instance (taking 'features' parameter as input)\n", @@ -102,9 +91,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UmzP1QoQYAAN" - }, + "metadata": {}, "source": [ "Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n", "\n", @@ -117,7 +104,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "K529lhzeYtl8", "outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5" }, "outputs": [ @@ -155,9 +141,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "NH7Y9xMEewmO" - }, + "metadata": {}, "source": [ "*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n", "\n", @@ -171,9 +155,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "M1qo9M3_naJo" - }, + "metadata": {}, "source": [ "To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:" ] @@ -182,7 +164,6 @@ "cell_type": "code", "execution_count": 6, "metadata": { - "id": "J8ietJecWiuK", "outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae" }, "outputs": [ @@ -205,9 +186,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "lVsjgYzuSBGL" - }, + "metadata": {}, "source": [ "### Gradient descent\n", "\n", @@ -222,7 +201,6 @@ "cell_type": "code", "execution_count": 7, "metadata": { - "id": "bFIiMnL4dl-e", "outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52" }, "outputs": [ @@ -257,9 +235,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "ZHkioicCiUbx" - }, + "metadata": {}, "source": [ "We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))." ] @@ -267,9 +243,7 @@ { "cell_type": "code", "execution_count": 8, - "metadata": { - "id": "JqJaVc7BeNyT" - }, + "metadata": {}, "outputs": [], "source": [ "# Same as JAX version but using model.apply().\n", @@ -285,9 +259,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "wGKru__mi15v" - }, + "metadata": {}, "source": [ "And finally perform the gradient descent." ] @@ -296,7 +268,6 @@ "cell_type": "code", "execution_count": 9, "metadata": { - "id": "ePEl1ndse0Jq", "outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4" }, "outputs": [ @@ -340,9 +311,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "zqEnJ9Poyb6q" - }, + "metadata": {}, "source": [ "### Optimizing with Optax\n", "\n", @@ -372,9 +341,7 @@ { "cell_type": "code", "execution_count": 10, - "metadata": { - "id": "Ce77uDJx1bUF" - }, + "metadata": {}, "outputs": [], "source": [ "import optax\n", @@ -387,7 +354,6 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "id": "PTSv0vx13xPO", "outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec" }, "outputs": [ @@ -420,9 +386,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "0eAPPwtpXYu7" - }, + "metadata": {}, "source": [ "### Serializing the result\n", "\n", @@ -433,7 +397,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "BiUPRU93XnAZ", "outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c" }, "outputs": [ @@ -479,9 +442,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "eielPo2KZByd" - }, + "metadata": {}, "source": [ "To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n", "\n", @@ -492,7 +453,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "MOhoBDCOYYJ5", "outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc" }, "outputs": [ @@ -531,9 +491,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "8mNu8nuOhDC5" - }, + "metadata": {}, "source": [ "## Defining your own models\n", "\n", @@ -544,9 +502,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "1sllHAdRlpmQ" - }, + "metadata": {}, "source": [ "### Module basics\n", "\n", @@ -557,7 +513,6 @@ "cell_type": "code", "execution_count": 17, "metadata": { - "id": "vbfrfbkxgPhg", "outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3" }, "outputs": [ @@ -610,9 +565,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "DDITIjXitEZl" - }, + "metadata": {}, "source": [ "As we can see, a `nn.Module` subclass is made of:\n", "\n", @@ -630,7 +583,6 @@ "cell_type": "code", "execution_count": 19, "metadata": { - "id": "DEYrVA6dnaJu", "outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058" }, "outputs": [ @@ -651,9 +603,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "I__UrmShnaJu" - }, + "metadata": {}, "source": [ "Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:" ] @@ -662,7 +612,6 @@ "cell_type": "code", "execution_count": 20, "metadata": { - "id": "ZTCbdpQ4suSK", "outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d" }, "outputs": [ @@ -712,9 +661,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "es7YHjgexT-L" - }, + "metadata": {}, "source": [ "There are, however, a few differences you should be aware of between the two declaration modes:\n", "\n", @@ -725,9 +672,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "-ykceROJyp7W" - }, + "metadata": {}, "source": [ "### Module parameters\n", "\n", @@ -738,7 +683,6 @@ "cell_type": "code", "execution_count": 21, "metadata": { - "id": "wK371Pt_vVfR", "outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5" }, "outputs": [ @@ -793,9 +737,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "MKyhfzVpzC94" - }, + "metadata": {}, "source": [ "Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` :\n", "\n", @@ -808,9 +750,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "QmSpxyqLDr58" - }, + "metadata": {}, "source": [ "### Variables and collections of variables\n", "\n", @@ -828,7 +768,6 @@ "cell_type": "code", "execution_count": 22, "metadata": { - "id": "J6_tR-nPzB1i", "outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a" }, "outputs": [ @@ -883,9 +822,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "5OHBbMJng3ic" - }, + "metadata": {}, "source": [ "Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:" ] @@ -894,7 +831,6 @@ "cell_type": "code", "execution_count": 23, "metadata": { - "id": "IbTsCAvZcdBy", "outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b" }, "outputs": [ @@ -934,9 +870,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GuUSOSKegKIM" - }, + "metadata": {}, "source": [ "From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n", "\n", @@ -947,7 +881,6 @@ "cell_type": "code", "execution_count": 29, "metadata": { - "id": "TUgAbUPpnaJw", "outputId": "0906fbab-b866-4956-d231-b1374415d448" }, "outputs": [ @@ -1004,9 +937,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "eWUmx5EjtWge" - }, + "metadata": {}, "source": [ "Note that the above function has a quite verbose signature and it would not actually\n", "work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n", @@ -1016,9 +947,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "_GL0PsCwnaJw" - }, + "metadata": {}, "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", diff --git a/docs/guides/flax_fundamentals/flax_basics.md b/docs/guides/flax_fundamentals/flax_basics.md index d349efc451..437d03c631 100644 --- a/docs/guides/flax_fundamentals/flax_basics.md +++ b/docs/guides/flax_fundamentals/flax_basics.md @@ -8,8 +8,6 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "yf-nWLh0naJi"} - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) @@ -23,14 +21,13 @@ This notebook will walk you through the following workflow: * Serialization of parameters and other objects. * Creating your own models and managing state. -+++ {"id": "KyANAaZtbs86"} ++++ ## Setting up our environment Here we provide the code needed to set up the environment for our notebook. ```{code-cell} -:id: qdrEVv9tinJn :outputId: e30aa464-fa52-4f35-df96-716c68a4b3ee :tags: [skip-execution] @@ -41,8 +38,6 @@ Here we provide the code needed to set up the environment for our notebook. ``` ```{code-cell} -:id: kN6bZDaReZO2 - import jax from typing import Any, Callable, Sequence from jax import random, numpy as jnp @@ -50,8 +45,6 @@ import flax from flax import linen as nn ``` -+++ {"id": "pCCwAbOLiscA"} - ## Linear regression with Flax In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done. @@ -61,14 +54,10 @@ A dense layer is a layer that has a kernel parameter $W\in\mathcal{M}_{m,n}(\mat This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`). ```{code-cell} -:id: zWX2zEtphT4Y - # We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) ``` -+++ {"id": "UmzP1QoQYAAN"} - Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class. ### Model parameters & initialization @@ -76,7 +65,6 @@ Layers (and models in general, we'll use that word from now on) are subclasses o Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data. ```{code-cell} -:id: K529lhzeYtl8 :outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5 key1, key2 = random.split(random.key(0)) @@ -85,8 +73,6 @@ params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes ``` -+++ {"id": "NH7Y9xMEewmO"} - *Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.* The result is what we expect: bias and kernel parameters of the correct size. Under the hood: @@ -96,19 +82,16 @@ The result is what we expect: bias and kernel parameters of the correct size. Un * Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`. * The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`. -+++ {"id": "M1qo9M3_naJo"} ++++ To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input: ```{code-cell} -:id: J8ietJecWiuK :outputId: 7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae model.apply(params, x) ``` -+++ {"id": "lVsjgYzuSBGL"} - ### Gradient descent If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error: @@ -118,7 +101,6 @@ $$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}( Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example. ```{code-cell} -:id: bFIiMnL4dl-e :outputId: 6eae59dc-0632-4f53-eac8-c22a7c646a52 # Set problem dimensions. @@ -141,13 +123,9 @@ y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) ``` -+++ {"id": "ZHkioicCiUbx"} - We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees)). ```{code-cell} -:id: JqJaVc7BeNyT - # Same as JAX version but using model.apply(). @jax.jit def mse(params, x_batched, y_batched): @@ -159,12 +137,9 @@ def mse(params, x_batched, y_batched): return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0) ``` -+++ {"id": "wGKru__mi15v"} - And finally perform the gradient descent. ```{code-cell} -:id: ePEl1ndse0Jq :outputId: 50d975b3-4706-4d8a-c4b8-2629ab8e3ac4 learning_rate = 0.3 # Gradient step size. @@ -185,8 +160,6 @@ for i in range(101): print(f'Loss step {i}: ', loss_val) ``` -+++ {"id": "zqEnJ9Poyb6q"} - ### Optimizing with Optax Flax used to use its own `flax.optim` package for optimization, but with @@ -212,8 +185,6 @@ to the [official documentation](https://optax.readthedocs.io/en/latest/). ```{code-cell} -:id: Ce77uDJx1bUF - import optax tx = optax.adam(learning_rate=learning_rate) opt_state = tx.init(params) @@ -221,7 +192,6 @@ loss_grad_fn = jax.value_and_grad(mse) ``` ```{code-cell} -:id: PTSv0vx13xPO :outputId: eec0c096-1d9e-4b3c-f8e5-942ee63828ec for i in range(101): @@ -232,14 +202,11 @@ for i in range(101): print('Loss step {}: '.format(i), loss_val) ``` -+++ {"id": "0eAPPwtpXYu7"} - ### Serializing the result Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that. ```{code-cell} -:id: BiUPRU93XnAZ :outputId: b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c from flax import serialization @@ -251,35 +218,29 @@ print('Bytes output') print(bytes_output) ``` -+++ {"id": "eielPo2KZByd"} - To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place. *The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.* ```{code-cell} -:id: MOhoBDCOYYJ5 :outputId: 13acc4e1-8757-4554-e2c8-d594ba6e67dc serialization.from_bytes(params, bytes_output) ``` -+++ {"id": "8mNu8nuOhDC5"} - ## Defining your own models Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class. *Keep in mind that we imported* `linen as nn` *and this only works with the new linen API* -+++ {"id": "1sllHAdRlpmQ"} ++++ ### Module basics The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function. ```{code-cell} -:id: vbfrfbkxgPhg :outputId: b59c679c-d164-4fd6-92db-b50f0d310ec3 class ExplicitMLP(nn.Module): @@ -310,8 +271,6 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax. print('output:\n', y) ``` -+++ {"id": "DDITIjXitEZl"} - As we can see, a `nn.Module` subclass is made of: * A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`. @@ -324,7 +283,6 @@ As we can see, a `nn.Module` subclass is made of: Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input: ```{code-cell} -:id: DEYrVA6dnaJu :outputId: 4af16ec5-b52a-43b0-fc47-1f8ab25e7058 try: @@ -333,12 +291,9 @@ except AttributeError as e: print(e) ``` -+++ {"id": "I__UrmShnaJu"} - Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so: ```{code-cell} -:id: ZTCbdpQ4suSK :outputId: 183a74ef-f54e-4848-99bf-fee4c174ba6d class SimpleMLP(nn.Module): @@ -366,22 +321,19 @@ print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax. print('output:\n', y) ``` -+++ {"id": "es7YHjgexT-L"} - There are, however, a few differences you should be aware of between the two declaration modes: * In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders). * If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated. * The last initialization will be handled differently. See these notes for more details (TODO: add notes link). -+++ {"id": "-ykceROJyp7W"} ++++ ### Module parameters In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules: ```{code-cell} -:id: wK371Pt_vVfR :outputId: 83b5fea4-071e-4ea0-8fa8-610e69fb5fd5 class SimpleDense(nn.Module): @@ -410,8 +362,6 @@ print('initialized parameters:\n', params) print('output:\n', y) ``` -+++ {"id": "MKyhfzVpzC94"} - Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` : * `name` is simply the name of the parameter that will end up in the parameter structure. @@ -420,7 +370,7 @@ Here, we see how to both declare and assign a parameter to the model using the ` Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site. -+++ {"id": "QmSpxyqLDr58"} ++++ ### Variables and collections of variables @@ -434,7 +384,6 @@ However this is not enough to cover everything that we would need for machine le For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py). ```{code-cell} -:id: J6_tR-nPzB1i :outputId: 75465fd6-cdc8-497c-a3ec-7f709b5dde7a class BiasAdderWithRunningMean(nn.Module): @@ -463,12 +412,9 @@ y, updated_state = model.apply(variables, x, mutable=['batch_stats']) print('updated state:\n', updated_state) ``` -+++ {"id": "5OHBbMJng3ic"} - Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern: ```{code-cell} -:id: IbTsCAvZcdBy :outputId: 09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b for val in [1.0, 2.0, 3.0]: @@ -479,14 +425,11 @@ for val in [1.0, 2.0, 3.0]: print('updated state:\n', updated_state) # Shows only the mutable part ``` -+++ {"id": "GuUSOSKegKIM"} - From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables. *This example isn't doing anything and is only for demonstration purposes.* ```{code-cell} -:id: TUgAbUPpnaJw :outputId: 0906fbab-b866-4956-d231-b1374415d448 from functools import partial @@ -517,14 +460,12 @@ for _ in range(3): print('Updated state: ', state) ``` -+++ {"id": "eWUmx5EjtWge"} - Note that the above function has a quite verbose signature and it would not actually work with `jax.jit()` because the function arguments are not "valid JAX types". Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more. -+++ {"id": "_GL0PsCwnaJw"} ++++ ### Exporting to Tensorflow's SavedModel with jax2tf diff --git a/docs/guides/parallel_training/flax_on_pjit.ipynb b/docs/guides/parallel_training/flax_on_pjit.ipynb index b93547b022..59adbaf098 100644 --- a/docs/guides/parallel_training/flax_on_pjit.ipynb +++ b/docs/guides/parallel_training/flax_on_pjit.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "2a9f78765c0c" - }, + "metadata": {}, "source": [ "# Scale up Flax Modules on multiple devices\n", "\n", @@ -14,9 +12,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "b1e0e5fc8bc1" - }, + "metadata": {}, "source": [ "## Flax and `jax.jit` scaled up\n", "\n", @@ -34,9 +30,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "a9601432b448" - }, + "metadata": {}, "source": [ "## Setup\n", "\n", @@ -49,7 +43,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "867203db3bef", "tags": [ "skip-execution" ] @@ -63,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "f8f42d1174e5" - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -75,9 +66,7 @@ { "cell_type": "code", "execution_count": 42, - "metadata": { - "id": "b8da40732f0b" - }, + "metadata": {}, "outputs": [], "source": [ "import functools\n", @@ -98,9 +87,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "id": "bcc30de1d6eb" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -117,9 +104,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "c0d280def897" - }, + "metadata": {}, "source": [ "The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide:\n", "\n", @@ -135,9 +120,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "684fe9fe13a0" - }, + "metadata": {}, "outputs": [], "source": [ "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", @@ -148,9 +131,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "4589d7a6d4bb" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -178,9 +159,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "307d39db6d94" - }, + "metadata": {}, "source": [ "## Define a layer\n", "\n", @@ -198,9 +177,7 @@ { "cell_type": "code", "execution_count": 43, - "metadata": { - "id": "b74c049968dc" - }, + "metadata": {}, "outputs": [], "source": [ "class DotReluDot(nn.Module):\n", @@ -234,9 +211,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "cbac5321c08e" - }, + "metadata": {}, "source": [ "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n", "\n", @@ -256,9 +231,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "b8389c11af79" - }, + "metadata": {}, "source": [ "## Define a model with `flax.linen.scan` lifted transformation\n", "\n", @@ -277,9 +250,7 @@ { "cell_type": "code", "execution_count": 44, - "metadata": { - "id": "a0ea0dcccbc3" - }, + "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", @@ -303,9 +274,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "44395b62561d" - }, + "metadata": {}, "source": [ "Now, create a `model` instance, and a sample input `x`." ] @@ -313,9 +282,7 @@ { "cell_type": "code", "execution_count": 45, - "metadata": { - "id": "5686299b4839" - }, + "metadata": {}, "outputs": [], "source": [ "# MLP hyperparameters.\n", @@ -334,9 +301,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "5b3abfef359d" - }, + "metadata": {}, "source": [ "## Specify sharding\n", "\n", @@ -350,9 +315,7 @@ { "cell_type": "code", "execution_count": 46, - "metadata": { - "id": "8b913a2e57d3" - }, + "metadata": {}, "outputs": [ { "data": { @@ -397,9 +360,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "06d134795ae1" - }, + "metadata": {}, "source": [ "### The output's sharding\n", "\n", @@ -416,9 +377,7 @@ { "cell_type": "code", "execution_count": 47, - "metadata": { - "id": "19094ec63385" - }, + "metadata": {}, "outputs": [], "source": [ "def init_fn(k, x, model, optimizer):\n", @@ -433,9 +392,7 @@ { "cell_type": "code", "execution_count": 48, - "metadata": { - "id": "e49264a3c78e" - }, + "metadata": {}, "outputs": [ { "data": { @@ -542,9 +499,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2ec24614050b" - }, + "metadata": {}, "source": [ "## Compile the code\n", "\n", @@ -556,9 +511,7 @@ { "cell_type": "code", "execution_count": 49, - "metadata": { - "id": "5b6e699df733" - }, + "metadata": {}, "outputs": [ { "data": { @@ -638,9 +591,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "8f74b009f11f" - }, + "metadata": {}, "source": [ "## Inspect the Module output\n", "\n", @@ -652,9 +603,7 @@ { "cell_type": "code", "execution_count": 14, - "metadata": { - "id": "19243982c892" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -676,9 +625,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2beee7d27bdb" - }, + "metadata": {}, "source": [ "You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices." ] @@ -686,9 +633,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "id": "2067c419a826" - }, + "metadata": {}, "outputs": [ { "data": { @@ -708,9 +653,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "id": "d7cf0baa334b" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -737,9 +680,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "273547d3ab89" - }, + "metadata": {}, "source": [ "You can use [`jax.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays." ] @@ -747,9 +688,7 @@ { "cell_type": "code", "execution_count": 17, - "metadata": { - "id": "29b3dae156a2" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -779,9 +718,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "f7e1ccb14c6b" - }, + "metadata": {}, "source": [ "## Compile the train step and inference \n", "\n", @@ -791,9 +728,7 @@ { "cell_type": "code", "execution_count": 18, - "metadata": { - "id": "4e3cc300cfee" - }, + "metadata": {}, "outputs": [], "source": [ "@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), \n", @@ -815,9 +750,7 @@ { "cell_type": "code", "execution_count": 19, - "metadata": { - "id": "91c6c2662c12" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -906,9 +839,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "2bae79e2e71b" - }, + "metadata": {}, "source": [ "Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`." ] @@ -916,9 +847,7 @@ { "cell_type": "code", "execution_count": 20, - "metadata": { - "id": "c9264a48b9ee" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -979,9 +908,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "7daa9e6e6eb4" - }, + "metadata": {}, "source": [ "## Profiling\n", "\n", @@ -991,9 +918,7 @@ { "cell_type": "code", "execution_count": 21, - "metadata": { - "id": "a68d7cb2eb89" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1017,9 +942,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "51420b514d53" - }, + "metadata": {}, "source": [ "## Logical axis annotation\n", "\n", @@ -1035,9 +958,7 @@ { "cell_type": "code", "execution_count": 50, - "metadata": { - "id": "a26f85a9e772" - }, + "metadata": {}, "outputs": [], "source": [ "class LogicalDotReluDot(nn.Module):\n", @@ -1085,9 +1006,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "0de93ec6cbd6" - }, + "metadata": {}, "source": [ "Now, initiate a model and try to figure out what sharding its `state` should have.\n", "\n", @@ -1099,9 +1018,7 @@ { "cell_type": "code", "execution_count": 51, - "metadata": { - "id": "14db7a1e30fd" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1133,9 +1050,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "58475fffb2de" - }, + "metadata": {}, "source": [ "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above." ] @@ -1143,9 +1058,7 @@ { "cell_type": "code", "execution_count": 52, - "metadata": { - "id": "589ff774bb4c" - }, + "metadata": {}, "outputs": [ { "data": { @@ -1165,9 +1078,7 @@ { "cell_type": "code", "execution_count": 53, - "metadata": { - "id": "77e07a0ab309" - }, + "metadata": {}, "outputs": [], "source": [ "logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),\n", @@ -1180,9 +1091,7 @@ { "cell_type": "code", "execution_count": 54, - "metadata": { - "id": "fb53bc20e0f9" - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1271,9 +1180,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "ae1754a3031d" - }, + "metadata": {}, "source": [ "## When to use device axis / logical axis\n", "\n", @@ -1289,9 +1196,7 @@ { "attachments": {}, "cell_type": "markdown", - "metadata": { - "id": "576bdd5cd782" - }, + "metadata": {}, "source": [ "## Save the data\n", "\n", diff --git a/docs/guides/parallel_training/flax_on_pjit.md b/docs/guides/parallel_training/flax_on_pjit.md index 486871642c..fe2e75c517 100644 --- a/docs/guides/parallel_training/flax_on_pjit.md +++ b/docs/guides/parallel_training/flax_on_pjit.md @@ -8,13 +8,11 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "2a9f78765c0c"} - # Scale up Flax Modules on multiple devices This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html). -+++ {"id": "b1e0e5fc8bc1"} ++++ ## Flax and `jax.jit` scaled up @@ -28,7 +26,7 @@ Flax provides several functionalities that can help you use auto-SPMD on [Flax M You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site. -+++ {"id": "a9601432b448"} ++++ ## Setup @@ -37,7 +35,6 @@ Import some necessary dependencies. **Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment. ```{code-cell} -:id: 867203db3bef :tags: [skip-execution] # Once Flax v0.6.10 is released, there is no need to do this. @@ -45,15 +42,11 @@ Import some necessary dependencies. ``` ```{code-cell} -:id: f8f42d1174e5 - import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` ```{code-cell} -:id: b8da40732f0b - import functools from typing import Optional, Callable @@ -70,13 +63,9 @@ import optax # Optax for common losses and optimizers. ``` ```{code-cell} -:id: bcc30de1d6eb - print(f'We have 8 fake JAX devices now: {jax.devices()}') ``` -+++ {"id": "c0d280def897"} - The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide: 1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). @@ -88,16 +77,12 @@ The code below shows how to import and set up the JAX-level device API, followin 3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout. ```{code-cell} -:id: 684fe9fe13a0 - from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.lax import with_sharding_constraint from jax.experimental import mesh_utils ``` ```{code-cell} -:id: 4589d7a6d4bb - # Create a mesh and annotate each axis with a name. device_mesh = mesh_utils.create_device_mesh((2, 4)) print(device_mesh) @@ -109,8 +94,6 @@ def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec) ``` -+++ {"id": "307d39db6d94"} - ## Define a layer Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between. @@ -124,8 +107,6 @@ To shard the parameters efficiently, apply the following APIs to annotate the pa * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. ```{code-cell} -:id: b74c049968dc - class DotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @@ -154,8 +135,6 @@ class DotReluDot(nn.Module): return z, None ``` -+++ {"id": "cbac5321c08e"} - Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all. For example: @@ -170,7 +149,7 @@ For example: * The first dimension — the batch dimension — will be sharded over the `'data'` axis. This means half of the batch will be processed on devices `0-3` (first four devices), and another half on devices `4-7` (the remaining four devices). * The second dimension — the data depth dimension — will be replicated across all devices. -+++ {"id": "b8389c11af79"} ++++ ## Define a model with `flax.linen.scan` lifted transformation @@ -186,8 +165,6 @@ The code below shows how to apply both methods, and default with the for-loop, s The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations). ```{code-cell} -:id: a0ea0dcccbc3 - class MLP(nn.Module): num_layers: int depth: int @@ -206,13 +183,9 @@ class MLP(nn.Module): return x ``` -+++ {"id": "44395b62561d"} - Now, create a `model` instance, and a sample input `x`. ```{code-cell} -:id: 5686299b4839 - # MLP hyperparameters. BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. @@ -226,8 +199,6 @@ optimizer = optax.adam(learning_rate=0.001) model = MLP(LAYERS, DEPTH, USE_SCAN) ``` -+++ {"id": "5b3abfef359d"} - ## Specify sharding Next, you need to tell `jax.jit` how to shard our data across devices. @@ -237,15 +208,11 @@ Next, you need to tell `jax.jit` how to shard our data across devices. For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s. ```{code-cell} -:id: 8b913a2e57d3 - x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length) x = jax.device_put(x, x_sharding) jax.debug.visualize_array_sharding(x) ``` -+++ {"id": "06d134795ae1"} - ### The output's sharding You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. @@ -258,8 +225,6 @@ To achieve this, luckily, you don't have to hardcode the output's sharding by ha * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_partitioning.html) annotations in the earlier definition to generate the correct sharding for the parameters. ```{code-cell} -:id: 19094ec63385 - def init_fn(k, x, model, optimizer): variables = model.init(k, x) # Initialize the model. state = train_state.TrainState.create( # Create a `TrainState`. @@ -270,8 +235,6 @@ def init_fn(k, x, model, optimizer): ``` ```{code-cell} -:id: e49264a3c78e - # Create an abstract closure to wrap the function before feeding it in # because `jax.eval_shape` only takes pytrees as arguments. abstract_variables = jax.eval_shape( @@ -283,8 +246,6 @@ state_sharding = nn.get_sharding(abstract_variables, mesh) state_sharding ``` -+++ {"id": "2ec24614050b"} - ## Compile the code Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) to your `init_fn`, but with two extra arguments: `in_shardings` and `out_shardings`. @@ -292,8 +253,6 @@ Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-ji Run it to get the `initialized_state`, in which parameters are sharded exactly as instructed: ```{code-cell} -:id: 5b6e699df733 - jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(None), x_sharding), # PRNG key and x out_shardings=state_sharding) @@ -306,8 +265,6 @@ jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Den jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` -+++ {"id": "8f74b009f11f"} - ## Inspect the Module output Note that in the output of `initialized_state`, the `params` `W1` and `W2` are of type [`flax.linen.Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Partitioned.html). This is a wrapper around the actual `jax.Array` that allows Flax to record the axis names associated with it. @@ -315,38 +272,26 @@ Note that in the output of `initialized_state`, the `params` `W1` and `W2` are o You can access the raw `jax.Array` by adding `.value` when outside `jit`, or by `.unbox()` when inside. ```{code-cell} -:id: 19243982c892 - print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) ``` -+++ {"id": "2beee7d27bdb"} - You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices. ```{code-cell} -:id: 2067c419a826 - initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding ``` ```{code-cell} -:id: d7cf0baa334b - print(initialized_state.step) initialized_state.step.sharding ``` -+++ {"id": "273547d3ab89"} - You can use [`jax.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays. ```{code-cell} -:id: 29b3dae156a2 - diff = jax.tree_map( lambda a, b: a - b, initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0']) @@ -356,15 +301,11 @@ print(type(diff_array)) print(diff_array.shape) ``` -+++ {"id": "f7e1ccb14c6b"} - ## Compile the train step and inference Create a `jit`ted training step as follows: ```{code-cell} -:id: 4e3cc300cfee - @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=state_sharding) def train_step(state, x): @@ -382,21 +323,15 @@ with mesh: ``` ```{code-cell} -:id: 91c6c2662c12 - print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` -+++ {"id": "2bae79e2e71b"} - Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`. ```{code-cell} -:id: c9264a48b9ee - @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=x_sharding) def apply_fn(state, x): @@ -410,15 +345,11 @@ print(y.shape) jax.debug.visualize_array_sharding(y) ``` -+++ {"id": "7daa9e6e6eb4"} - ## Profiling If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance: ```{code-cell} -:id: a68d7cb2eb89 - %%timeit def block_all(xs): @@ -429,8 +360,6 @@ with mesh: new_state = block_all(train_step(initialized_state, x)) ``` -+++ {"id": "51420b514d53"} - ## Logical axis annotation JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). @@ -442,8 +371,6 @@ The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to 2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_partitioning.html) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.with_logical_constraint.html#flax-linen-with-logical-constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names. ```{code-cell} -:id: a26f85a9e772 - class LogicalDotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @@ -486,8 +413,6 @@ class LogicalMLP(nn.Module): return x ``` -+++ {"id": "0de93ec6cbd6"} - Now, initiate a model and try to figure out what sharding its `state` should have. To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.logical_to_mesh_sharding.html#flax-linen-logical-to-mesh-sharding) will convert them to the kind of sharding that the device mesh can understand. @@ -495,8 +420,6 @@ To allow the device mesh to take your model correctly, you need to decide which This allows you to change the rules and try out new partition layouts without modifying the model definition. ```{code-cell} -:id: 14db7a1e30fd - # Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`. rules = (('batch', 'data'), ('hidden', 'model')) @@ -514,19 +437,13 @@ print('sharding annotations are mesh-specific: ', logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec) ``` -+++ {"id": "58475fffb2de"} - You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above. ```{code-cell} -:id: 589ff774bb4c - state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0'] ``` ```{code-cell} -:id: 77e07a0ab309 - logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(None), x_sharding), # PRNG key and x out_shardings=logical_state_sharding) @@ -535,16 +452,12 @@ logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) ``` ```{code-cell} -:id: fb53bc20e0f9 - print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value) ``` -+++ {"id": "ae1754a3031d"} - ## When to use device axis / logical axis Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model: @@ -555,7 +468,7 @@ Choosing when to use a device or logical axis depends on how much you want to co * **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. -+++ {"id": "576bdd5cd782"} ++++ ## Save the data diff --git a/docs/guides/training_techniques/use_checkpointing.ipynb b/docs/guides/training_techniques/use_checkpointing.ipynb index e9728f429f..ba9054feba 100644 --- a/docs/guides/training_techniques/use_checkpointing.ipynb +++ b/docs/guides/training_techniques/use_checkpointing.ipynb @@ -4,9 +4,7 @@ "attachments": {}, "cell_type": "markdown", "id": "6e9134fa", - "metadata": { - "id": "6e9134fa" - }, + "metadata": {}, "source": [ "# Save and load checkpoints\n", "\n", @@ -46,9 +44,7 @@ { "cell_type": "markdown", "id": "5a2f6aae", - "metadata": { - "id": "5a2f6aae" - }, + "metadata": {}, "source": [ "## Setup\n", "\n", @@ -59,9 +55,7 @@ "attachments": {}, "cell_type": "markdown", "id": "-icO30rwmKYj", - "metadata": { - "id": "-icO30rwmKYj" - }, + "metadata": {}, "source": [ "Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell." ] @@ -70,9 +64,7 @@ "cell_type": "code", "execution_count": 1, "id": "ArKLnsyGRxGv", - "metadata": { - "id": "ArKLnsyGRxGv" - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -83,9 +75,7 @@ "cell_type": "code", "execution_count": 2, "id": "SJT9DTxTytjn", - "metadata": { - "id": "SJT9DTxTytjn" - }, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -128,9 +118,7 @@ { "cell_type": "markdown", "id": "40d434cd", - "metadata": { - "id": "40d434cd" - }, + "metadata": {}, "source": [ "## Save checkpoints\n", "\n", @@ -144,7 +132,6 @@ "execution_count": 4, "id": "56dec3f6", "metadata": { - "id": "56dec3f6", "outputId": "f1856d96-1961-48ed-bb7c-cb63fbaa7567" }, "outputs": [ @@ -220,9 +207,7 @@ { "cell_type": "markdown", "id": "6fc59dfa", - "metadata": { - "id": "6fc59dfa" - }, + "metadata": {}, "source": [ "Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory.\n", "\n", @@ -233,9 +218,7 @@ "cell_type": "code", "execution_count": 5, "id": "61b12da2", - "metadata": { - "id": "0pp4QtEqW9k7" - }, + "metadata": {}, "outputs": [], "source": [ "from flax.training import orbax_utils\n", @@ -262,7 +245,6 @@ "execution_count": 6, "id": "d3686ea5", "metadata": { - "id": "T6T8V4UBXB1R", "outputId": "b7132933-566d-440d-c34e-c5468d87cbdc" }, "outputs": [ @@ -293,9 +275,7 @@ { "cell_type": "markdown", "id": "8ecbc4cc", - "metadata": { - "id": "OQkUOkHVW_4e" - }, + "metadata": {}, "source": [ "### With the legacy API\n", "\n", @@ -307,7 +287,6 @@ "execution_count": 7, "id": "4cdb35ef", "metadata": { - "id": "4cdb35ef", "outputId": "6d849273-15ce-4480-8864-726d1838ac1f" }, "outputs": [ @@ -336,9 +315,7 @@ { "cell_type": "markdown", "id": "6b658bd1", - "metadata": { - "id": "6b658bd1" - }, + "metadata": {}, "source": [ "## Restore checkpoints\n", "\n", @@ -352,7 +329,6 @@ "execution_count": 8, "id": "a807a9c1", "metadata": { - "id": "WgRJj3wjXIaN", "outputId": "b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" }, "outputs": [ @@ -425,9 +401,7 @@ { "cell_type": "markdown", "id": "c7fe3bc8", - "metadata": { - "id": "VKJrfSyLXGrc" - }, + "metadata": {}, "source": [ "### With the legacy API\n", "\n", @@ -441,7 +415,6 @@ "execution_count": 10, "id": "150b20a0", "metadata": { - "id": "150b20a0", "outputId": "85ffceca-f38d-46b8-e567-d9d38b7885f9" }, "outputs": [ @@ -474,9 +447,7 @@ { "cell_type": "markdown", "id": "987b981f", - "metadata": { - "id": "987b981f" - }, + "metadata": {}, "source": [ "## Restore with custom dataclasses\n", "\n", @@ -496,7 +467,6 @@ "execution_count": 11, "id": "58f42513", "metadata": { - "id": "58f42513", "outputId": "110c6b6e-fe42-4179-e5d8-6b92d355e11b" }, "outputs": [ @@ -647,9 +617,7 @@ "attachments": {}, "cell_type": "markdown", "id": "136a300a", - "metadata": { - "id": "136a300a" - }, + "metadata": {}, "source": [ "## Restore when checkpoint structures differ\n", "\n", @@ -667,7 +635,6 @@ "execution_count": 14, "id": "be65d4af", "metadata": { - "id": "be65d4af", "outputId": "4fe776f0-65f8-4fc4-d64a-990520b36dce" }, "outputs": [ @@ -705,9 +672,7 @@ { "cell_type": "markdown", "id": "379c2255", - "metadata": { - "id": "379c2255" - }, + "metadata": {}, "source": [ "It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given.\n", "\n", @@ -867,7 +832,6 @@ "execution_count": 17, "id": "29fd1e33", "metadata": { - "id": "29fd1e33", "outputId": "cdbb9247-d1eb-4458-aa83-8db0332af7cb" }, "outputs": [ @@ -986,9 +950,7 @@ { "cell_type": "markdown", "id": "a6b39501", - "metadata": { - "id": "a6b39501" - }, + "metadata": {}, "source": [ "## Asynchronized checkpointing\n", "\n", @@ -1006,7 +968,6 @@ "execution_count": 19, "id": "85be68a6", "metadata": { - "id": "85be68a6", "outputId": "aefce94c-8bae-4355-c142-05f2b61c39e2" }, "outputs": [ @@ -1062,9 +1023,7 @@ { "cell_type": "markdown", "id": "13e93db6", - "metadata": { - "id": "QpuTCeMVXOBn" - }, + "metadata": {}, "source": [ "If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead." ] @@ -1084,9 +1043,7 @@ { "cell_type": "markdown", "id": "bb0e03cd", - "metadata": { - "id": "13e93db6" - }, + "metadata": {}, "source": [ "## Multi-host/multi-process checkpointing\n", "\n", @@ -1101,9 +1058,7 @@ "cell_type": "code", "execution_count": 21, "id": "ubdUvyMrhD-1", - "metadata": { - "id": "ubdUvyMrhD-1" - }, + "metadata": {}, "outputs": [], "source": [ "from jax.sharding import PartitionSpec, NamedSharding\n", @@ -1184,9 +1139,7 @@ { "cell_type": "markdown", "id": "edc355ce", - "metadata": { - "id": "edc355ce" - }, + "metadata": {}, "source": [ "### With the legacy Flax: use `save_checkpoint_multiprocess`\n", "\n", @@ -1200,7 +1153,6 @@ "execution_count": 24, "id": "5d10039b", "metadata": { - "id": "5d10039b", "outputId": "901bb097-0899-479d-b9ae-61dae79e7057" }, "outputs": [ @@ -1230,7 +1182,6 @@ "execution_count": 25, "id": "a9f9724c", "metadata": { - "id": "a9f9724c", "outputId": "393c4a0e-8a8c-4ca6-c609-93c8bab38e75" }, "outputs": [ diff --git a/docs/guides/training_techniques/use_checkpointing.md b/docs/guides/training_techniques/use_checkpointing.md index 10b9d8fca0..f6c6b58f08 100644 --- a/docs/guides/training_techniques/use_checkpointing.md +++ b/docs/guides/training_techniques/use_checkpointing.md @@ -10,7 +10,6 @@ jupyter: jupytext_version: 1.13.8 --- - # Save and load checkpoints This guide demonstrates how to save and load Flax checkpoints with [Orbax](https://github.com/google/orbax). @@ -45,24 +44,21 @@ For backward-compatibility, this guide shows the Orbax-equivalent calls in the F If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](https://github.com/google/orbax/blob/main/docs/checkpoint.md). - - + ## Setup Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation). - - + Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell. - -```python id="ArKLnsyGRxGv" +```python import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` -```python id="SJT9DTxTytjn" +```python from typing import Optional, Any import shutil @@ -86,15 +82,13 @@ if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run. ``` - ## Save checkpoints In Orbax and Flax, you can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html). This includes not only typical Python and NumPy containers, but also customized classes extended from [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass). That means you can store almost any data generated — not only your model parameters, but any arrays/dictionaries, metadata/configs, and so on. First, create a pytree with many data structures and containers, and play with it: - -```python id="56dec3f6" outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" +```python outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" # A simple model with one linear layer. key1, key2 = random.split(random.key(0)) x1 = random.normal(key1, (5,)) # A simple JAX array. @@ -121,13 +115,12 @@ ckpt ### With Orbax - + Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory. Note: An optional `save_args` is provided. This is recommended for performance speedups, as it bundles smaller arrays in your pytree to a single large file instead of multiple smaller files. - -```python id="0pp4QtEqW9k7" +```python from flax.training import orbax_utils orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() @@ -141,7 +134,7 @@ In addition, provide `orbax.checkpoint.CheckpointManagerOptions` that customizes `orbax.checkpoint.CheckpointManager` should be placed at the top-level outside your training steps to manage your saves. -```python id="T6T8V4UBXB1R" outputId="b7132933-566d-440d-c34e-c5468d87cbdc" +```python outputId="b7132933-566d-440d-c34e-c5468d87cbdc" options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True) checkpoint_manager = orbax.checkpoint.CheckpointManager( '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options) @@ -154,13 +147,11 @@ for step in range(5): os.listdir('/tmp/flax_ckpt/orbax/managed') # Because max_to_keep=2, only step 3 and 4 are retained ``` - ### With the legacy API And here's how to save with the legacy Flax checkpointing utilities (note that this provides less management features compared with `orbax.checkpoint.CheckpointManagerOptions`): - -```python id="4cdb35ef" outputId="6d849273-15ce-4480-8864-726d1838ac1f" +```python outputId="6d849273-15ce-4480-8864-726d1838ac1f" # Import Flax Checkpoints. from flax.training import checkpoints @@ -171,15 +162,13 @@ checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', keep=2) ``` - ## Restore checkpoints ### With Orbax In Orbax, call `.restore()` for either `orbax.checkpoint.PyTreeCheckpointer` or `orbax.checkpoint.CheckpointManager` to restore your checkpoint in the raw pytree format. - -```python id="WgRJj3wjXIaN" outputId="b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" +```python outputId="b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save') raw_restored ``` @@ -191,20 +180,17 @@ step = checkpoint_manager.latest_step() # step = 4 checkpoint_manager.restore(step) ``` - ### With the legacy API Note that with the migration to Orbax in progress, `flax.training.checkpointing.restore_checkpoint` can automatically identify whether a checkpoint is saved in the legacy Flax format or with an Orbax backend, and restore the pytree correctly. Therefore, adding `flax.config.update('flax_use_orbax_checkpointing', True)` won't hurt your ability to restore old checkpoints. Here's how to restore checkpoints using the legacy API: - -```python id="150b20a0" outputId="85ffceca-f38d-46b8-e567-d9d38b7885f9" +```python outputId="85ffceca-f38d-46b8-e567-d9d38b7885f9" raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None) raw_restored ``` - ## Restore with custom dataclasses ### With Orbax @@ -216,9 +202,8 @@ raw_restored This section demonstrates how to set up any custom Flax dataclass explicitly, and have the same structure as a saved checkpoint. Note: Data that was a JAX NumPy array (`jnp.array`) format will be restored as a NumPy array (`numpy.array`). This would not affect your work because JAX will [automatically convert](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) NumPy arrays to JAX arrays once the computation starts. - -```python id="58f42513" outputId="110c6b6e-fe42-4179-e5d8-6b92d355e11b" +```python outputId="110c6b6e-fe42-4179-e5d8-6b92d355e11b" empty_state = train_state.TrainState.create( apply_fn=model.apply, params=jax.tree_map(np.zeros_like, variables['params']), # values of the tree leaf doesn't matter @@ -244,7 +229,7 @@ checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', tar It's often recommended to refactor out the process of initializing a checkpoint's structure (for example, a [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state)), so that saving/loading is easier and less error-prone. This is because functions and complex objects like `apply_fn` and `tx` (optimizer) cannot be serialized into the checkpoint file and must be initialized by code. - + ## Restore when checkpoint structures differ During your development, your checkpoint structure will change when changing the model, adding/removing fields during tweaking, and so on. @@ -254,9 +239,8 @@ This section explains how to load old data to your new code. Below is a simple example — a `CustomTrainState` extended from `flax.training.train_state.TrainState` that contains an extra field called `batch_stats`. When working on a real-world model, you may need this when applying [batch normalization](https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html). Here, you store the new `CustomTrainState` as step 5, while step 4 contains the old/previous `TrainState`. - -```python id="be65d4af" outputId="4fe776f0-65f8-4fc4-d64a-990520b36dce" +```python outputId="4fe776f0-65f8-4fc4-d64a-990520b36dce" class CustomTrainState(train_state.TrainState): batch_stats: Any = None @@ -276,11 +260,10 @@ custom_save_args = orbax_utils.save_args_from_target(custom_ckpt) checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args}) ``` - It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given. Below are examples of a few common scenarios. - + ### Scenario 1: When a reference object is partial @@ -326,7 +309,7 @@ restored If you have already saved your checkpoints with the Orbax backend, you can use `orbax_transforms` to access this `transforms` argument in the Flax API. -```python id="29fd1e33" outputId="cdbb9247-d1eb-4458-aa83-8db0332af7cb" +```python outputId="cdbb9247-d1eb-4458-aa83-8db0332af7cb" # Save in the "Flax-with-Orbax" backend. flax.config.update('flax_use_orbax_checkpointing', True) checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', @@ -362,7 +345,6 @@ raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10)) flax.serialization.from_state_dict(custom_target, raw_state_dict) ``` - ## Asynchronized checkpointing Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training. @@ -372,9 +354,8 @@ You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://gi Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps. Whenever you want to explicitly wait until an async save is done, you can call `async_checkpointer.wait_until_finished()`. - -```python id="85be68a6" outputId="aefce94c-8bae-4355-c142-05f2b61c39e2" +```python outputId="aefce94c-8bae-4355-c142-05f2b61c39e2" # `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was # originally designed for multi-process large model checkpointing. # For Python notebooks or other single-process settings, just set up with `num_processes=1`. @@ -394,9 +375,7 @@ async_checkpointer.wait_until_finished() # Blocks until the checkpoint saving i async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target) ``` - If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead. - ```python async_checkpoint_manager = orbax.checkpoint.CheckpointManager( @@ -404,7 +383,6 @@ async_checkpoint_manager = orbax.checkpoint.CheckpointManager( async_checkpoint_manager.wait_until_finished() ``` - ## Multi-host/multi-process checkpointing JAX provides a few ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). To get started on JAX in multi-process settings, check out [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and the [distributed array guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). @@ -412,9 +390,8 @@ JAX provides a few ways to scale up your code on multiple hosts at the same time In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket. Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context. - -```python id="ubdUvyMrhD-1" +```python from jax.sharding import PartitionSpec, NamedSharding # Create an array sharded across multiple devices. @@ -454,15 +431,13 @@ async_checkpoint_manager.restore( 0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args}) ``` - ### With the legacy Flax: use `save_checkpoint_multiprocess` In legacy Flax, to save multi-process arrays, use [`flax.training.checkpoints.save_checkpoint_multiprocess()`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) in place of `save_checkpoint()` and with the same arguments. If your checkpoint is too large, you can specify `timeout_secs` in the manager and give it more time to finish writing. - -```python id="5d10039b" outputId="901bb097-0899-479d-b9ae-61dae79e7057" +```python outputId="901bb097-0899-479d-b9ae-61dae79e7057" async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50) checkpoints.save_checkpoint_multiprocess(ckpt_dir, mp_ckpt, @@ -472,7 +447,7 @@ checkpoints.save_checkpoint_multiprocess(ckpt_dir, orbax_checkpointer=async_checkpointer) ``` -```python id="a9f9724c" outputId="393c4a0e-8a8c-4ca6-c609-93c8bab38e75" +```python outputId="393c4a0e-8a8c-4ca6-c609-93c8bab38e75" mp_restored = checkpoints.restore_checkpoint(ckpt_dir, target=ref_ckpt, step=3, diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb index 62e1b1ea69..5c76c39a03 100644 --- a/docs/quick_start.ipynb +++ b/docs/quick_start.ipynb @@ -3,9 +3,7 @@ { "cell_type": "markdown", "id": "6eea21b3", - "metadata": { - "id": "6eea21b3" - }, + "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/getting_started.ipynb)\n", @@ -22,9 +20,7 @@ { "cell_type": "markdown", "id": "nwJWKIhdwxDo", - "metadata": { - "id": "nwJWKIhdwxDo" - }, + "metadata": {}, "source": [ "## 1. Install Flax" ] @@ -34,7 +30,6 @@ "execution_count": null, "id": "bb81587e", "metadata": { - "id": "bb81587e", "tags": [ "skip-execution" ] @@ -47,9 +42,7 @@ { "cell_type": "markdown", "id": "b529fbef", - "metadata": { - "id": "b529fbef" - }, + "metadata": {}, "source": [ "## 2. Loading data\n", "\n", @@ -62,14 +55,7 @@ "cell_type": "code", "execution_count": 48, "id": "bRlrHqZVXZvk", - "metadata": { - "executionInfo": { - "elapsed": 54, - "status": "ok", - "timestamp": 1673483483044 - }, - "id": "bRlrHqZVXZvk" - }, + "metadata": {}, "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS for MNIST\n", @@ -98,9 +84,7 @@ { "cell_type": "markdown", "id": "7057395a", - "metadata": { - "id": "7057395a" - }, + "metadata": {}, "source": [ "## 3. Define network\n", "\n", @@ -117,14 +101,7 @@ "cell_type": "code", "execution_count": 49, "id": "cbc079cd", - "metadata": { - "executionInfo": { - "elapsed": 53, - "status": "ok", - "timestamp": 1673483483208 - }, - "id": "cbc079cd" - }, + "metadata": {}, "outputs": [], "source": [ "from flax import linen as nn # Linen API\n", @@ -150,9 +127,7 @@ { "cell_type": "markdown", "id": "hy7iRu7_zlx-", - "metadata": { - "id": "hy7iRu7_zlx-" - }, + "metadata": {}, "source": [ "### View model layers\n", "\n", @@ -164,12 +139,6 @@ "execution_count": 50, "id": "lDHfog81zLQa", "metadata": { - "executionInfo": { - "elapsed": 103, - "status": "ok", - "timestamp": 1673483483427 - }, - "id": "lDHfog81zLQa", "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" }, "outputs": [ @@ -238,9 +207,7 @@ { "cell_type": "markdown", "id": "4b5ac16e", - "metadata": { - "id": "4b5ac16e" - }, + "metadata": {}, "source": [ "## 4. Create a `TrainState`\n", "\n", @@ -257,12 +224,6 @@ "execution_count": null, "id": "qXr7JDpIxGNZ", "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483483631 - }, - "id": "qXr7JDpIxGNZ", "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" }, "outputs": [], @@ -274,14 +235,7 @@ "cell_type": "code", "execution_count": 52, "id": "CJDaJNijyOji", - "metadata": { - "executionInfo": { - "elapsed": 1, - "status": "ok", - "timestamp": 1673483483754 - }, - "id": "CJDaJNijyOji" - }, + "metadata": {}, "outputs": [], "source": [ "from clu import metrics\n", @@ -293,9 +247,7 @@ { "cell_type": "markdown", "id": "8b86b5f1", - "metadata": { - "id": "8b86b5f1" - }, + "metadata": {}, "source": [ "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." ] @@ -304,14 +256,7 @@ "cell_type": "code", "execution_count": 53, "id": "7W0qf7FC9uG5", - "metadata": { - "executionInfo": { - "elapsed": 55, - "status": "ok", - "timestamp": 1673483483958 - }, - "id": "7W0qf7FC9uG5" - }, + "metadata": {}, "outputs": [], "source": [ "@struct.dataclass\n", @@ -323,9 +268,7 @@ { "cell_type": "markdown", "id": "f3ce5e4c", - "metadata": { - "id": "f3ce5e4c" - }, + "metadata": {}, "source": [ "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." @@ -335,14 +278,7 @@ "cell_type": "code", "execution_count": 54, "id": "e0102447", - "metadata": { - "executionInfo": { - "elapsed": 54, - "status": "ok", - "timestamp": 1673483484125 - }, - "id": "e0102447" - }, + "metadata": {}, "outputs": [], "source": [ "class TrainState(train_state.TrainState):\n", @@ -360,9 +296,7 @@ { "cell_type": "markdown", "id": "a15de484", - "metadata": { - "id": "a15de484" - }, + "metadata": {}, "source": [ "## 5. Training step\n", "\n", @@ -388,14 +322,7 @@ "cell_type": "code", "execution_count": 55, "id": "9b0af486", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483484293 - }, - "id": "9b0af486" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -415,9 +342,7 @@ { "cell_type": "markdown", "id": "0ff5145f", - "metadata": { - "id": "0ff5145f" - }, + "metadata": {}, "source": [ "## 6. Metric computation\n", "\n", @@ -428,14 +353,7 @@ "cell_type": "code", "execution_count": 56, "id": "961bf70b", - "metadata": { - "executionInfo": { - "elapsed": 53, - "status": "ok", - "timestamp": 1673483484460 - }, - "id": "961bf70b" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -453,9 +371,7 @@ { "cell_type": "markdown", "id": "497241c3", - "metadata": { - "id": "497241c3" - }, + "metadata": {}, "source": [ "## 7. Download data" ] @@ -464,14 +380,7 @@ "cell_type": "code", "execution_count": 57, "id": "bff5393e", - "metadata": { - "executionInfo": { - "elapsed": 515, - "status": "ok", - "timestamp": 1673483485090 - }, - "id": "bff5393e" - }, + "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", @@ -484,9 +393,7 @@ "attachments": {}, "cell_type": "markdown", "id": "809ae1a0", - "metadata": { - "id": "809ae1a0" - }, + "metadata": {}, "source": [ "## 8. Seed randomness\n", "\n", @@ -503,14 +410,7 @@ "cell_type": "code", "execution_count": 58, "id": "xC4MFyBsfT-U", - "metadata": { - "executionInfo": { - "elapsed": 59, - "status": "ok", - "timestamp": 1673483485268 - }, - "id": "xC4MFyBsfT-U" - }, + "metadata": {}, "outputs": [], "source": [ "tf.random.set_seed(0)" @@ -520,14 +420,7 @@ "cell_type": "code", "execution_count": 59, "id": "e4f6f4d3", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483485436 - }, - "id": "e4f6f4d3" - }, + "metadata": {}, "outputs": [], "source": [ "init_rng = jax.random.key(0)" @@ -536,9 +429,7 @@ { "cell_type": "markdown", "id": "80fbb60b", - "metadata": { - "id": "80fbb60b" - }, + "metadata": {}, "source": [ "## 9. Initialize the `TrainState`\n", "\n", @@ -550,14 +441,7 @@ "cell_type": "code", "execution_count": 60, "id": "445fcab0", - "metadata": { - "executionInfo": { - "elapsed": 56, - "status": "ok", - "timestamp": 1673483485606 - }, - "id": "445fcab0" - }, + "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.01\n", @@ -568,14 +452,7 @@ "cell_type": "code", "execution_count": 61, "id": "5221eafd", - "metadata": { - "executionInfo": { - "elapsed": 52, - "status": "ok", - "timestamp": 1673483485777 - }, - "id": "5221eafd" - }, + "metadata": {}, "outputs": [], "source": [ "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", @@ -585,9 +462,7 @@ { "cell_type": "markdown", "id": "b1c00230", - "metadata": { - "id": "b1c00230" - }, + "metadata": {}, "source": [ "## 10. Train and evaluate\n", "\n", @@ -610,14 +485,7 @@ "cell_type": "code", "execution_count": 62, "id": "74295360", - "metadata": { - "executionInfo": { - "elapsed": 55, - "status": "ok", - "timestamp": 1673483485947 - }, - "id": "74295360" - }, + "metadata": {}, "outputs": [], "source": [ "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", @@ -628,14 +496,7 @@ "cell_type": "code", "execution_count": 63, "id": "cRtnMZuQFlKl", - "metadata": { - "executionInfo": { - "elapsed": 1, - "status": "ok", - "timestamp": 1673483486076 - }, - "id": "cRtnMZuQFlKl" - }, + "metadata": {}, "outputs": [], "source": [ "metrics_history = {'train_loss': [],\n", @@ -649,12 +510,6 @@ "execution_count": 64, "id": "2c40ce90", "metadata": { - "executionInfo": { - "elapsed": 17908, - "status": "ok", - "timestamp": 1673483504133 - }, - "id": "2c40ce90", "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" }, "outputs": [ @@ -716,9 +571,7 @@ { "cell_type": "markdown", "id": "gfsecJzvzgCT", - "metadata": { - "id": "gfsecJzvzgCT" - }, + "metadata": {}, "source": [ "## 11. Visualize metrics" ] @@ -728,12 +581,6 @@ "execution_count": 65, "id": "Zs5atiqIG9Kz", "metadata": { - "executionInfo": { - "elapsed": 358, - "status": "ok", - "timestamp": 1673483504621 - }, - "id": "Zs5atiqIG9Kz", "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" }, "outputs": [ @@ -776,9 +623,7 @@ { "cell_type": "markdown", "id": "qQbKS0tV3sZ1", - "metadata": { - "id": "qQbKS0tV3sZ1" - }, + "metadata": {}, "source": [ "## 12. Perform inference on test set\n", "\n", @@ -789,14 +634,7 @@ "cell_type": "code", "execution_count": 66, "id": "DFwxgBQf44ks", - "metadata": { - "executionInfo": { - "elapsed": 580, - "status": "ok", - "timestamp": 1673483505350 - }, - "id": "DFwxgBQf44ks" - }, + "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", @@ -813,12 +651,6 @@ "execution_count": 67, "id": "5d5nF3u44JFI", "metadata": { - "executionInfo": { - "elapsed": 1250, - "status": "ok", - "timestamp": 1673483506723 - }, - "id": "5d5nF3u44JFI", "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" }, "outputs": [ @@ -844,9 +676,7 @@ { "cell_type": "markdown", "id": "edb528b6", - "metadata": { - "id": "edb528b6" - }, + "metadata": {}, "source": [ "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", "the same example, but structured differently as a couple of Python modules, test\n", diff --git a/docs/quick_start.md b/docs/quick_start.md index 0fe3f63129..e12dc84910 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -9,8 +9,6 @@ jupytext: jupytext_version: 1.13.8 --- -+++ {"id": "6eea21b3"} - [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/getting_started.ipynb) @@ -22,19 +20,16 @@ Flax is an open source Python neural network library built on top of [JAX](https network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train the network for image classification on the MNIST dataset. -+++ {"id": "nwJWKIhdwxDo"} ++++ ## 1. Install Flax ```{code-cell} -:id: bb81587e :tags: [skip-execution] !pip install -q flax>=0.7.5 ``` -+++ {"id": "b529fbef"} - ## 2. Loading data Flax can use any @@ -42,13 +37,6 @@ data-loading pipeline and this example demonstrates how to utilize TFDS. Define samples to floating-point numbers. ```{code-cell} ---- -executionInfo: - elapsed: 54 - status: ok - timestamp: 1673483483044 -id: bRlrHqZVXZvk ---- import tensorflow_datasets as tfds # TFDS for MNIST import tensorflow as tf # TensorFlow operations @@ -72,8 +60,6 @@ def get_datasets(num_epochs, batch_size): return train_ds, test_ds ``` -+++ {"id": "7057395a"} - ## 3. Define network Create a convolutional neural network with the Linen API by subclassing @@ -85,13 +71,6 @@ stacking layers—you can define the inlined submodules directly within the decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. ```{code-cell} ---- -executionInfo: - elapsed: 53 - status: ok - timestamp: 1673483483208 -id: cbc079cd ---- from flax import linen as nn # Linen API class CNN(nn.Module): @@ -112,21 +91,13 @@ class CNN(nn.Module): return x ``` -+++ {"id": "hy7iRu7_zlx-"} - ### View model layers Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. ```{code-cell} ---- -executionInfo: - elapsed: 103 - status: ok - timestamp: 1673483483427 -id: lDHfog81zLQa -outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da ---- +:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da + import jax import jax.numpy as jnp # JAX NumPy @@ -135,8 +106,6 @@ print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), compute_flops=True, compute_vjp_flops=True)) ``` -+++ {"id": "4b5ac16e"} - ## 4. Create a `TrainState` A common pattern in Flax is to create a single dataclass that represents the @@ -147,62 +116,31 @@ Because this is such a common pattern, Flax provides the class that serves most basic usecases. ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483483631 -id: qXr7JDpIxGNZ -outputId: 1249b7fb-6787-41eb-b34c-61d736300844 ---- +:outputId: 1249b7fb-6787-41eb-b34c-61d736300844 + !pip install -q clu ``` ```{code-cell} ---- -executionInfo: - elapsed: 1 - status: ok - timestamp: 1673483483754 -id: CJDaJNijyOji ---- from clu import metrics from flax.training import train_state # Useful dataclass to keep train state from flax import struct # Flax dataclasses import optax # Common loss functions and optimizers ``` -+++ {"id": "8b86b5f1"} - We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). ```{code-cell} ---- -executionInfo: - elapsed: 55 - status: ok - timestamp: 1673483483958 -id: 7W0qf7FC9uG5 ---- @struct.dataclass class Metrics(metrics.Collection): accuracy: metrics.Accuracy loss: metrics.Average.from_output('loss') ``` -+++ {"id": "f3ce5e4c"} - You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. ```{code-cell} ---- -executionInfo: - elapsed: 54 - status: ok - timestamp: 1673483484125 -id: e0102447 ---- class TrainState(train_state.TrainState): metrics: Metrics @@ -215,8 +153,6 @@ def create_train_state(module, rng, learning_rate, momentum): metrics=Metrics.empty()) ``` -+++ {"id": "a15de484"} - ## 5. Training step A function that: @@ -237,13 +173,6 @@ it with [XLA](https://www.tensorflow.org/xla) into fused device operations that run faster and more efficiently on hardware accelerators. ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483484293 -id: 9b0af486 ---- @jax.jit def train_step(state, batch): """Train for a single step.""" @@ -258,20 +187,11 @@ def train_step(state, batch): return state ``` -+++ {"id": "0ff5145f"} - ## 6. Metric computation Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. ```{code-cell} ---- -executionInfo: - elapsed: 53 - status: ok - timestamp: 1673483484460 -id: 961bf70b ---- @jax.jit def compute_metrics(*, state, batch): logits = state.apply_fn({'params': state.params}, batch['image']) @@ -284,26 +204,15 @@ def compute_metrics(*, state, batch): return state ``` -+++ {"id": "497241c3"} - ## 7. Download data ```{code-cell} ---- -executionInfo: - elapsed: 515 - status: ok - timestamp: 1673483485090 -id: bff5393e ---- num_epochs = 10 batch_size = 32 train_ds, test_ds = get_datasets(num_epochs, batch_size) ``` -+++ {"id": "809ae1a0"} - ## 8. Seed randomness - Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. @@ -315,60 +224,28 @@ train_ds, test_ds = get_datasets(num_epochs, batch_size) and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) ```{code-cell} ---- -executionInfo: - elapsed: 59 - status: ok - timestamp: 1673483485268 -id: xC4MFyBsfT-U ---- tf.random.set_seed(0) ``` ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483485436 -id: e4f6f4d3 ---- init_rng = jax.random.key(0) ``` -+++ {"id": "80fbb60b"} - ## 9. Initialize the `TrainState` Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics and puts them into the training state dataclass that is returned. ```{code-cell} ---- -executionInfo: - elapsed: 56 - status: ok - timestamp: 1673483485606 -id: 445fcab0 ---- learning_rate = 0.01 momentum = 0.9 ``` ```{code-cell} ---- -executionInfo: - elapsed: 52 - status: ok - timestamp: 1673483485777 -id: 5221eafd ---- state = create_train_state(cnn, init_rng, learning_rate, momentum) del init_rng # Must not be used anymore. ``` -+++ {"id": "b1c00230"} - ## 10. Train and evaluate Create a "shuffled" dataset by: @@ -386,25 +263,11 @@ Define a training loop that: Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. ```{code-cell} ---- -executionInfo: - elapsed: 55 - status: ok - timestamp: 1673483485947 -id: '74295360' ---- # since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs ``` ```{code-cell} ---- -executionInfo: - elapsed: 1 - status: ok - timestamp: 1673483486076 -id: cRtnMZuQFlKl ---- metrics_history = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], @@ -412,14 +275,8 @@ metrics_history = {'train_loss': [], ``` ```{code-cell} ---- -executionInfo: - elapsed: 17908 - status: ok - timestamp: 1673483504133 -id: 2c40ce90 -outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 ---- +:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 + for step,batch in enumerate(train_ds.as_numpy_iterator()): # Run optimization steps over training batches and compute batch metrics @@ -447,19 +304,11 @@ for step,batch in enumerate(train_ds.as_numpy_iterator()): f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") ``` -+++ {"id": "gfsecJzvzgCT"} - ## 11. Visualize metrics ```{code-cell} ---- -executionInfo: - elapsed: 358 - status: ok - timestamp: 1673483504621 -id: Zs5atiqIG9Kz -outputId: 431a2fcd-44fa-4202-f55a-906555f060ac ---- +:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac + import matplotlib.pyplot as plt # Visualization # Plot loss and accuracy in subplots @@ -475,20 +324,11 @@ plt.show() plt.clf() ``` -+++ {"id": "qQbKS0tV3sZ1"} - ## 12. Perform inference on test set Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. ```{code-cell} ---- -executionInfo: - elapsed: 580 - status: ok - timestamp: 1673483505350 -id: DFwxgBQf44ks ---- @jax.jit def pred_step(state, batch): logits = state.apply_fn({'params': state.params}, test_batch['image']) @@ -499,14 +339,8 @@ pred = pred_step(state, test_batch) ``` ```{code-cell} ---- -executionInfo: - elapsed: 1250 - status: ok - timestamp: 1673483506723 -id: 5d5nF3u44JFI -outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e ---- +:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e + fig, axs = plt.subplots(5, 5, figsize=(12, 12)) for i, ax in enumerate(axs.flatten()): ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') @@ -514,8 +348,6 @@ for i, ax in enumerate(axs.flatten()): ax.axis('off') ``` -+++ {"id": "edb528b6"} - Congratulations! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax's Git repo: diff --git a/examples/imagenet/imagenet.ipynb b/examples/imagenet/imagenet.ipynb index 2614653ef8..d271c9d117 100644 --- a/examples/imagenet/imagenet.ipynb +++ b/examples/imagenet/imagenet.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax Imagenet Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": 28, "metadata": { - "id": "ecyWhpr9X6tE", "outputId": "cb862d1a-2f71-444f-9770-9f0d53b11389" }, "outputs": [ @@ -72,7 +65,6 @@ "cell_type": "code", "execution_count": 29, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "80340396-77c2-4654-cc6d-67040f227eb9" }, "outputs": [ @@ -92,9 +84,7 @@ { "cell_type": "code", "execution_count": 30, - "metadata": { - "id": "SwX8bCNEGhJM" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/imagenet'\n", @@ -108,7 +98,6 @@ "execution_count": 31, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "9449a7b4-8a5d-4446-abe0-7886435ebd1c" }, "outputs": [ @@ -249,7 +238,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "acc1f45d-5062-4ff3-e6d4-10b4ffe0f8ef" }, "outputs": [], @@ -260,9 +248,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -272,7 +258,6 @@ "execution_count": 33, "metadata": { "cellView": "form", - "id": "4EzOChfJeVrU", "outputId": "9dc7fb32-331e-44a6-b6e8-830f6a64d845" }, "outputs": [ @@ -303,9 +288,7 @@ { "cell_type": "code", "execution_count": 34, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "import json\n", @@ -324,9 +307,7 @@ { "cell_type": "code", "execution_count": 35, - "metadata": { - "id": "7O2C7AY3p4ZF" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", @@ -356,7 +337,6 @@ "cell_type": "code", "execution_count": 36, "metadata": { - "id": "6Y1ru2Ovp4ZI", "outputId": "f943d165-b953-4a70-9f93-96eb857c3d53" }, "outputs": [ @@ -382,9 +362,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -393,7 +371,6 @@ "cell_type": "code", "execution_count": 37, "metadata": { - "id": "12KP_4h-_10s", "outputId": "a9b6cfe9-cc1c-451a-f8f7-69356cb7bdd2" }, "outputs": [ @@ -473,9 +450,7 @@ { "cell_type": "code", "execution_count": 38, - "metadata": { - "id": "UnuSCpoYBPKN" - }, + "metadata": {}, "outputs": [], "source": [ "# Utilities to help with Imagenette labels.\n", @@ -514,7 +489,6 @@ "cell_type": "code", "execution_count": 39, "metadata": { - "id": "EBibz3g905qt", "outputId": "78142300-cc8b-4a6c-f781-5ab29578d828" }, "outputs": [ @@ -542,7 +516,6 @@ "cell_type": "code", "execution_count": 40, "metadata": { - "id": "ccF8NVuX1Msk", "outputId": "8b3b9cf2-7649-4953-99bb-32a689fe0a29" }, "outputs": [ @@ -565,9 +538,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training from scratch" ] @@ -575,9 +546,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -591,7 +560,6 @@ "cell_type": "code", "execution_count": 42, "metadata": { - "id": "GGgHtVhIIuH7", "outputId": "2d0bc789-213d-4a34-a7b1-e7852b40f375" }, "outputs": [ @@ -638,7 +606,6 @@ "cell_type": "code", "execution_count": 43, "metadata": { - "id": "4bGmMCQd6S8U", "outputId": "de56d320-c336-459b-f258-5d6ae41ce0af" }, "outputs": [ @@ -667,7 +634,6 @@ "cell_type": "code", "execution_count": 44, "metadata": { - "id": "OBSJAvUqGgDq", "outputId": "018da2c5-c6f0-42ac-843f-7ac855a6bf14" }, "outputs": [ @@ -760,8 +726,7 @@ "cell_type": "code", "execution_count": 45, "metadata": { - "cellView": "form", - "id": "mZOKD0Y7p4ZW" + "cellView": "form" }, "outputs": [], "source": [ @@ -776,9 +741,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Load pre-trained model" ] @@ -787,7 +750,6 @@ "cell_type": "code", "execution_count": 46, "metadata": { - "id": "uKeJJJ5FJksQ", "outputId": "b06fa3d8-a950-46d2-e03e-fc6c971bdbd0" }, "outputs": [ @@ -815,9 +777,7 @@ { "cell_type": "code", "execution_count": 47, - "metadata": { - "id": "UCBikf4GvGuR" - }, + "metadata": {}, "outputs": [], "source": [ "# Load config that was used to train checkpoint.\n", @@ -829,7 +789,6 @@ "cell_type": "code", "execution_count": 48, "metadata": { - "id": "YfA4OnlyKe5x", "outputId": "57777298-4b4b-4a82-b0f2-4b6ff3b949af" }, "outputs": [ @@ -860,9 +819,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "HeMRgkbGiXo9" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -871,7 +828,6 @@ "cell_type": "code", "execution_count": 49, "metadata": { - "id": "i-7r57EkYJtc", "outputId": "793a656b-f3ad-4596-ad4f-44c686e5e885" }, "outputs": [ @@ -895,9 +851,7 @@ { "cell_type": "code", "execution_count": 50, - "metadata": { - "id": "KNTNZZJKYEHF" - }, + "metadata": {}, "outputs": [], "source": [ "# Evaluate using model trained on imagenet.\n", @@ -908,7 +862,6 @@ "cell_type": "code", "execution_count": 51, "metadata": { - "id": "ti55teFObTZW", "outputId": "6ab4bb0b-2c03-4663-d7ac-e51b979d121f" }, "outputs": [ @@ -934,7 +887,6 @@ "cell_type": "code", "execution_count": 52, "metadata": { - "id": "k5bKo731c98H", "outputId": "142c1acf-037e-4ab0-9ca3-bdf0829c51c4" }, "outputs": [ @@ -964,7 +916,6 @@ "cell_type": "code", "execution_count": 53, "metadata": { - "id": "2tEFrztxnh2B", "outputId": "4fae2533-5598-4f2e-c133-50bfba463311" }, "outputs": [ @@ -995,7 +946,6 @@ "cell_type": "code", "execution_count": 54, "metadata": { - "id": "SY3YQbgLgJe1", "outputId": "d01e1993-28ab-4a4a-ac58-01c83b80e6c9" }, "outputs": [ diff --git a/examples/mnist/mnist.ipynb b/examples/mnist/mnist.ipynb index 94b8412fe4..3bc16317ba 100644 --- a/examples/mnist/mnist.ipynb +++ b/examples/mnist/mnist.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax MNIST Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "8520b2f8-2b9d-4216-ba1f-d96175455bbc" }, "outputs": [ @@ -86,7 +79,6 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "id": "SwX8bCNEGhJM", "tags": [] }, "outputs": [], @@ -102,7 +94,6 @@ "execution_count": 3, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "2dfbdfa6-d213-4b5b-dc82-ee1765705255" }, "outputs": [ @@ -226,7 +217,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "e9061488-ac3e-4d23-f24f-06e1988e7541" }, "outputs": [ @@ -245,9 +235,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -255,9 +243,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import logging\n", @@ -273,9 +259,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "7O2C7AY3p4ZF" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", @@ -302,7 +286,6 @@ "cell_type": "code", "execution_count": 7, "metadata": { - "id": "6Y1ru2Ovp4ZI", "tags": [] }, "outputs": [], @@ -318,9 +301,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -329,7 +310,6 @@ "cell_type": "code", "execution_count": 8, "metadata": { - "id": "BRg0rNsJp4ZL", "outputId": "bb4525f4-8ca4-4e9d-d1cc-48a3e0533645", "tags": [] }, @@ -413,7 +393,6 @@ "cell_type": "code", "execution_count": 9, "metadata": { - "id": "B0LgjT3Vp4ZP", "outputId": "89de05b0-aede-414f-cf43-5e7c71871140" }, "outputs": [ @@ -439,9 +418,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -449,9 +426,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -465,7 +440,6 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "id": "RHoBUKSkp4ZS", "outputId": "a0eb78b5-ee73-4f4f-8400-41b521f42b75", "tags": [] }, @@ -532,7 +506,6 @@ "execution_count": 12, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "tags": [] }, "outputs": [], @@ -548,9 +521,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -559,7 +530,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "Q-45FkBLp4ZY", "outputId": "3af424f7-4433-475d-817c-5c0bbc4599ae" }, "outputs": [ @@ -587,7 +557,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "xIIvyz8Jp4Zb", "outputId": "949487f5-8aa2-45c8-9b54-efbf34ab58f1" }, "outputs": [ diff --git a/examples/ogbg_molpcba/ogbg_molpcba.ipynb b/examples/ogbg_molpcba/ogbg_molpcba.ipynb index c51a4f7f4e..61baa503c0 100644 --- a/examples/ogbg_molpcba/ogbg_molpcba.ipynb +++ b/examples/ogbg_molpcba/ogbg_molpcba.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "81wUkzl5gCUr" - }, + "metadata": {}, "source": [ "# Flax ogbg-molpcba Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "hfbxr1U9eciL" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -27,7 +23,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "id": "cKmRTXhHdm_U", "outputId": "6508ab2f-b0e5-4693-f6a0-7bc495ec1344" }, "outputs": [ @@ -61,9 +56,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "bdI9miDfEk9Y" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/ogbg_molpcba'\n", @@ -77,7 +70,6 @@ "execution_count": 3, "metadata": { "cellView": "form", - "id": "bCKbiylLgURG", "outputId": "8261a349-b41e-4e1b-a2ca-8d23412155be" }, "outputs": [ @@ -231,7 +223,6 @@ "cell_type": "code", "execution_count": 4, "metadata": { - "id": "ifRtigyGgZYk", "outputId": "14b17380-5077-4354-e651-027f3d933cfe" }, "outputs": [ @@ -251,9 +242,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "hvyohKMtelLG" - }, + "metadata": {}, "source": [ "## Imports" ] @@ -261,9 +250,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "id": "bh18UXRVerEz" - }, + "metadata": {}, "outputs": [], "source": [ "# Base imports\n", @@ -281,9 +268,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "id": "r-T0M1okfkrA" - }, + "metadata": {}, "outputs": [], "source": [ "# Local imports from current directory - auto reload.\n", @@ -298,18 +283,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "NFQxAbTWerTQ" - }, + "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", - "metadata": { - "id": "K6Soh8gkYKQB" - }, + "metadata": {}, "source": [ "TensorFlow Datasets supports customizable visualization of the ogbg_molpcba dataset." ] @@ -317,9 +298,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "id": "0ohDGSB_ZC0q" - }, + "metadata": {}, "outputs": [], "source": [ "# Visualization helpers\n", @@ -364,7 +343,6 @@ "cell_type": "code", "execution_count": 8, "metadata": { - "id": "U5jNKcD3YFsO", "outputId": "d9336190-e685-43e8-e3f1-73e3f1ce1cd2" }, "outputs": [ @@ -594,9 +572,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "CZJz7xoKevYn" - }, + "metadata": {}, "source": [ "## Training" ] @@ -604,9 +580,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Bha8sF3ne0mg" - }, + "metadata": {}, "outputs": [], "source": [ "# Start TensorBoard\n", @@ -621,7 +595,6 @@ "cell_type": "code", "execution_count": 10, "metadata": { - "id": "iVA3nVAth5Wh", "outputId": "7696aae4-beb5-4df2-b72a-7f391ac30c2e" }, "outputs": [ @@ -750,8 +723,7 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "cellView": "form", - "id": "L9BT3cGoiMNo" + "cellView": "form" }, "outputs": [], "source": [ @@ -767,9 +739,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "541fzDOQeyA0" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -777,9 +747,7 @@ { "cell_type": "code", "execution_count": 12, - "metadata": { - "id": "HMxkT1Lge1h5" - }, + "metadata": {}, "outputs": [], "source": [ "# Create deterministic evaluation model.\n", @@ -791,7 +759,6 @@ "cell_type": "code", "execution_count": 13, "metadata": { - "id": "NyDT6Ayp_s-G", "outputId": "51242dc2-5260-4279-b41a-9d7614b33c97" }, "outputs": [ @@ -822,7 +789,6 @@ "cell_type": "code", "execution_count": 14, "metadata": { - "id": "dhQSaZ5Z2sd6", "outputId": "2a529c9f-fed6-4221-a738-a8ffd9049d7e" }, "outputs": [ @@ -852,9 +818,7 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "id": "7AcUjIPN7pE8" - }, + "metadata": {}, "outputs": [], "source": [ "# Helper functions for formatting labels and predictions.\n", @@ -882,9 +846,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "JNQnsIFQtf92" - }, + "metadata": {}, "source": [ "We can choose one of the 128 different tasks and see how the model predictions\n", "match up with the true labels.\n", @@ -897,9 +859,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "id": "U-lXmVI4LBPE" - }, + "metadata": {}, "outputs": [], "source": [ "# Define which task to plot labels for.\n", @@ -910,7 +870,6 @@ "cell_type": "code", "execution_count": 17, "metadata": { - "id": "hDKR3yVIwOm3", "outputId": "38934096-13a6-4701-823a-ac83f3b7eaac" }, "outputs": [ diff --git a/examples/seq2seq/seq2seq.ipynb b/examples/seq2seq/seq2seq.ipynb index 6ecce14748..5cea631200 100644 --- a/examples/seq2seq/seq2seq.ipynb +++ b/examples/seq2seq/seq2seq.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax seq2seq Example\n", "\n", @@ -16,9 +14,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -40,9 +36,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -51,7 +45,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xVAH-aWN3NzF", "outputId": "4c0a705c-8d7e-44cc-d851-873a40ac115e" }, "outputs": [ @@ -78,7 +71,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "SwX8bCNEGhJM", "tags": [] }, "outputs": [], @@ -94,7 +86,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "o65RonwHp4Y9", "outputId": "4801432e-4090-4b13-f0f2-d99a3039ce47" }, "outputs": [ @@ -230,7 +221,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xcXZ-F3_zBuJ", "outputId": "a292a7a2-ae3c-4518-af28-9c2fa0ed2d7b" }, "outputs": [ @@ -249,9 +239,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports" ] @@ -259,9 +247,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import app\n", @@ -277,7 +263,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "6Y1ru2Ovp4ZI", "outputId": "7e1a29ce-9d8b-4715-ce60-9eae100a1df3", "tags": [] }, @@ -303,9 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -314,7 +297,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "xce4axo5Y9xp", "outputId": "cb5f7f6e-1e6f-40ff-e0d6-5b428511d75b" }, "outputs": [ @@ -343,7 +325,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "k_ZD70nIYlEq", "outputId": "b58ea813-e757-4cc5-f3ba-3cb0f05d35a6" }, "outputs": [ @@ -376,7 +357,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "UF19Nr2zZRQo", "outputId": "3b33e061-f0b5-42d7-ad49-5058e8fd3b90" }, "outputs": [ @@ -398,9 +378,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -408,9 +386,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -423,9 +399,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "LR9apE1dcFy0" - }, + "metadata": {}, "outputs": [], "source": [ "import time\n", @@ -436,7 +410,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "HgjiCPuAbZ5m", "outputId": "e49554e2-9336-4d97-a1e2-82b9e98407da" }, "outputs": [ @@ -464,7 +437,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "az3CUuNacBkS", "outputId": "49396889-35b0-4a11-8b8a-e67624be32a7" }, "outputs": [ @@ -598,7 +570,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "outputId": "2beaf4e9-b10b-4156-d2d9-187777306de0", "tags": [] }, @@ -655,9 +626,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "GBh-2D-Wp4ZY" - }, + "metadata": {}, "source": [ "## Inference" ] @@ -666,7 +635,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "hwi0ylrOgVKT", "outputId": "e22b7208-5413-4a63-abfb-b510af60f340" }, "outputs": [ @@ -690,9 +658,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "hNRtka4Ng61k" - }, + "metadata": {}, "outputs": [], "source": [ "# Using different random seeds generates different samples.\n", @@ -703,7 +669,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "2LWKWLyohTt8", "outputId": "e5cdfd75-2c66-4165-8ab7-9fdecde5062a" }, "outputs": [ diff --git a/examples/sst2/sst2.ipynb b/examples/sst2/sst2.ipynb index 10d05c98f4..1a3d0e1bc2 100644 --- a/examples/sst2/sst2.ipynb +++ b/examples/sst2/sst2.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "wquAnPg0p4Y8" - }, + "metadata": {}, "source": [ "# Flax SST-2 Example\n", "\n", @@ -16,18 +14,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "qYA_sBsVn1SY" - }, + "metadata": {}, "source": [ "**Before you start:** Select Runtime -> Change runtime type -> GPU." ] }, { "cell_type": "markdown", - "metadata": { - "id": "UuqrLz3he_1M" - }, + "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", @@ -49,9 +43,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "2cMTM3W4hcsZ" - }, + "metadata": {}, "source": [ "## Setup" ] @@ -59,9 +51,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "SwX8bCNEGhJM" - }, + "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/sst2'\n", @@ -71,9 +61,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "o65RonwHp4Y9" - }, + "metadata": {}, "outputs": [], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", @@ -126,9 +114,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "xcXZ-F3_zBuJ" - }, + "metadata": {}, "outputs": [], "source": [ "# Note: In Colab, above cell changed the working directory.\n", @@ -138,9 +124,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "qgUlFbSy_9q_" - }, + "metadata": {}, "outputs": [], "source": [ "# Install SST-2 dependencies.\n", @@ -149,9 +133,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Tt0rL4ycp4ZB" - }, + "metadata": {}, "source": [ "## Imports / Helpers" ] @@ -159,9 +141,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "08kWwdKZYZtG" - }, + "metadata": {}, "outputs": [], "source": [ "# If you want to use TPU instead of GPU, you need to run this to make it work.\n", @@ -178,9 +158,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "EdzHCJuop4ZB" - }, + "metadata": {}, "outputs": [], "source": [ "from absl import logging\n", @@ -200,7 +178,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "6Y1ru2Ovp4ZI", "tags": [] }, "outputs": [], @@ -219,9 +196,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "gGi7zcRpp4ZL" - }, + "metadata": {}, "source": [ "## Dataset" ] @@ -230,7 +205,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "BRg0rNsJp4ZL", "tags": [] }, "outputs": [], @@ -243,9 +217,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "KqW8WP5bp4ZS" - }, + "metadata": {}, "source": [ "## Training" ] @@ -253,9 +225,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "zzBxSXGGyEfw" - }, + "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", @@ -269,7 +239,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "RHoBUKSkp4ZS", "tags": [] }, "outputs": [], @@ -286,7 +255,6 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "mZOKD0Y7p4ZW", "tags": [] }, "outputs": [], diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index 8949e0dca5..327c92ad63 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -4,8 +4,7 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "colab_type": "code", - "id": "x0SPwYS9dtYA" + "colab_type": "code" }, "outputs": [], "source": [ @@ -19,8 +18,7 @@ "cell_type": "code", "execution_count": 2, "metadata": { - "colab_type": "code", - "id": "7n9cxyCzluvI" + "colab_type": "code" }, "outputs": [], "source": [ @@ -31,8 +29,7 @@ "cell_type": "code", "execution_count": 3, "metadata": { - "colab_type": "code", - "id": "0L7YCrobkfzU" + "colab_type": "code" }, "outputs": [], "source": [ @@ -44,12 +41,6 @@ "execution_count": 4, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 1116, - "status": "ok", - "timestamp": 1590673431275 - }, - "id": "aDLGb3iGkjoL", "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "tags": [] }, @@ -114,12 +105,6 @@ "execution_count": 5, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 526, - "status": "ok", - "timestamp": 1590672865722 - }, - "id": "LTFjZbRmlqZh", "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "tags": [] }, @@ -178,12 +163,6 @@ "execution_count": 6, "metadata": { "colab_type": "code", - "executionInfo": { - "elapsed": 342, - "status": "ok", - "timestamp": 1590673618925 - }, - "id": "TMlae0hem0u5", "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733" }, "outputs": [ diff --git a/tests/colab_tpu_jax_version.ipynb b/tests/colab_tpu_jax_version.ipynb index 77a65bc812..7f55df00d9 100644 --- a/tests/colab_tpu_jax_version.ipynb +++ b/tests/colab_tpu_jax_version.ipynb @@ -3,9 +3,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "6RTgYOrq2Mbp" - }, + "metadata": {}, "outputs": [], "source": [ "# JAX/jaxlib should be both 0.3.25\n", @@ -17,9 +15,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "0TI6hM1oU-y9" - }, + "metadata": {}, "outputs": [], "source": [ "# should show 8 TPU devices\n", @@ -42,9 +38,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "xBlI1bhd0QzN" - }, + "metadata": {}, "outputs": [], "source": [ "# in case JAX version has changed after the '!pip install`, below command should\n", @@ -56,9 +50,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Xf--m2iHVUgh" - }, + "metadata": {}, "outputs": [], "source": [ "# it's possible to get dependency tree without installing packages, but this\n",