diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index 1420f39a..4bed8d81 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -75,7 +75,7 @@ jobs: uv-version: "0.3.0" - name: Install standalone dependencies only run: | - uv sync --locked --extra all + uv sync --extra all - name: Test importing Flax run: | uv run python -c "import flax" @@ -108,13 +108,10 @@ jobs: uses: astral-sh/setup-uv@v2 with: version: "0.3.0" - - name: Setup Rust (flaxlib) - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Install dependencies run: | - uv sync --locked --extra all --extra testing --extra docs - uv pip install ./flaxlib + uv sync --extra all --extra testing --extra docs - name: Install JAX run: | if [[ "${{ matrix.jax-version }}" == "newest" ]]; then diff --git a/.gitignore b/.gitignore index 2d436c71..0bc7f3cb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,10 @@ build/ docs*/**/_autosummary docs*/_build docs*/**/tmp +flaxlib_src/build +flaxlib_src/builddir +flaxlib_src/dist +flaxlib_src/subprojects # used by direnv .envrc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 776f5c3d..93f25601 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: hooks: - id: check-toml - id: trailing-whitespace - exclude: ^docs*/.*\.md$ + exclude: ^docs.*\.md$ - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a298007..7f6aeccf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,11 @@ vNext - - - +- removed GeGLU simplistic activation, it should be implemented manually. - - - -- -- +- removed FLAX_LAZY_RNG flag support for old non-lazy PRNG derivation mode - - - diff --git a/README.md b/README.md index 92073c71..63b5592f 100644 --- a/README.md +++ b/README.md @@ -6,39 +6,40 @@ ![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax) - [**Overview**](#overview) | [**Quick install**](#quick-install) | [**What does Flax look like?**](#what-does-flax-look-like) | [**Documentation**](https://flax.readthedocs.io/) -**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API! +Released in 2024, Flax NNX is a new simplified Flax API that is designed to make +it easier to create, inspect, debug, and analyze neural networks in +[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support +for Python reference semantics. This allows users to express their models using +regular Python objects, enabling reference sharing and mutability. + +Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which +was released in 2020 by engineers and researchers at Google Brain in close collaboration +with the JAX team. -This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).** +You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out: -Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community. +* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html) +* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html) +* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html) +* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) -Flax is being used by a growing -community of hundreds of folks in various Alphabet research departments -for their daily work, as well as a [growing community -of open source -projects](https://github.com/google/flax/network/dependents?dependent_type=REPOSITORY). +**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/). The Flax team's mission is to serve the growing JAX neural network -research ecosystem -- both within Alphabet and with the broader community, +research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, -issue and pull request threads. We are in the process of moving some -remaining internal design docs and conversation threads to GitHub -discussions, issues and pull requests. We hope to increasingly engage -with the needs and clarifications of the broader ecosystem. Please let -us know how we can help! +issue and pull request threads. -Please report any feature requests, -issues, questions or concerns in our [discussion -forum](https://github.com/google/flax/discussions), or just let us -know what you're working on! +You can make feature requests, let us know what you are working on, +report issues, ask questions in our [Flax GitHub discussion +forum](https://github.com/google/flax/discussions). We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md) @@ -51,31 +52,22 @@ In case you want to reach us directly, we're at flax-dev@google.com. Flax is a high-performance neural network library and ecosystem for JAX that is **designed for flexibility**: Try new forms of training by forking an example and by modifying the training -loop, not by adding features to a framework. +loop, not adding features to a framework. Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including: -* **Neural network API** (`flax.linen`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout - -* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device +* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout). -* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging +* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device. -* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b +* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma), [Transformer LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx). ## Quick install -You will need Python 3.6 or later, and a working [JAX](https://github.com/google/jax/blob/main/README.md) -installation (with or without GPU support - refer to [the instructions](https://github.com/google/jax/blob/main/README.md)). -For a CPU-only version of JAX: - -``` -pip install --upgrade pip # To support manylinux2010 wheels. -pip install --upgrade jax jaxlib # CPU-only -``` +Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html). -Then, install Flax from PyPi: +You will need Python 3.8 or later. Install Flax from PyPi: ``` pip install flax @@ -86,6 +78,7 @@ To upgrade to the latest version of Flax, you can use: ``` pip install --upgrade git+https://github.com/google/flax.git ``` + To install some additional dependencies (like `matplotlib`) that are required but not included by some dependencies, you can use: @@ -101,95 +94,60 @@ To learn more about the `Module` abstraction, check out our [docs](https://flax. [guides](https://flax.readthedocs.io/en/latest/guides/index.html) and [developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html). +Example of an MLP: + ```py -from typing import Sequence +class MLP(nnx.Module): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.linear1 = Linear(din, dmid, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.linear2 = Linear(dmid, dout, rngs=rngs) + + def __call__(self, x: jax.Array): + x = nnx.gelu(self.dropout(self.bn(self.linear1(x)))) + return self.linear2(x) +``` -import numpy as np -import jax -import jax.numpy as jnp -import flax.linen as nn +Example of a CNN: -class MLP(nn.Module): - features: Sequence[int] +```py +class CNN(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) - @nn.compact def __call__(self, x): - for feat in self.features[:-1]: - x = nn.relu(nn.Dense(feat)(x)) - x = nn.Dense(self.features[-1])(x) + x = self.avg_pool(nnx.relu(self.conv1(x))) + x = self.avg_pool(nnx.relu(self.conv2(x))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.linear1(x)) + x = self.linear2(x) return x - -model = MLP([12, 8, 4]) -batch = jnp.ones((32, 10)) -variables = model.init(jax.random.key(0), batch) -output = model.apply(variables, batch) ``` -```py -class CNN(nn.Module): - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) - x = nn.log_softmax(x) - return x +Example of an autoencoder: -model = CNN() -batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format -variables = model.init(jax.random.key(0), batch) -output = model.apply(variables, batch) -``` ```py -class AutoEncoder(nn.Module): - encoder_widths: Sequence[int] - decoder_widths: Sequence[int] - input_shape: Sequence[int] - - def setup(self): - input_dim = np.prod(self.input_shape) - self.encoder = MLP(self.encoder_widths) - self.decoder = MLP(self.decoder_widths + (input_dim,)) +Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) +Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) - def __call__(self, x): - return self.decode(self.encode(x)) +class AutoEncoder(nnx.Module): + def __init__(self, rngs): + self.encoder = Encoder(rngs) + self.decoder = Decoder(rngs) - def encode(self, x): - assert x.shape[1:] == self.input_shape - return self.encoder(jnp.reshape(x, (x.shape[0], -1))) + def __call__(self, x) -> jax.Array: + return self.decoder(self.encoder(x)) - def decode(self, z): - z = self.decoder(z) - x = nn.sigmoid(z) - x = jnp.reshape(x, (x.shape[0],) + self.input_shape) - return x - -model = AutoEncoder(encoder_widths=[20, 10, 5], - decoder_widths=[5, 10, 20], - input_shape=(12,)) -batch = jnp.ones((16, 12)) -variables = model.init(jax.random.key(0), batch) -encoded = model.apply(variables, batch, method=model.encode) -decoded = model.apply(variables, encoded, method=model.decode) + def encode(self, x) -> jax.Array: + return self.encoder(x) ``` -## 🤗 Hugging Face - -In-detail examples to train and evaluate a variety of Flax models for -Natural Language Processing, Computer Vision, and Speech Recognition are -actively maintained in the [🤗 Transformers repository](https://github.com/huggingface/transformers/tree/main/examples/flax). - -As of October 2021, the [19 most-used Transformer architectures](https://huggingface.co/transformers/#supported-frameworks) are supported in Flax -and over 5000 pretrained checkpoints in Flax have been uploaded to the [🤗 Hub](https://huggingface.co/models?library=jax&sort=downloads). - ## Citing Flax To cite this repository: @@ -199,7 +157,7 @@ To cite this repository: author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, - version = {0.9.0}, + version = {0.10.2}, year = {2024}, } ``` @@ -209,4 +167,4 @@ is intended to be that from [flax/version.py](https://github.com/google/flax/blo ## Note -Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product. +Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product. diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py new file mode 100644 index 00000000..73cff6d6 --- /dev/null +++ b/benchmarks/nnx_graph_overhead.py @@ -0,0 +1,118 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import numpy as np +import optax +from time import time + +from flax import nnx + +from absl import flags +from absl import app + +FLAGS = flags.FLAGS +flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') +flags.DEFINE_integer('width', 32, 'Hidden layer size') +flags.DEFINE_integer('depth', 5, 'Depth of the model') + + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.list = [ + nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + nnx.Param(jnp.zeros((dout,))), + ] + self.dict = { + 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + 'b': nnx.Param(jnp.zeros((dout,))), + } + + + +class MLP(nnx.Module): + def __init__(self, depth, *, rngs: nnx.Rngs): + self.intermediates = [ + Linear(10, 10, rngs=rngs) for _ in range(depth) + ] + + +def main(argv): + print(argv) + mode: str = FLAGS.mode + total_steps: int = FLAGS.total_steps + width: int = FLAGS.width + depth: int = FLAGS.depth + + print(f'{mode=}, {total_steps=}, {width=}') + + X = np.linspace(0, 1, 100)[:, None] + Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + model = MLP(depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + + #------------------------------------------------------------ + # NNX + #------------------------------------------------------------ + if mode in ['all', 'nnx']: + @nnx.jit + def step_nnx(model: MLP, optimizer: nnx.Optimizer): + pass + + t0 = time() + for _ in range(total_steps): + step_nnx(model, optimizer) + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print("### NNX ###") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + + + #------------------------------------------------------------ + # JAX + #------------------------------------------------------------ + + if mode in ['all', 'jax']: + @jax.jit + def step_jax(graphdef, state): + return graphdef, state + + graphdef, state = nnx.split((model, optimizer)) + t0 = time() + for _ in range(total_steps): + graphdef, state = step_jax(graphdef, state) + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print("### JAX ###") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() + + + +if __name__ == '__main__': + app.run(main) diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py new file mode 100644 index 00000000..0cb08066 --- /dev/null +++ b/benchmarks/nnx_simple_training.py @@ -0,0 +1,168 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import numpy as np +import optax +from time import time + +from flax import nnx + +from absl import flags +from absl import app + +FLAGS = flags.FLAGS +flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') +flags.DEFINE_integer('batch_size', 32, 'Batch size') +flags.DEFINE_integer('width', 32, 'Hidden layer size') +flags.DEFINE_integer('depth', 5, 'Depth of the model') + + +def dataset(X, Y, batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear_in = Linear(din, dhidden, rngs=rngs) + self.intermediates = [ + Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) + ] + self.linear_out = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count.value += 1 + x = nnx.relu(self.linear_in(x)) + for layer in self.intermediates: + x = nnx.relu(layer(x)) + x = self.linear_out(x) + return x + + +def main(argv): + print(argv) + mode: str = FLAGS.mode + total_steps: int = FLAGS.total_steps + batch_size: int = FLAGS.batch_size + width: int = FLAGS.width + depth: int = FLAGS.depth + + print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') + + if mode not in ['nnx', 'jax']: + raise ValueError(f'Invalid mode: {mode}') + + X = np.linspace(0, 1, 100)[:, None] + Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + + if mode == 'nnx': + + @nnx.jit + def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads: nnx.State = nnx.grad(loss_fn)(model) + optimizer.update(grads) + + @nnx.jit + def test_step_nnx(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_nnx(model, optimizer, batch) + + if step % 1000 == 0: + logs = test_step_nnx(model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + else: + + @jax.jit + def train_step_jax(graphdef, state, batch): + model, optimizer = nnx.merge(graphdef, state) + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = nnx.grad(loss_fn)(model) + optimizer.update(grads) + + return nnx.state((model, optimizer)) + + @jax.jit + def test_step_jax(graphdef, state, batch): + model, optimizer = nnx.merge(graphdef, state) + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + state = nnx.state((model, optimizer)) + return state, {'loss': loss} + + graphdef, state = nnx.split((model, optimizer)) + + for step, batch in enumerate(dataset(X, Y, batch_size)): + state = train_step_jax(graphdef, state, batch) + + if step % 1000 == 0: + state, logs = test_step_jax(graphdef, state, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + + model, optimizer = nnx.merge(graphdef, state) + + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) + + +if __name__ == '__main__': + app.run(main) diff --git a/dev/.devcontainer/Dockerfile b/dev/.devcontainer/Dockerfile index 11fb57e0..6f48300f 100644 --- a/dev/.devcontainer/Dockerfile +++ b/dev/.devcontainer/Dockerfile @@ -3,7 +3,7 @@ # Licensed under the MIT License. See https://go.microsoft.com/fwlink/?linkid=2090316 for license information. #------------------------------------------------------------------------------------------------------------- -FROM python:3.7 +FROM python:3.12 # Avoid warnings by switching to noninteractive ENV DEBIAN_FRONTEND=noninteractive @@ -52,7 +52,7 @@ RUN apt-get update \ RUN pip install numpy jaxlib tensorflow tensorflow-datasets matplotlib msgpack \ jupyter pytest pytest-xdist \ twine \ - sphinx sphinx_rtd_theme ipykernel nbsphinx recommonmark sklearn + sphinx sphinx_rtd_theme ipykernel nbsphinx recommonmark scikit-learn # Switch back to dialog for any ad-hoc use of apt-get ENV DEBIAN_FRONTEND=dialog diff --git a/dev/.devcontainer/devcontainer.json b/dev/.devcontainer/devcontainer.json index 32b87efa..2b6cde28 100644 --- a/dev/.devcontainer/devcontainer.json +++ b/dev/.devcontainer/devcontainer.json @@ -22,7 +22,7 @@ // Uncomment the next line to run commands after the container is created. // "postCreateCommand": "pip install -r requirements.txt", "runArgs": ["-v", "${env:HOME}${env:USERPROFILE}/.ssh:/root/.ssh-localhost:ro"], - "postCreateCommand": "sudo cp -r /root/.ssh-localhost ~/.ssh && sudo chown -R $(id -u):$(id -g) ~/.ssh && chmod 700 ~/.ssh && chmod 600 ~/.ssh/* && pip install -e ./flax", + "postCreateCommand": "sudo cp -r /root/.ssh-localhost ~/.ssh && sudo chown -R $(id -u):$(id -g) ~/.ssh && chmod 700 ~/.ssh && chmod 600 ~/.ssh/* && pip install -e .", // Uncomment the next line to have VS Code connect as an existing non-root user in the container. // On Linux, by default, the container user's UID/GID will be updated to match your local user. See diff --git a/dev/README.md b/dev/README.md index 5d8f081b..56f3484b 100644 --- a/dev/README.md +++ b/dev/README.md @@ -13,4 +13,7 @@ the environments used by contributors and maintainers. 5. Re-open the folder workspace using the remote containers extension. VSCode should recommend this action in a popup. Alternatively, use the green button in the bottom left container to control the - remote extension. \ No newline at end of file + remote extension. + +## Troubleshoot: +If you have the following error `~/.docker/buildx/current: permission denied`, try running `sudo chown -R $(whoami) ~/.docker` \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e65c142b..cd90590d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -113,7 +113,7 @@ # href with no underline and white bold text color announcement = """ This site covers the old Flax Linen API. [Explore the new Flax NNX API ✨] diff --git a/docs/index.rst b/docs/index.rst index 2f0cfee6..c286d4f0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -326,4 +326,4 @@ Notable examples in Flax include: philosophy contributing api_reference/index - Flax NNX + Flax NNX diff --git a/docs_nnx/api_reference/flax.nnx/graph.rst b/docs_nnx/api_reference/flax.nnx/graph.rst index d944e3c7..2cf65c94 100644 --- a/docs_nnx/api_reference/flax.nnx/graph.rst +++ b/docs_nnx/api_reference/flax.nnx/graph.rst @@ -10,6 +10,8 @@ graph .. autofunction:: update .. autofunction:: pop .. autofunction:: state +.. autofunction:: variables +.. autofunction:: graph .. autofunction:: graphdef .. autofunction:: iter_graph .. autofunction:: clone @@ -22,4 +24,4 @@ graph :members: .. autofunction:: update_context -.. autofunction:: current_update_context \ No newline at end of file +.. autofunction:: current_update_context diff --git a/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst b/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst new file mode 100644 index 00000000..74edaf8d --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/dtypes.rst @@ -0,0 +1,8 @@ +Dtypes +------------------------ + +.. automodule:: flax.nnx.nn.dtypes +.. currentmodule:: flax.nnx.nn.dtypes + +.. autofunction:: canonicalize_dtype +.. autofunction:: promote_dtype \ No newline at end of file diff --git a/docs_nnx/api_reference/flax.nnx/nn/index.rst b/docs_nnx/api_reference/flax.nnx/nn/index.rst index 4b7600b0..e42d5842 100644 --- a/docs_nnx/api_reference/flax.nnx/nn/index.rst +++ b/docs_nnx/api_reference/flax.nnx/nn/index.rst @@ -9,8 +9,11 @@ See the `NNX page `__ for activations attention + dtypes initializers linear + lora normalization + recurrent stochastic diff --git a/docs_nnx/api_reference/flax.nnx/nn/lora.rst b/docs_nnx/api_reference/flax.nnx/nn/lora.rst new file mode 100644 index 00000000..43461027 --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/lora.rst @@ -0,0 +1,15 @@ +LoRA +------------------------ + +NNX LoRA classes. + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. flax_module:: + :module: flax.nnx + :class: LoRA + +.. flax_module:: + :module: flax.nnx + :class: LoRALinear diff --git a/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst b/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst new file mode 100644 index 00000000..b3270d95 --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/recurrent.rst @@ -0,0 +1,32 @@ +Recurrent +------------------------ + +.. automodule:: flax.nnx.nn.recurrent +.. currentmodule:: flax.nnx.nn.recurrent + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: LSTMCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: OptimizedLSTMCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: SimpleCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: GRUCell + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: RNN + +.. flax_module:: + :module: flax.nnx.nn.recurrent + :class: Bidirectional + + +.. autofunction:: flip_sequences \ No newline at end of file diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst index 4aaef0da..54ba3399 100644 --- a/docs_nnx/api_reference/flax.nnx/transforms.rst +++ b/docs_nnx/api_reference/flax.nnx/transforms.rst @@ -20,4 +20,8 @@ transforms .. autofunction:: value_and_grad .. autofunction:: vmap .. autofunction:: eval_shape +.. autofunction:: custom_vjp .. autofunction:: cond +.. autofunction:: switch +.. autofunction:: while_loop +.. autofunction:: fori_loop diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 344010ac..7eee4706 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -144,13 +144,11 @@ # files that will not be executed. myst_enable_extensions = ['dollarmath'] nb_execution_excludepatterns = [ - 'quick_start.ipynb', # <-- times out + 'mnist_tutorial.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx 'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update - 'guides/why.ipynb', # TODO(cgarciae): broken, remove in favor on the new guide - 'guides/flax_gspmd.ipynb', # TODO(IvyZX): broken, needs to be updated - 'guides/surgery.ipynb', # TODO(IvyZX): broken, needs to be updated + 'guides/gemma.ipynb', ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs_nnx/glossary.rst b/docs_nnx/glossary.rst deleted file mode 100644 index 1ed754a0..00000000 --- a/docs_nnx/glossary.rst +++ /dev/null @@ -1,50 +0,0 @@ -********* -Glossary -********* - -For additional terms, refer to the `JAX glossary `__. - -.. glossary:: - - Filter - A way to extract only certain :term:`Variables` out of a :term:`Module`. Usually done via calling :meth:`nnx.split ` upon the module. See the `Filter guide `__ to learn more. - - `Folding in `__ - Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to - generate a new key but still be able to use the original rng key afterwards. You can also do this with - `jax.random.split `__ - but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys - automatically in our - `RNG guide `__. - - GraphDef - :class:`nnx.GraphDef`, a class that represents all the static, stateless, Pythonic part of an :class:`nnx.Module` definition. - - Lifted transformation - A wrapped version of the `JAX transformations `__ that allows the transformed function to take Flax :term:`Modules` as input or output. For example, a lifted version of `jax.jit `__ will be :meth:`flax.nnx.jit `. See the `lifted transforms guide `__. - - Merge - See :term:`Split and merge`. - - Module - :class:`nnx.Module `, a dataclass allowing the definition and initialization of parameters in a - referentially-transparent form. This is responsible for storing and updating variables - and parameters within itself. - - Params / parameters - :class:`nnx.Param `, a particular subclass of :class:`nnx.Variable ` that generally contains the trainable weights. - - RNG states - A Flax :class:`module ` can keep a reference of an :class:`RNG state object ` that can generate new JAX `PRNG `__ keys. They keys are used to generate random JAX arrays through `JAX's functional random number generators `__. - You can use an RNG state with different seeds to make more fine-grained control on your model (e.g., independent random numbers for parameters and dropout masks). - See the `RNG guide `__ - for more details. - - Split and merge - :meth:`nnx.split `, a way to represent an `nnx.Module` by two parts - a static :term:`GraphDef ` that captures its Pythonic, static information, and one or more :term:`Variable state(s)` that captures its JAX arrays in the form of pytrees. They can be merged back to the original module with :meth:`nnx.merge `. - - Variable - The `weights / parameters / data / arrays `__ residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. - - Variable state - :class:`nnx.VariableState `, a purely functional pytree of all the :term:`Variables` inside a :term:`Module`. Since it's pure, it can be an input or output of a JAX transformation function. Obtained by using :term:`splitting` the module. \ No newline at end of file diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index e41836a9..c24b76a5 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -17,9 +17,9 @@ "\n", "**Note**:\n", "\n", - "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", + "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", "\n", - "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." + "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." ] }, { diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index 3f243ae2..3e2c9b4a 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -11,9 +11,9 @@ We hope this allows you to move and try out NNX at your own pace, and leverage t **Note**: -This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. +This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. -And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). +And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). ```python diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb new file mode 100644 index 00000000..449f8a77 --- /dev/null +++ b/docs_nnx/guides/checkpointing.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save and load checkpoints\n", + "\n", + "This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/).\n", + "\n", + "> **Note:** The Flax team does not actively maintain a library for saving and loading model checkpoints to disk. Therefore, it is recommended you use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.\n", + "\n", + "In this guide you will learn how to:\n", + "\n", + "* Save checkpoints.\n", + "* Restore checkpoints.\n", + "* Restore checkpoints if checkpoint structures differ. \n", + "* Perform multi-process checkpointing. \n", + "\n", + "The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/).\n", + "\n", + "> **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities.\n", + "\n", + "> **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import orbax.checkpoint as ocp\n", + "import jax\n", + "from jax import numpy as jnp\n", + "import numpy as np\n", + "\n", + "ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class TwoLayerMLP(nnx.Module):\n", + " def __init__(self, dim, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", + " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", + "\n", + " def __call__(self, x):\n", + " x = self.linear1(x)\n", + " return self.linear2(x)\n", + "\n", + "# Instantiate the model and show we can run it.\n", + "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", + "x = jax.random.normal(jax.random.key(42), (3, 4))\n", + "assert model(x).shape == (3, 4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save checkpoints\n", + "\n", + "JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states.\n", + "\n", + "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_, state = nnx.split(model)\n", + "nnx.display(state)\n", + "\n", + "checkpointer = ocp.StandardCheckpointer()\n", + "checkpointer.save(ckpt_dir / 'state', state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "## Restore checkpoints\n", + "\n", + "Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n", + "\n", + "At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n", + "- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n", + "- Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The abstract NNX state (all leaves are abstract arrays):\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NNX State restored: \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.\n", + "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "print('The abstract NNX state (all leaves are abstract arrays):')\n", + "nnx.display(abstract_state)\n", + "\n", + "state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)\n", + "jax.tree.map(np.testing.assert_array_equal, state, state_restored)\n", + "print('NNX State restored: ')\n", + "nnx.display(state_restored)\n", + "\n", + "# The model is now good to use!\n", + "model = nnx.merge(graphdef, state_restored)\n", + "assert model(x).shape == (3, 4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " The abstract NNX state (all leaves are abstract arrays):\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + " NNX State restored: \n", + "\n", + "\n", + " /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + " warnings.warn(\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "## Save and restore as pure dictionaries\n", + "\n", + "When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" + ] + } + ], + "source": [ + "# Save as pure dict\n", + "pure_dict_state = state.to_pure_dict()\n", + "nnx.display(pure_dict_state)\n", + "checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n", + "\n", + "# Restore as a pure dictionary.\n", + "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", + "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "abstract_state.replace_by_pure_dict(restored_pure_dict)\n", + "model = nnx.merge(graphdef, abstract_state)\n", + "assert model(x).shape == (3, 4) # The model still works!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", + "\n", + "\n", + "## Restore when checkpoint structures differ\n", + "\n", + "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below.\n", + "\n", + "This pattern also works if you save the checkpoint as an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class ModifiedTwoLayerMLP(nnx.Module):\n", + " \"\"\"A modified version of TwoLayerMLP, which requires bias arrays.\"\"\"\n", + " def __init__(self, dim, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", + " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", + "\n", + " def __call__(self, x):\n", + " x = self.linear1(x)\n", + " return self.linear2(x)\n", + "\n", + "# Accommodate your old checkpoint to the new code.\n", + "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", + "restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))\n", + "restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))\n", + "\n", + "# Same restore code as above.\n", + "abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graphdef, abstract_state = nnx.split(abstract_model)\n", + "abstract_state.replace_by_pure_dict(restored_pure_dict)\n", + "model = nnx.merge(graphdef, abstract_state)\n", + "assert model(x).shape == (3, 4) # The new model works!\n", + "\n", + "nnx.display(model.linear1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "\n", + "## Multi-process checkpointing\n", + "\n", + "In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out the [Load sharded model from a checkpoint](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) section in the Flax [Scale up on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint.\n", + "\n", + "> **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Other checkpointing features\n", + "\n", + "This guide only uses the simplest [`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n", + "\n", + "In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n", + "\n", + "* [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps.\n", + "\n", + "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html).\n", + "\n", + "* [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time, instead of after loading time, which is demonstrated in this guide." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md new file mode 100644 index 00000000..fa98e6db --- /dev/null +++ b/docs_nnx/guides/checkpointing.md @@ -0,0 +1,220 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Save and load checkpoints + +This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/). + +> **Note:** The Flax team does not actively maintain a library for saving and loading model checkpoints to disk. Therefore, it is recommended you use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it. + +In this guide you will learn how to: + +* Save checkpoints. +* Restore checkpoints. +* Restore checkpoints if checkpoint structures differ. +* Perform multi-process checkpointing. + +The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/). + +> **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities. + +> **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html). + ++++ + +### Setup + +Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). + +```{code-cell} ipython3 +from flax import nnx +import orbax.checkpoint as ocp +import jax +from jax import numpy as jnp +import numpy as np + +ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') +``` + +```{code-cell} ipython3 +class TwoLayerMLP(nnx.Module): + def __init__(self, dim, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False) + self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False) + + def __call__(self, x): + x = self.linear1(x) + return self.linear2(x) + +# Instantiate the model and show we can run it. +model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) +x = jax.random.normal(jax.random.key(42), (3, 4)) +assert model(x).shape == (3, 4) +``` + +## Save checkpoints + +JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, "tensors" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states. + +In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State). + +```{code-cell} ipython3 +_, state = nnx.split(model) +nnx.display(state) + +checkpointer = ocp.StandardCheckpointer() +checkpointer.save(ckpt_dir / 'state', state) +``` + +
+ + + +
+ + +## Restore checkpoints + +Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes. + +At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows: +- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library. +- Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual. + +```{code-cell} ipython3 +# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference. +abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) +graphdef, abstract_state = nnx.split(abstract_model) +print('The abstract NNX state (all leaves are abstract arrays):') +nnx.display(abstract_state) + +state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state) +jax.tree.map(np.testing.assert_array_equal, state, state_restored) +print('NNX State restored: ') +nnx.display(state_restored) + +# The model is now good to use! +model = nnx.merge(graphdef, state_restored) +assert model(x).shape == (3, 4) +``` + + The abstract NNX state (all leaves are abstract arrays): + + + +
+ + + NNX State restored: + + + /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with. + warnings.warn( + + + +
+ + + +
+ + +## Save and restore as pure dictionaries + +When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries. + +```{code-cell} ipython3 +# Save as pure dict +pure_dict_state = state.to_pure_dict() +nnx.display(pure_dict_state) +checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) + +# Restore as a pure dictionary. +restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') +abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) +graphdef, abstract_state = nnx.split(abstract_model) +abstract_state.replace_by_pure_dict(restored_pure_dict) +model = nnx.merge(graphdef, abstract_state) +assert model(x).shape == (3, 4) # The model still works! +``` + +
+ + + +
+ + + WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under. + + +## Restore when checkpoint structures differ + +The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below. + +This pattern also works if you save the checkpoint as an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179). + +```{code-cell} ipython3 +class ModifiedTwoLayerMLP(nnx.Module): + """A modified version of TwoLayerMLP, which requires bias arrays.""" + def __init__(self, dim, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now! + self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now! + + def __call__(self, x): + x = self.linear1(x) + return self.linear2(x) + +# Accommodate your old checkpoint to the new code. +restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') +restored_pure_dict['linear1']['bias'] = jnp.zeros((4,)) +restored_pure_dict['linear2']['bias'] = jnp.zeros((4,)) + +# Same restore code as above. +abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))) +graphdef, abstract_state = nnx.split(abstract_model) +abstract_state.replace_by_pure_dict(restored_pure_dict) +model = nnx.merge(graphdef, abstract_state) +assert model(x).shape == (3, 4) # The new model works! + +nnx.display(model.linear1) +``` + + WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under. + + + +
+ + + +
+ + +## Multi-process checkpointing + +In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out the [Load sharded model from a checkpoint](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) section in the Flax [Scale up on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint. + +> **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html). + ++++ + +## Other checkpointing features + +This guide only uses the simplest [`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit. + +In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as: + +* [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps. + +* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html). + +* [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time, instead of after loading time, which is demonstrated in this guide. diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index 168ae2bd..44dcfd51 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -6,7 +6,7 @@ "source": [ "# Scale up on multiple devices\n", "\n", - "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." + "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on [multiple devices and hosts](Multi-host and multi-process environments) - such as GPUs, Google TPUs, and CPUs - using the [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html)." ] }, { @@ -16,13 +16,15 @@ "source": [ "## Overview\n", "\n", - "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", + "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and Google TPUs. At the core of scaling up is the [JAX just-in-time (`jax.jit`) compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", "\n", - "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n", + "> **Note:** To learn more about Flax’s transformations, such as `nnx.jit` and `nnx.vmap`, go to [Why Flax NNX? - Transforms](https://flax.readthedocs.io/en/latest/why.html#transforms), [Transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html), and [Flax NNX vs JAX Transformations](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", "\n", - "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", + "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will [automatically compile](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) and [run it](https://jax.readthedocs.io/en/latest/sharded-computation.html) on [multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "\n", - "> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n", + "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", + "\n", + "> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n", "\n", "If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n", "\n", @@ -46,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -74,14 +76,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" + "You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" ] } ], @@ -106,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -131,16 +133,16 @@ "source": [ "## Define a model with specified sharding\n", "\n", - "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", - "\n", - "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", + "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).\n", + "- This layer carries out two dot product multiplications upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", + "- To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", "\n", - "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more." + "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -197,12 +199,12 @@ "source": [ "## Initialize a sharded model\n", "\n", - "Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will \"OOM\" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized." + "Now, you have annotations attached to the Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights have not been sharded yet. If you just go ahead and create this model, all [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will \"OOM\" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -219,7 +221,7 @@ "source": [ "unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))\n", "\n", - "# You have annotations sticked there, yay!\n", + "# You have annotations stuck there, yay!\n", "print(unsharded_model.dot1.kernel.sharding) # (None, 'model')\n", "print(unsharded_model.w2.sharding) # ('model', None)\n", "\n", @@ -232,15 +234,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", + "Here, you should leverage JAX's compilation mechanism via Flax’s [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", "\n", - "1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", + "1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", "\n", "1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable!\n", "\n", "1. Throw away the unsharded state and return the model based upon the sharded state.\n", "\n", - "1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful NNX module.\n", + "1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful Flax NNX `Module`.\n", "\n", "1. Run it under a device mesh context so that JAX knows which devices to shard it to.\n", "\n", @@ -249,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -291,12 +293,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can view the sharding of any 1-D or 2-D array with `jax.debug.visualize_array_sharding`:" + "You can view the sharding of any 1-D or 2-D array with [`jax.debug.visualize_array_sharding`](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html):" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -309,31 +311,31 @@ { "data": { "text/html": [ - "
                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
+       "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, @@ -349,33 +351,27 @@ { "data": { "text/html": [ - "
                         \n",
-       "         CPU 0,4         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 1,5         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 2,6         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 3,7         \n",
-       "                         \n",
+       "
┌───────────────────────┐\n",
+       "│        CPU 0,4        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 1,5        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 2,6        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 3,7        │\n",
+       "└───────────────────────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────────────────────┐\n", + "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "└───────────────────────┘\n" ] }, "metadata": {}, @@ -399,7 +395,7 @@ "\n", "> **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail.\n", "\n", - "You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables.\n", + "You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API if you want to explicitly shard values that are not model variables.\n", "\n", "This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section." ] @@ -408,48 +404,48 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Load sharded model from a checkpoint\n", + "## Load a sharded model from a checkpoint\n", "\n", - "Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given.\n", + "Now you learned how to initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading a model sharded if a sharding pytree is provided.\n", "\n", - "You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", + "You can generate such a sharding pytree with Flax’s [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", "\n", - "Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs." + "Below is an example that demonstrates using Orbax's `StandardCheckpointer` API. (Go to the [Orbax documentation site](https://orbax.readthedocs.io/en/latest/) to learn about their latest and most recommended APIs.)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
+       "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, @@ -458,33 +454,27 @@ { "data": { "text/html": [ - "
                         \n",
-       "         CPU 0,4         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 1,5         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 2,6         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 3,7         \n",
-       "                         \n",
+       "
┌───────────────────────┐\n",
+       "│        CPU 0,4        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 1,5        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 2,6        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 3,7        │\n",
+       "└───────────────────────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────────────────────┐\n", + "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "└───────────────────────┘\n" ] }, "metadata": {}, @@ -510,7 +500,7 @@ " abs_state, nnx.get_named_sharding(abs_state, mesh)\n", ")\n", "loaded_sharded = checkpointer.restore(path / 'checkpoint_name',\n", - " args=ocp.args.StandardRestore(abs_state))\n", + " target=abs_state)\n", "jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)\n", "jax.debug.visualize_array_sharding(loaded_sharded.w2.value)" ] @@ -521,16 +511,18 @@ "source": [ "## Compile the training loop\n", "\n", - "Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this.\n", + "Now, after either initialization or loading the checkpoint, you have a sharded model. To carry out the compiled scaled up training, you need to shard the inputs as well.\n", "\n", - "Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", + "- In the data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this.\n", + "- Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jit` compilation. \n", + "- In the example below, even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", "\n", - "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level." + "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low-level." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -543,33 +535,31 @@ { "data": { "text/html": [ - "
                                                                                \n",
-       "                                                                                \n",
-       "                                  CPU 0,1,2,3                                   \n",
-       "                                                                                \n",
-       "                                                                                \n",
-       "                                                                                \n",
-       "                                                                                \n",
-       "                                                                                \n",
-       "                                  CPU 4,5,6,7                                   \n",
-       "                                                                                \n",
-       "                                                                                \n",
-       "                                                                                \n",
+       "
┌──────────────────────────────────────────────────────────────────────────────┐\n",
+       "│                                                                              │\n",
+       "│                                 CPU 0,1,2,3                                  │\n",
+       "│                                                                              │\n",
+       "│                                                                              │\n",
+       "├──────────────────────────────────────────────────────────────────────────────┤\n",
+       "│                                                                              │\n",
+       "│                                 CPU 4,5,6,7                                  │\n",
+       "│                                                                              │\n",
+       "│                                                                              │\n",
+       "└──────────────────────────────────────────────────────────────────────────────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,1,2,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 4,5,6,7\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" + "┌──────────────────────────────────────────────────────────────────────────────┐\n", + "│ │\n", + "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├──────────────────────────────────────────────────────────────────────────────┤\n", + "│ │\n", + "│ CPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└──────────────────────────────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, @@ -577,28 +567,28 @@ } ], "source": [ - "# In data parallelism, the first dimension (batch) will be sharded on `data` axis.\n", + "# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.\n", "data_sharding = NamedSharding(mesh, PartitionSpec('data', None))\n", "input = jax.device_put(jnp.ones((8, 1024)), data_sharding)\n", "\n", "with mesh:\n", " output = sharded_model(input)\n", "print(output.shape)\n", - "jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None)" + "jax.debug.visualize_array_sharding(output) # Also sharded as `('data', None)`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n", - "\n", - "[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." + "Now the rest of the training loop is pretty conventional - it is almost the same as the example in [Flax NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms):\n", + "- Except that the inputs and labels are also explicitly sharded.\n", + "- [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -642,19 +632,19 @@ "source": [ "## Profiling\n", "\n", - "If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" + "If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "7.09 ms ± 390 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "7.89 ms ± 486 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], @@ -678,12 +668,12 @@ "\n", "JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n", "\n", - "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below." + "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot()` example below." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -723,42 +713,42 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." + "If you didn't provide all `sharding_rule` annotations in the model definition, you can write a few lines to add it to Flax’s [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       " CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
-       "                                    \n",
+       "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, @@ -767,33 +757,27 @@ { "data": { "text/html": [ - "
                         \n",
-       "         CPU 0,4         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 1,5         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 2,6         \n",
-       "                         \n",
-       "                         \n",
-       "         CPU 3,7         \n",
-       "                         \n",
+       "
┌───────────────────────┐\n",
+       "│        CPU 0,4        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 1,5        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 2,6        │\n",
+       "├───────────────────────┤\n",
+       "│        CPU 3,7        │\n",
+       "└───────────────────────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", - "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" + "┌───────────────────────┐\n", + "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "├───────────────────────┤\n", + "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "└───────────────────────┘\n" ] }, "metadata": {}, @@ -850,9 +834,9 @@ "\n", " * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", + " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", "\n", - "* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." + "* **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." ] } ], @@ -870,7 +854,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 342774d9..50441f94 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -10,19 +10,21 @@ jupytext: # Scale up on multiple devices -This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on [multiple devices and hosts](Multi-host and multi-process environments) - such as GPUs, Google TPUs, and CPUs - using the [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html). +++ ## Overview -Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. +Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and Google TPUs. At the core of scaling up is the [JAX just-in-time (`jax.jit`) compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. -JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices. +> **Note:** To learn more about Flax’s transformations, such as `nnx.jit` and `nnx.vmap`, go to [Why Flax NNX? - Transforms](https://flax.readthedocs.io/en/latest/why.html#transforms), [Transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html), and [Flax NNX vs JAX Transformations](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). -To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. +JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will [automatically compile](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) and [run it](https://jax.readthedocs.io/en/latest/sharded-computation.html) on [multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). -> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer. +To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. + +> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer. If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials: @@ -79,11 +81,11 @@ print(mesh) ## Define a model with specified sharding -Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. - -To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). +Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). +- This layer carries out two dot product multiplications upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. +- To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). -> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html) to learn more. +> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more. ```{code-cell} ipython3 class DotReluDot(nnx.Module): @@ -130,12 +132,12 @@ JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs ## Initialize a sharded model -Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will "OOM" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized. +Now, you have annotations attached to the Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights have not been sharded yet. If you just go ahead and create this model, all [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will "OOM" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized. ```{code-cell} ipython3 unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0)) -# You have annotations sticked there, yay! +# You have annotations stuck there, yay! print(unsharded_model.dot1.kernel.sharding) # (None, 'model') print(unsharded_model.w2.sharding) # ('model', None) @@ -144,15 +146,15 @@ print(unsharded_model.dot1.kernel.value.sharding) # SingleDeviceSharding print(unsharded_model.w2.value.sharding) # SingleDeviceSharding ``` -Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: +Here, you should leverage JAX's compilation mechanism via Flax’s [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: -1. Use [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. +1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. 1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jit` how to shard a variable! 1. Throw away the unsharded state and return the model based upon the sharded state. -1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful NNX module. +1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful Flax NNX `Module`. 1. Run it under a device mesh context so that JAX knows which devices to shard it to. @@ -184,7 +186,7 @@ assert sharded_model.w2.value.sharding.is_equivalent_to( ) ``` -You can view the sharding of any 1-D or 2-D array with `jax.debug.visualize_array_sharding`: +You can view the sharding of any 1-D or 2-D array with [`jax.debug.visualize_array_sharding`](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html): ```{code-cell} ipython3 print("sharded_model.dot1.kernel (None, 'model') :") @@ -199,19 +201,19 @@ The key to shard a JAX array is to call [`jax.lax.with_sharding_constraint`](htt > **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail. -You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables. +You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API if you want to explicitly shard values that are not model variables. This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section. +++ -## Load sharded model from a checkpoint +## Load a sharded model from a checkpoint -Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given. +Now you learned how to initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading a model sharded if a sharding pytree is provided. -You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. +You can generate such a sharding pytree with Flax’s [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. -Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs. +Below is an example that demonstrates using Orbax's `StandardCheckpointer` API. (Go to the [Orbax documentation site](https://orbax.readthedocs.io/en/latest/) to learn about their latest and most recommended APIs.) ```{code-cell} ipython3 import orbax.checkpoint as ocp @@ -232,33 +234,35 @@ abs_state = jax.tree.map( abs_state, nnx.get_named_sharding(abs_state, mesh) ) loaded_sharded = checkpointer.restore(path / 'checkpoint_name', - args=ocp.args.StandardRestore(abs_state)) + target=abs_state) jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value) jax.debug.visualize_array_sharding(loaded_sharded.w2.value) ``` ## Compile the training loop -Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this. +Now, after either initialization or loading the checkpoint, you have a sharded model. To carry out the compiled scaled up training, you need to shard the inputs as well. -Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. +- In the data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this. +- Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jit` compilation. +- In the example below, even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. -> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level. +> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low-level. ```{code-cell} ipython3 -# In data parallelism, the first dimension (batch) will be sharded on `data` axis. +# In data parallelism, the first dimension (batch) will be sharded on the `data` axis. data_sharding = NamedSharding(mesh, PartitionSpec('data', None)) input = jax.device_put(jnp.ones((8, 1024)), data_sharding) with mesh: output = sharded_model(input) print(output.shape) -jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None) +jax.debug.visualize_array_sharding(output) # Also sharded as `('data', None)`. ``` -Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded. - -[`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. +Now the rest of the training loop is pretty conventional - it is almost the same as the example in [Flax NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms): +- Except that the inputs and labels are also explicitly sharded. +- [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. ```{code-cell} ipython3 optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3)) # reference sharing @@ -285,7 +289,7 @@ with mesh: ## Profiling -If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: +If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: ```{code-cell} ipython3 %%timeit @@ -302,7 +306,7 @@ with mesh: JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes. -You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below. +You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot()` example below. ```{code-cell} ipython3 # The mapping from alias annotation to the device mesh. @@ -337,7 +341,7 @@ class LogicalDotReluDot(nnx.Module): return z ``` -If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). +If you didn't provide all `sharding_rule` annotations in the model definition, you can write a few lines to add it to Flax’s [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). ```{code-cell} ipython3 def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState: @@ -384,6 +388,6 @@ Choosing when to use a device or logical axis depends on how much you want to co * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. - * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. + * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. -* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*. +* **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*. diff --git a/docs_nnx/guides/gemma.ipynb b/docs_nnx/guides/gemma.ipynb new file mode 100644 index 00000000..1c59c951 --- /dev/null +++ b/docs_nnx/guides/gemma.ipynb @@ -0,0 +1,386 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example: Using Pretrained Gemma\n", + "\n", + "You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install --no-deps -U flax\n", + "! pip install jaxtyping kagglehub treescope" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Downloading the checkpoint\n", + "\n", + "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", + "\n", + "1. Visit https://www.kaggle.com/ and create an account.\n", + "2. Go to your account settings, then the 'API' section.\n", + "3. Click 'Create new token' to download your key.\n", + "\n", + "Then run the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2e7cf9f0345845f1a3edc72fa4411eb4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
(()=>{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child => { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=>{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })();
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "transformer = transformer_lib.Transformer.from_params(params)\n", + "nnx.display(transformer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, build a sampler on top of your model and your tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "cellView": "form" + }, + "outputs": [], + "source": [ + "# Create a sampler with the right param shapes.\n", + "sampler = sampler_lib.Sampler(\n", + " transformer=transformer,\n", + " vocab=vocab,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "cellView": "form" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt:\n", + "\n", + "# Python program for implementation of Bubble Sort\n", + "\n", + "def bubbleSort(arr):\n", + "Output:\n", + "\n", + " for i in range(len(arr)):\n", + " for j in range(len(arr) - i - 1):\n", + " if arr[j] > arr[j + 1]:\n", + " swap(arr, j, j + 1)\n", + "\n", + "\n", + "def swap(arr, i, j):\n", + " temp = arr[i]\n", + " arr[i] = arr[j]\n", + " arr[j] = temp\n", + "\n", + "\n", + "# Driver code\n", + "arr = [5, 2, 8, 3, 1, 9]\n", + "print(\"Unsorted array:\")\n", + "print(arr)\n", + "bubbleSort(arr)\n", + "print(\"Sorted array:\")\n", + "print(arr)\n", + "\n", + "\n", + "# Time complexity of Bubble sort O(n^2)\n", + "# where n is the length of the array\n", + "\n", + "\n", + "# Space complexity of Bubble sort O(1)\n", + "# as it only requires constant extra space for the swap operation\n", + "\n", + "\n", + "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", + "\n", + "```python\n", + "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", + "\n", + "def bubbleSort(arr):\n", + " for i in range(len(arr)):\n", + " for j in range(len(arr) - i - 1):\n", + " if arr[j] > arr[j + 1]:\n", + " swap(arr, j, j + 1)\n", + "\n", + "\n", + "def swap(\n", + "\n", + "##########\n" + ] + } + ], + "source": [ + "input_batch = [\n", + " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", + "]\n", + "\n", + "out_data = sampler(\n", + " input_strings=input_batch,\n", + " total_generation_steps=300, # number of steps performed when generating\n", + " )\n", + "\n", + "for input_string, out_string in zip(input_batch, out_data.text):\n", + " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n", + " print()\n", + " print(10*'#')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should get an implementation of bubble sort." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs_nnx/guides/gemma.md b/docs_nnx/guides/gemma.md new file mode 100644 index 00000000..e479201a --- /dev/null +++ b/docs_nnx/guides/gemma.md @@ -0,0 +1,141 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Example: Using Pretrained Gemma + +You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it. + ++++ + +## Installation + +```{code-cell} ipython3 +! pip install --no-deps -U flax +! pip install jaxtyping kagglehub treescope +``` + +## Downloading the checkpoint + +"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them: + +1. Visit https://www.kaggle.com/ and create an account. +2. Go to your account settings, then the 'API' section. +3. Click 'Create new token' to download your key. + +Then run the cell below. + +```{code-cell} ipython3 +import kagglehub +kagglehub.login() +``` + +If everything went well, you should see: +``` +Kaggle credentials set. +Kaggle credentials successfully validated. +``` + +Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models. + +```{code-cell} ipython3 +from IPython.display import clear_output + +VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"} +weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}') +ckpt_path = f'{weights_dir}/{VARIANT}' +vocab_path = f'{weights_dir}/tokenizer.model' + +clear_output() +``` + +## Python imports + +```{code-cell} ipython3 +from flax import nnx +import sentencepiece as spm +``` + +Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example. + +```{code-cell} ipython3 +import sys +import tempfile + +with tempfile.TemporaryDirectory() as tmp: + # Here we create a temporary directory and clone the flax repo + # Then we append the examples/gemma folder to the path to load the gemma modules + ! git clone https://github.com/google/flax.git {tmp}/flax + sys.path.append(f"{tmp}/flax/examples/gemma") + import params as params_lib + import sampler as sampler_lib + import transformer as transformer_lib + sys.path.pop(); +``` + +## Start Generating with Your Model + +Load and prepare your LLM's checkpoint for use with Flax. + +```{code-cell} ipython3 +:cellView: form + +# Load parameters +params = params_lib.load_and_format_params(ckpt_path) +``` + +Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library. + +```{code-cell} ipython3 +:cellView: form + +vocab = spm.SentencePieceProcessor() +vocab.Load(vocab_path) +``` + +Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release. + +```{code-cell} ipython3 +transformer = transformer_lib.Transformer.from_params(params) +nnx.display(transformer) +``` + +Finally, build a sampler on top of your model and your tokenizer. + +```{code-cell} ipython3 +:cellView: form + +# Create a sampler with the right param shapes. +sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, +) +``` + +You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent. + +```{code-cell} ipython3 +:cellView: form + +input_batch = [ + "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", +] + +out_data = sampler( + input_strings=input_batch, + total_generation_steps=300, # number of steps performed when generating + ) + +for input_string, out_string in zip(input_batch, out_data.text): + print(f"Prompt:\n{input_string}\nOutput:\n{out_string}") + print() + print(10*'#') +``` + +You should get an implementation of bubble sort. diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 16dcd865..8fdb48db 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -1,41 +1,43 @@ - Migrating from Haiku to Flax -================================= +############################ + +This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku. -This guide will showcase the differences between Haiku, Flax Linen and Flax NNX. -Both Haiku and Linen enforce a functional paradigm with stateless modules, -while NNX is a new, next-generation API that embraces the python language to -provide a more intuitive development experience. +If you are new to Flax NNX, make sure you become familiarized with `Flax NNX basics `__, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. -.. testsetup:: Haiku, Linen, NNX +Let’s start with some imports. + +.. testsetup:: Haiku, Flax NNX import jax import jax.numpy as jnp - from jax import random import optax - import flax.linen as nn from typing import Any - # TODO: double check the params output match the rendered tab-set - # TODO: change filename to haiku_linen_upgrade.rst and update other .rst file references - # TODO: make sure code lines are not too long - # TODO: make sure all code diffs are aligned -Basic Example ------------------ +Basic Module definition +======================= + +Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you: + +* First, create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function. +* Then, use ``Block`` as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. + +There are two fundamental differences between Haiku and Flax ``Module`` objects: + +* **Stateless vs. stateful**: + + * A ``haiku.Module`` instance is stateless. This means, the variables are returned from a purely functional ``Module.init()`` call and managed separately. + * A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. -To create custom Modules you subclass from a ``Module`` base class in -both Haiku and Flax. Modules can be defined inline in Haiku and Flax -Linen (using the ``@nn.compact`` decorator), whereas modules can't be -defined inline in NNX and must be defined in ``__init__``. +* **Lazy vs. eager**: + + * A ``haiku.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). + * A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). -Linen requires a ``deterministic`` argument to control whether or -not dropout is used. NNX also uses a ``deterministic`` argument -but the value can be set later using ``.eval()`` and ``.train()`` methods -that will be shown in a later code snippet. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: import haiku as hk @@ -64,33 +66,7 @@ that will be shown in a later code snippet. --- - import flax.linen as nn - - class Block(nn.Module): - features: int - - - @nn.compact - def __call__(self, x, training: bool): - x = nn.Dense(self.features)(x) - x = nn.Dropout(0.5, deterministic=not training)(x) - x = jax.nn.relu(x) - return x - - class Model(nn.Module): - dmid: int - dout: int - - - @nn.compact - def __call__(self, x, training: bool): - x = Block(self.dmid)(x, training) - x = nn.Dense(self.dout)(x) - return x - - --- - - from flax.experimental import nnx + from flax import nnx class Block(nnx.Module): def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs): @@ -114,182 +90,79 @@ that will be shown in a later code snippet. x = self.linear(x) return x -Since modules are defined inline in Haiku and Linen, the parameters -are lazily initialized, by inferring the shape of a sample input. In Flax -NNX, the module is stateful and is initialized eagerly. This means that the -input shape must be explicitly passed during module instantiation since there -is no shape inference in NNX. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - def forward(x, training: bool): - return Model(256, 10)(x, training) - - model = hk.transform(forward) - - --- - - ... - - - model = Model(256, 10) - - --- - ... +Variable creation +================= +This section is about instantiating a model and initializing its parameters. - model = Model(784, 256, 10, rngs=nnx.Rngs(0)) +* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``haiku.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. -To get the model parameters in both Haiku and Linen, you use the ``init`` method -with a ``random.key`` plus some inputs to run the model. +* In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. -In NNX, the model parameters are automatically initialized when the user -instantiates the model because the input shapes are already explicitly passed at -instantiation time. +If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API `__ (:func:`nnx.split` / :func:`nnx.merge`). -Since NNX is eager and the module is bound upon instantiation, the user can access -the parameters (and other fields defined in ``__init__`` via dot-access). On the other -hand, Haiku and Linen use lazy initialization and so the parameters can only be accessed -once the module is initialized with a sample input and both frameworks do not support -dot-access of their attributes. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: + def forward(x, training: bool): + return Model(256, 10)(x, training) + + model = hk.transform(forward) sample_x = jnp.ones((1, 784)) - params = model.init( - random.key(0), - sample_x, training=False # <== inputs - ) + params = model.init(jax.random.key(0), sample_x, training=False) assert params['model/linear']['b'].shape == (10,) assert params['model/block/linear']['w'].shape == (784, 256) - --- - - sample_x = jnp.ones((1, 784)) - variables = model.init( - random.key(0), - sample_x, training=False # <== inputs - ) - params = variables["params"] - - assert params['Dense_0']['bias'].shape == (10,) - assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256) --- ... + model = Model(784, 256, 10, rngs=nnx.Rngs(0)) - # parameters were already initialized during model instantiation + # Parameters were already initialized during model instantiation. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) -Let's take a look at the parameter structure. In Haiku and Linen, we can -simply inspect the ``params`` object returned from ``.init()``. - -To see the parameter structure in NNX, the user can call ``nnx.split`` to -generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree -denoting the structure of the model (for example usages, see -`NNX Basics `__). -``State`` objects contains all the module variables (i.e. any class that sub-classes -``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object -of all the learnable module parameters. - -.. tab-set:: - - .. tab-item:: Haiku - :sync: Haiku - - .. code-block:: python - - ... - - - { - 'model/block/linear': { - 'b': (256,), - 'w': (784, 256), - }, - 'model/linear': { - 'b': (10,), - 'w': (256, 10), - } - } - - ... - +Training step and compilation +============================= - .. tab-item:: Linen - :sync: Linen +This section covers writing a training step and compiling it using the `JAX just-in-time compilation `__. - .. code-block:: python +When compiling the training step: - ... +* Haiku uses ``@jax.jit`` - a `JAX transformation `__ - to compile a purely functional training step. +* Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transformation `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects `__). While ``jax.jit`` only accepts functions with pure stateless arguments, ``flax.nnx.jit`` allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step. +When taking gradients: - FrozenDict({ - Block_0: { - Dense_0: { - bias: (256,), - kernel: (784, 256), - }, - }, - Dense_0: { - bias: (10,), - kernel: (256, 10), - }, - }) +* Similarly, Haiku uses ``jax.grad`` (a JAX transformation for `automatic differentiation `__) to return a raw dictionary of gradients. +* Meanwhile, Flax NNX uses :meth:`flax.nnx.grad` (a Flax NNX transformation) to return the gradients of Flax NNX Modules as :class:`flax.nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX, you need to use the `split/merge API `__. +For optimizers: - .. tab-item:: NNX - :sync: NNX +* If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`flax.nnx.Optimizer` example in the `Flax basics `__ guide for a much more concise way of training and updating your model. - .. code-block:: python +Model updates during each training step: - graphdef, params, rngs = nnx.split(model, nnx.Param, nnx.RngState) +* The Haiku training step needs to return a `JAX pytree `__ of parameters as the input of the next step. +* The Flax NNX training step does not need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. +* In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``flax.nnx.BatchNorm`` stats. That is why you don't need to explicitly pass a PRNG key in at every step. Also note that you can use :meth:`flax.nnx.reseed` to reset its underlying PRNG state. - params - State({ - 'block': { - 'linear': { - 'bias': VariableState(type=Param, value=(256,)), - 'kernel': VariableState(type=Param, value=(784, 256)) - } - }, - 'linear': { - 'bias': VariableState(type=Param, value=(10,)), - 'kernel': VariableState(type=Param, value=(256, 10)) - } - }) +The dropout behavior: -During training in Haiku and Linen, you pass the parameters structure to the -``apply`` method to run the forward pass. To use dropout, we must pass in -``training=True`` and provide a ``key`` to ``apply`` in order to generate the -random dropout masks. To use dropout in NNX, we first call ``model.train()``, -which will set the dropout layer's ``deterministic`` attribute to ``False`` -(conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``). -Since the stateful NNX module already contains both the parameters and RNG key -(used for dropout), we simply need to call the module to run the forward pass. We -use ``nnx.split`` to extract the learnable parameters (all learnable parameters -subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully -update the model using ``nnx.update``. - -To compile ``train_step``, we decorate the function using ``@jax.jit`` for Haiku -and Linen, and ``@nnx.jit`` for NNX. Similar to ``@jax.jit``, ``@nnx.jit`` also -compiles functions, with the additional feature of allowing the user to compile -functions that take in NNX modules as arguments. +* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``haiku.dropout`` and make sure that random dropout only happens if ``training=True``. +* In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`flax.nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``flax.nnx.Module.train`` does in its `API reference `__. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: ... @@ -313,27 +186,6 @@ functions that take in NNX modules as arguments. --- - ... - - @jax.jit - def train_step(key, params, inputs, labels): - def loss_fn(params): - logits = model.apply( - {'params': params}, - inputs, training=True, # <== inputs - rngs={'dropout': key} - ) - return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - - grads = jax.grad(loss_fn)(params) - - - params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) - - return params - - --- - model.train() # set deterministic=False @nnx.jit @@ -347,163 +199,34 @@ functions that take in NNX modules as arguments. return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = nnx.grad(loss_fn)(model) - # we can use Ellipsis to filter out the rest of the variables - _, params, _ = nnx.split(model, nnx.Param, ...) - params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) - - nnx.update(model, params) + _, params, rest = nnx.split(model, nnx.Param, ...) + params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads) + nnx.update(model, nnx.GraphState.merge(params, rest)) -.. testcode:: Haiku, Linen +.. testcode:: Haiku :hide: - train_step(random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) + train_step(jax.random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) -.. testcode:: NNX +.. testcode:: Flax NNX :hide: sample_x = jnp.ones((1, 784)) train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) -Flax also offers a convenient ``TrainState`` dataclass to bundle the model, -parameters and optimizer, to simplify training and updating the model. In Haiku -and Linen, we simply pass in the ``model.apply`` function, initialized parameters -and optimizer as arguments to the ``TrainState`` constructor. - -In NNX, we must first call ``nnx.split`` on the model to get the -separated ``GraphDef`` and ``State`` objects. We can pass in ``nnx.Param`` to filter -all trainable parameters into a single ``State``, and pass in ``...`` for the remaining -variables. We also need to subclass ``TrainState`` to add a field for the other variables. -We can then pass in ``GraphDef.apply`` as the apply function, ``State`` as the parameters -and other variables and an optimizer as arguments to the ``TrainState`` constructor. -One thing to note is that ``GraphDef.apply`` will take in ``State``'s as arguments and -return a callable function. This function can be called on the inputs to output the -model's logits, as well as updated ``GraphDef`` and ``State`` objects. This isn't needed -for our current example with dropout, but in the next section, you will see that using -these updated objects are relevant with layers like batch norm. Notice we also use -``@jax.jit`` since we aren't passing in NNX modules into ``train_step``. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - from flax.training import train_state - - - - - - - - state = train_state.TrainState.create( - apply_fn=model.apply, - params=params, - - tx=optax.adam(1e-3) - ) - - @jax.jit - def train_step(key, state, inputs, labels): - def loss_fn(params): - logits = state.apply_fn( - params, key, - inputs, training=True # <== inputs - - ) - return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - - grads = jax.grad(loss_fn)(state.params) - - - state = state.apply_gradients(grads=grads) - - return state - - --- - - from flax.training import train_state - - - - - - - - state = train_state.TrainState.create( - apply_fn=model.apply, - params=params, - - tx=optax.adam(1e-3) - ) - - @jax.jit - def train_step(key, state, inputs, labels): - def loss_fn(params): - logits = state.apply_fn( - {'params': params}, - inputs, training=True, # <== inputs - rngs={'dropout': key} - ) - return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - - grads = jax.grad(loss_fn)(state.params) - - - state = state.apply_gradients(grads=grads) - - return state - - --- - - from flax.training import train_state - - model.train() # set deterministic=False - graphdef, params, other_variables = nnx.split(model, nnx.Param, ...) - - class TrainState(train_state.TrainState): - other_variables: nnx.State - - state = TrainState.create( - apply_fn=graphdef.apply, - params=params, - other_variables=other_variables, - tx=optax.adam(1e-3) - ) - - @jax.jit - def train_step(state, inputs, labels): - def loss_fn(params, other_variables): - logits, (graphdef, new_state) = state.apply_fn( - params, - other_variables - - )(inputs) # <== inputs - return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - - grads = jax.grad(loss_fn)(state.params, state.other_variables) - state = state.apply_gradients(grads=grads) +Handling non-parameter states +============================= - return state +Haiku makes a distinction between trainable parameters and all other data ("states") that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with ``hk.transform_with_state`` so that their ``.init()`` returns both params and states. -.. testcode:: Haiku, Linen - :hide: - - train_step(random.key(0), state, sample_x, jnp.ones((1,), dtype=jnp.int32)) - -.. testcode:: NNX - :hide: +In Flax, there isn't such a strong distinction - they are all subclasses of ``nnx.Variable`` and seen by a module as its attributes. Parameters are instances of a subclass called ``nnx.Param``, and batch stats can be of another subclass called ``nnx.BatchStat``. You can use :func:`nnx.split` to quickly extract all data of a certain variable type. - train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) - -Handling State ------------------ - -Now let's see how mutable state is handled in all three frameworks. We will take -the same model as before, but now we will replace Dropout with BatchNorm. +Let's see an example of this by taking the ``Block`` definition above but replace dropout with ``BatchNorm``. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: class Block(hk.Module): @@ -521,22 +244,12 @@ the same model as before, but now we will replace Dropout with BatchNorm. x = jax.nn.relu(x) return x - --- - - class Block(nn.Module): - features: int - - - + def forward(x, training: bool): + return Model(256, 10)(x, training) + model = hk.transform_with_state(forward) - @nn.compact - def __call__(self, x, training: bool): - x = nn.Dense(self.features)(x) - x = nn.BatchNorm( - momentum=0.99 - )(x, use_running_average=not training) - x = jax.nn.relu(x) - return x + sample_x = jnp.ones((1, 784)) + params, batch_stats = model.init(jax.random.key(0), sample_x, training=True) --- @@ -555,314 +268,32 @@ the same model as before, but now we will replace Dropout with BatchNorm. x = jax.nn.relu(x) return x -Haiku requires an ``is_training`` argument and Linen requires a -``use_running_average`` argument to control whether or not to update the -running statistics. NNX also uses a ``use_running_average`` argument -but the value can be set later using ``.eval()`` and ``.train()`` methods -that will be shown in later code snippets. - -As before, you need to pass in the input shape to construct the Module -eagerly in NNX. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - def forward(x, training: bool): - return Model(256, 10)(x, training) - - model = hk.transform_with_state(forward) - - --- - - ... - - - model = Model(256, 10) - - --- - - ... - - - model = Model(784, 256, 10, rngs=nnx.Rngs(0)) - - -To initialize both the parameters and state in Haiku and Linen, you just -call the ``init`` method as before. However, in Haiku you now get ``batch_stats`` -as a second return value, and in Linen you get a new ``batch_stats`` collection -in the ``variables`` dictionary. -Note that since ``hk.BatchNorm`` only initializes batch statistics when -``is_training=True``, we must set ``training=True`` when initializing parameters -of a Haiku model with an ``hk.BatchNorm`` layer. In Linen, we can set -``training=False`` as usual. - -In NNX, the parameters and state are already initialized upon module -instantiation. The batch statistics are of class ``nnx.BatchStat`` which -subclasses the ``nnx.Variable`` class (not ``nnx.Param`` since they aren't -learnable parameters). Calling ``nnx.split`` with no additional filter arguments -will return a state containing all ``nnx.Variable``'s by default. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - sample_x = jnp.ones((1, 784)) - params, batch_stats = model.init( - random.key(0), - sample_x, training=True # <== inputs - ) - ... - - --- - - sample_x = jnp.ones((1, 784)) - variables = model.init( - random.key(0), - sample_x, training=False # <== inputs - ) - params, batch_stats = variables["params"], variables["batch_stats"] - - --- - - ... - - - - - graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat) - - -Now, training looks very similar in Haiku and Linen as you use the same -``apply`` method to run the forward pass. In Haiku, now pass the ``batch_stats`` -as the second argument to ``apply``, and get the newly updated ``batch_stats`` -as the second return value. In Linen, you instead add ``batch_stats`` as a new -key to the input dictionary, and get the ``updates`` variables dictionary as the -second return value. To update the batch statistics, we must pass in -``training=True`` to ``apply``. - -In NNX, the training code is identical to the earlier example as the -batch statistics (which are bounded to the stateful NNX module) are updated -statefully. To update batch statistics in NNX, we first call ``model.train()``, -which will set the batchnorm layer's ``use_running_average`` attribute to ``False`` -(conversely, calling ``model.eval()`` would set ``use_running_average`` to ``True``). -Since the stateful NNX module already contains the parameters and batch statistics, -we simply need to call the module to run the forward pass. We use ``nnx.split`` to -extract the learnable parameters (all learnable parameters subclass the NNX class -``nnx.Param``) and then apply the gradients and statefully update the model using -``nnx.update``. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - ... - - @jax.jit - def train_step(params, batch_stats, inputs, labels): - def loss_fn(params, batch_stats): - logits, batch_stats = model.apply( - params, batch_stats, - None, # <== rng - inputs, training=True # <== inputs - ) - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss, batch_stats - - grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats) - - params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) - - return params, batch_stats - --- - - ... - - @jax.jit - def train_step(params, batch_stats, inputs, labels): - def loss_fn(params, batch_stats): - logits, updates = model.apply( - {'params': params, 'batch_stats': batch_stats}, - inputs, training=True, # <== inputs - mutable='batch_stats', - ) - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss, updates["batch_stats"] - - grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params, batch_stats) - - params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) - - return params, batch_stats - - --- - - model.train() # set use_running_average=False - - @nnx.jit - def train_step(model, inputs, labels): - def loss_fn(model): - logits = model( - - inputs, # <== inputs - - ) # batch statistics are updated statefully in this step - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss - - grads = nnx.grad(loss_fn)(model) - _, params, _ = nnx.split(model, nnx.Param, ...) - params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) - - nnx.update(model, params) - -.. testcode:: Haiku, Linen - :hide: - - train_step(params, batch_stats, sample_x, jnp.ones((1,), dtype=jnp.int32)) - -.. testcode:: NNX - :hide: - - train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) - -To use ``TrainState``, we subclass to add an additional field that can store -the batch statistics: - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - ... - - - class TrainState(train_state.TrainState): - batch_stats: Any - - state = TrainState.create( - apply_fn=model.apply, - params=params, - batch_stats=batch_stats, - tx=optax.adam(1e-3) - ) - - @jax.jit - def train_step(state, inputs, labels): - def loss_fn(params, batch_stats): - logits, batch_stats = state.apply_fn( - params, batch_stats, - None, # <== rng - inputs, training=True # <== inputs - ) - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss, batch_stats - - grads, batch_stats = jax.grad( - loss_fn, has_aux=True - )(state.params, state.batch_stats) - state = state.apply_gradients(grads=grads) - state = state.replace(batch_stats=batch_stats) - - return state - - --- - - ... - class TrainState(train_state.TrainState): - batch_stats: Any + model = Block(4, 4, rngs=nnx.Rngs(0)) - state = TrainState.create( - apply_fn=model.apply, - params=params, - batch_stats=batch_stats, - tx=optax.adam(1e-3) - ) + model.linear.kernel # Param(value=...) + model.batchnorm.mean # BatchStat(value=...) - @jax.jit - def train_step(state, inputs, labels): - def loss_fn(params, batch_stats): - logits, updates = state.apply_fn( - {'params': params, 'batch_stats': batch_stats}, - inputs, training=True, # <== inputs - mutable='batch_stats' - ) - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss, updates['batch_stats'] - - grads, batch_stats = jax.grad( - loss_fn, has_aux=True - )(state.params, state.batch_stats) - state = state.apply_gradients(grads=grads) - state = state.replace(batch_stats=batch_stats) - - return state - - --- - - model.train() # set deterministic=False - graphdef, params, batch_stats = nnx.split(model, nnx.Param, nnx.BatchStat) - - class TrainState(train_state.TrainState): - batch_stats: Any - - state = TrainState.create( - apply_fn=graphdef.apply, - params=params, - batch_stats=batch_stats, - tx=optax.adam(1e-3) - ) - - @jax.jit - def train_step(state, inputs, labels): - def loss_fn(params, batch_stats): - logits, (graphdef, new_state) = state.apply_fn( - params, batch_stats - )(inputs) # <== inputs - - _, batch_stats = new_state.split(nnx.Param, nnx.BatchStat) - loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() - return loss, batch_stats - - grads, batch_stats = jax.grad( - loss_fn, has_aux=True - )(state.params, state.batch_stats) - state = state.apply_gradients(grads=grads) - state = state.replace(batch_stats=batch_stats) - - return state - -.. testcode:: Haiku, Linen - :hide: - - train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) -.. testcode:: NNX - :hide: +Flax takes the difference of trainable params and other data into account. ``nnx.grad`` will only take gradients on the ``nnx.Param`` variables, thus skipping the ``batchnorm`` arrays automatically. Therefore, the training step will look the same for Flax NNX with this model. - train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) +Using multiple methods +====================== -Using Multiple Methods ------------------------ +In this section you will learn how to use multiple methods in Haiku and Flax. As an example, you will implement an auto-encoder model with three methods: ``encode``, ``decode``, and ``__call__``. -In this section we will take a look at how to use multiple methods in all three -frameworks. As an example, we will implement an auto-encoder model with three methods: -``encode``, ``decode``, and ``__call__``. +In Haiku, you need to use ``hk.multi_transform`` to explicitly define how the model shall be initialized and what methods (``encode`` and ``decode`` here) it can call. Note that you still need to define a ``__call__`` that activates both layers for the lazy initialization of all model parameters. -As before, we define the encoder and decoder layers without having to pass in the -input shape, since the module parameters will be initialized lazily using shape -inference in Haiku and Linen. In NNX, we must pass in the input shape -since the module parameters will be initialized eagerly without shape inference. +In Flax, it's simpler as you initialized parameters in ``__init__`` and the :class:`nnx.Module` methods ``encode`` and ``decode`` can be used directly. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: class AutoEncoder(hk.Module): - def __init__(self, embed_dim: int, output_dim: int, name=None): super().__init__(name=name) self.encoder = hk.Linear(embed_dim, name="encoder") @@ -879,34 +310,20 @@ since the module parameters will be initialized eagerly without shape inference. x = self.decode(x) return x - --- - - class AutoEncoder(nn.Module): - embed_dim: int - output_dim: int - - def setup(self): - self.encoder = nn.Dense(self.embed_dim) - self.decoder = nn.Dense(self.output_dim) - - def encode(self, x): - return self.encoder(x) - - def decode(self, x): - return self.decoder(x) + def forward(): + module = AutoEncoder(256, 784) + init = lambda x: module(x) + return init, (module.encode, module.decode) - def __call__(self, x): - x = self.encode(x) - x = self.decode(x) - return x + model = hk.multi_transform(forward) + params = model.init(jax.random.key(0), x=jnp.ones((1, 784))) --- class AutoEncoder(nnx.Module): - - def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs): + self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs) self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs) @@ -916,72 +333,20 @@ since the module parameters will be initialized eagerly without shape inference. def decode(self, x): return self.decoder(x) - def __call__(self, x): - x = self.encode(x) - x = self.decode(x) - return x - -As before, we pass in the input shape when instantiating the NNX module. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - def forward(): - module = AutoEncoder(256, 784) - init = lambda x: module(x) - return init, (module.encode, module.decode) - - model = hk.multi_transform(forward) - - --- - - ... - model = AutoEncoder(256, 784) - --- - ... model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0)) - - -For Haiku and Linen, ``init`` can be used to trigger the -``__call__`` method to initialize the parameters of our model, -which uses both the ``encode`` and ``decode`` method. This will -create all the necessary parameters for the model. In NNX, -the parameters are already initialized upon module instantiation. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - params = model.init( - random.key(0), - x=jnp.ones((1, 784)), - ) - - --- - - params = model.init( - random.key(0), - x=jnp.ones((1, 784)), - )['params'] - - --- - - # parameters were already initialized during model instantiation - - ... + The parameter structure is as follows: .. tab-set:: @@ -1005,27 +370,8 @@ The parameter structure is as follows: } } - .. tab-item:: Linen - :sync: Linen - - .. code-block:: python - - ... - - - FrozenDict({ - decoder: { - bias: (784,), - kernel: (256, 784), - }, - encoder: { - bias: (256,), - kernel: (784, 256), - }, - }) - - .. tab-item:: NNX - :sync: NNX + .. tab-item:: Flax NNX + :sync: Flax NNX .. code-block:: python @@ -1044,32 +390,17 @@ The parameter structure is as follows: }) -Finally, let's explore how we can employ the forward pass. In Haiku -and Linen, we use the ``apply`` function to invoke the ``encode`` -method. In NNX, we simply can simply call the ``encode`` method -directly. +To call those custom methods: + +* In Haiku, you need to decouple the `.apply` function to extract your method before calling it. +* In Flax, you can simply call the method directly. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: encode, decode = model.apply - z = encode( - params, - None, # <== rng - x=jnp.ones((1, 784)), - - ) - - --- - - ... - z = model.apply( - {"params": params}, - - x=jnp.ones((1, 784)), - method="encode", - ) + z = encode(params, None, x=jnp.ones((1, 784))) --- @@ -1078,26 +409,21 @@ directly. +Transformations +======================= - ... +Both Haiku and `Flax transformations `__ provide their own set of transforms that wrap `JAX transforms `__ in a way that they can be used with ``Module`` objects. +For more information on Flax transforms, check out the `Transforms guide `__. -Lifted Transforms ------------------ +Let's start with an example: -Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms, -that wrap JAX transformations in such a way that they can be used with Modules and sometimes -provide additional functionality. In this section we will take a look at how to use the -lifted version of ``scan`` in both Flax and Haiku to implement a simple RNN layer. +* First, define an ``RNNCell`` ``Module`` that will contain the logic for a single step of the RNN. +* Define a ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan`` (`API doc `__), the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. -To begin, we will first define a ``RNNCell`` module that will contain the logic for a single -step of the RNN. We will also define a ``initial_state`` method that will be used to initialize -the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__`` -method will be a function that takes the carry and input, and returns the new -carry and output. In this case, the carry and the output are the same. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: class RNNCell(hk.Module): @@ -1116,22 +442,6 @@ carry and output. In this case, the carry and the output are the same. --- - class RNNCell(nn.Module): - hidden_size: int - - - @nn.compact - def __call__(self, carry, x): - x = jnp.concatenate([carry, x], axis=-1) - x = nn.Dense(self.hidden_size)(x) - x = jax.nn.relu(x) - return x, x - - def initial_state(self, batch_size: int): - return jnp.zeros((batch_size, self.hidden_size)) - - --- - class RNNCell(nnx.Module): def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs) @@ -1146,24 +456,12 @@ carry and output. In this case, the carry and the output are the same. def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) -Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. -In Haiku, we will first initialze the ``RNNCell``, then use it to construct the ``carry``, -and finally use ``hk.scan`` to run the ``RNNCell`` over the input sequence. - -In Linen, we will use ``nn.scan`` to define a new temporary type that wraps -``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast -the ``params`` collection (all steps share the same parameters) and to not split the -``params`` rng stream (so all steps intialize with the same parameters), and finally -we will specify that we want scan to run over the second axis of the input and stack -the outputs along the second axis as well. We will then use this temporary type immediately -to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and -the run the ``__call__`` method which will ``scan`` over the sequence. +Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In both cases, we use the library's ``scan`` call to run the ``RNNCell`` over the input sequence. -In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined -in ``__init__`` to scan over the sequence. +The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan`__, whereas in Haiku you need to transpose the input and output explicitly. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: class RNN(hk.Module): @@ -1183,23 +481,6 @@ in ``__init__`` to scan over the sequence. --- - class RNN(nn.Module): - hidden_size: int - - - @nn.compact - def __call__(self, x): - rnn = nn.scan( - RNNCell, variable_broadcast='params', - split_rngs={'params': False}, in_axes=1, out_axes=1 - )(self.hidden_size) - carry = rnn.initial_state(x.shape[0]) - carry, y = rnn(carry, x) - - return y - - --- - class RNN(nnx.Module): def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs): self.hidden_size = hidden_size @@ -1214,107 +495,22 @@ in ``__init__`` to scan over the sequence. return y -In general, the main difference between lifted transforms between Flax and Haiku is that -in Haiku the lifted transforms don't operate over the state, that is, Haiku will handle the -``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the -transform. In Flax, the lifted transforms can operate over both variable collections and rng -streams, the user must define how different collections are treated by each transform -according to the transform's semantics. - -As before, the parameters must be initialized via ``.init()`` and passed into ``.apply()`` -to conduct a forward pass in Haiku and Linen. In NNX, the parameters are already -eagerly initialized and bound to the stateful module, and the module can be simply called -on the input to conduct a forward pass. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - x = jnp.ones((3, 12, 32)) - - def forward(x): - return RNN(64)(x) - - model = hk.without_apply_rng(hk.transform(forward)) - - params = model.init( - random.key(0), - x=jnp.ones((3, 12, 32)), - ) - - y = model.apply( - params, - x=jnp.ones((3, 12, 32)), - ) - - --- - - x = jnp.ones((3, 12, 32)) - - - - - model = RNN(64) - - params = model.init( - random.key(0), - x=jnp.ones((3, 12, 32)), - )['params'] - - y = model.apply( - {'params': params}, - x=jnp.ones((3, 12, 32)), - ) - - --- - - x = jnp.ones((3, 12, 32)) - - - - - model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0)) - - +Scan over layers +======================= +Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception. - - y = model(x) - - - ... - -The only notable change with respect to the examples in the previous sections is that -this time around we used ``hk.without_apply_rng`` in Haiku so we didn't have to -pass the ``rng`` argument as ``None`` to the ``apply`` method. - -Scan over layers ----------------- -One very important application of ``scan`` is apply a sequence of layers iteratively -over an input, passing the output of each layer as the input to the next layer. This -is very useful to reduce compilation time for big models. As an example we will create -a simple ``Block`` Module, and then use it inside an ``MLP`` Module that will apply -the ``Block`` Module ``num_layers`` times. +Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the ``Block`` ``Module`` 5 times in the top-level ``MLP`` ``Module``. In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack -of ``Block`` Modules. +of ``Block`` Modules. The same code will create 5 layers of parameters in initialization time, and run the input through them in call time. -In Linen, the definition of ``Block`` is a little different, -``__call__`` will accept and return a second dummy input/output that in both cases will -be ``None``. In ``MLP``, we will use ``nn.scan`` as in the previous example, but -by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}`` -we are telling ``nn.scan`` create different parameters for each step and slice the -``params`` collection along the first axis, effectively implementing a stack of -``Block`` Modules as in Haiku. - -In NNX, we use ``nnx.Scan.constructor()`` to define a stack of ``Block`` modules. -We can then simply call the stack of ``Block``'s, ``self.blocks``, on the input and -carry to get the forward pass output. +In Flax, model initialization and calling code are completely decoupled, so we use the :func:`nnx.vmap` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan` transform to run the model input through them. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: class Block(hk.Module): @@ -1336,7 +532,10 @@ carry to get the forward pass output. + + def __call__(self, x, training: bool): + @hk.experimental.layer_stack(self.num_layers) def stack_block(x): return Block(self.features)(x, training) @@ -1344,34 +543,12 @@ carry to get the forward pass output. stack = hk.experimental.layer_stack(self.num_layers) return stack_block(x) - --- - - class Block(nn.Module): - features: int - training: bool - - @nn.compact - def __call__(self, x, _): - x = nn.Dense(self.features)(x) - x = nn.Dropout(0.5)(x, deterministic=not self.training) - x = jax.nn.relu(x) - return x, None - - class MLP(nn.Module): - features: int - num_layers: int - - - - - @nn.compact - def __call__(self, x, training: bool): - ScanBlock = nn.scan( - Block, variable_axes={'params': 0}, split_rngs={'params': True}, - length=self.num_layers) + def forward(x, training: bool): + return MLP(64, num_layers=5)(x, training) + model = hk.transform(forward) - y, _ = ScanBlock(self.features, training)(x, None) - return y + sample_x = jnp.ones((1, 64)) + params = model.init(jax.random.key(0), sample_x, training=False) --- @@ -1380,81 +557,53 @@ carry to get the forward pass output. self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) - def __call__(self, x: jax.Array, _): + def __call__(self, x: jax.Array): # No need to require a second input! x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) - return x, None + return x # No need to return a second output! class MLP(nnx.Module): - def __init__(self, input_dim, features, num_layers, rngs): - self.blocks = nnx.Scan.constructor( - Block, length=num_layers - )(input_dim, features, rngs=rngs) - + def __init__(self, features, num_layers, rngs): + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=(0,), out_axes=0) + def create_block(rngs: nnx.Rngs): + return Block(features, features, rngs=rngs) + self.blocks = create_block(rngs) + self.num_layers = num_layers def __call__(self, x): + @nnx.split_rngs(splits=self.num_layers) + @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) + def forward(x, model): + x = model(x) + return x + return forward(x, self.blocks) - y, _ = self.blocks(x, None) - return y - -Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore -its second output. These represent the inputs/outputs per-step but they are ``None`` -because in this case we don't have any. - -Initializing each model is the same as in previous examples. In this case, -we will be specifying that we want to use ``5`` layers each with ``64`` features. -As before, we also pass in the input shape for NNX. - -.. codediff:: - :title: Haiku, Linen, NNX - :sync: - - def forward(x, training: bool): - return MLP(64, num_layers=5)(x, training) - - model = hk.transform(forward) - - sample_x = jnp.ones((1, 64)) - params = model.init( - random.key(0), - sample_x, training=False # <== inputs - ) - - --- - - ... - - - model = MLP(64, num_layers=5) - - sample_x = jnp.ones((1, 64)) - params = model.init( - random.key(0), - sample_x, training=False # <== inputs - )['params'] + model = MLP(64, num_layers=5, rngs=nnx.Rngs(0)) - --- - ... +There are a few other details to explain in the Flax example above: +* **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside. - model = MLP(64, 64, num_layers=5, rngs=nnx.Rngs(0)) + * Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform. + * Note that actually ``create_block()`` knows it needs to create 5 layers *precisely because* it sees 5 PRNG keys, because ``in_axes=(0,)`` indicates that ``vmap`` will look into the first argument's first dimension to know the size it will map over. + * Same goes for ``forward()``, which looks at the variables inside the first argument (aka. ``model``) to find out how many times it needs to scan. ``nnx.split_rngs`` here actually splits the PRNG state inside the ``model``. (If the ``Block`` ``Module`` doesn't have dropout, you don't need the :meth:`nnx.split_rngs` line as it would not consume any PRNG key anyway.) +* **Why the Block Module in Flax doesn't need to take and return that extra dummy value:** ``jax.lax.scan`` `(API doc `__ requires its function to return two inputs - the carry and the stacked output. In this case, we didn't use the latter. Flax simplifies this, so that you can now choose to ignore the second output if you set ``out_axes=nnx.Carry`` instead of the default ``(nnx.Carry, 0)``. + * This is one of the rare cases where Flax NNX transforms diverge from the `JAX transforms `__ APIs. - ... +There are more lines of code in the Flax example above, but they express what happens at each time more precisely. Since Flax transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying `JAX transforms `__ before using their `Flax NNX equivalents `__ -When using scan over layers the one thing you should notice is that all layers -are fused into a single layer whose parameters have an extra "layer" dimension on -the first axis. In this case, the shape of all parameters will start with ``(5, ...)`` -as we are using ``5`` layers. +Now inspect the variable pytree on both sides: .. tab-set:: @@ -1477,27 +626,8 @@ as we are using ``5`` layers. ... - .. tab-item:: Linen - :sync: Linen - - .. code-block:: python - - ... - - - FrozenDict({ - ScanBlock_0: { - Dense_0: { - bias: (5, 64), - kernel: (5, 64, 64), - }, - }, - }) - - ... - - .. tab-item:: NNX - :sync: NNX + .. tab-item:: Flax NNX + :sync: Flax NNX .. code-block:: python @@ -1506,30 +636,27 @@ as we are using ``5`` layers. params State({ 'blocks': { - 'scan_module': { - 'linear': { - 'bias': VariableState(type=Param, value=(5, 64)), - 'kernel': VariableState(type=Param, value=(5, 64, 64)) - } + 'linear': { + 'bias': VariableState(type=Param, value=(5, 64)), + 'kernel': VariableState(type=Param, value=(5, 64, 64)) } } }) + Top-level Haiku functions vs top-level Flax modules ------------------------------------ +======================= In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and -states. It very common to write the top-level "Module" as a function instead. +states. It is very common to write the top-level "Module" as a function instead. The Flax team recommends a more Module-centric approach that uses ``__call__`` to -define the forward function. In Linen, the corresponding accessor will be -``Module.param`` and ``Module.variable`` (go to `Handling State <#handling-state>`__ -for an explanation on collections). In NNX, the parameters and variables can +define the forward function. In Flax modules, the parameters and variables can be set and accessed as normal using regular Python class semantics. .. codediff:: - :title: Haiku, Linen, NNX + :title: Haiku, Flax NNX :sync: ... @@ -1550,29 +677,7 @@ be set and accessed as normal using regular Python class semantics. model = hk.transform_with_state(forward) - params, state = model.init(random.key(0), jnp.ones((1, 64))) - - --- - - ... - - - class FooModule(nn.Module): - @nn.compact - def __call__(self, x): - counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32)) - multiplier = self.param( - 'multiplier', nn.initializers.ones_init(), [1,], x.dtype - ) - - output = x + multiplier * counter.value - if not self.is_initializing(): # otherwise model.init() also increases it - counter.value += 1 - return output - - model = FooModule() - variables = model.init(random.key(0), jnp.ones((1, 64))) - params, counter = variables['params'], variables['counter'] + params, state = model.init(jax.random.key(0), jnp.ones((1, 64))) --- @@ -1594,4 +699,8 @@ be set and accessed as normal using regular Python class semantics. model = FooModule(rngs=nnx.Rngs(0)) - _, params, counter = nnx.split(model, nnx.Param, Counter) \ No newline at end of file + _, params, counter = nnx.split(model, nnx.Param, Counter) + + + + diff --git a/docs_nnx/guides/images/performance-graph.png b/docs_nnx/guides/images/performance-graph.png new file mode 100644 index 00000000..34fb134e Binary files /dev/null and b/docs_nnx/guides/images/performance-graph.png differ diff --git a/docs_nnx/guides/index.rst b/docs_nnx/guides/index.rst index 247db3c6..58c9e80f 100644 --- a/docs_nnx/guides/index.rst +++ b/docs_nnx/guides/index.rst @@ -8,8 +8,11 @@ Guides flax_gspmd filters_guide randomness + performance linen_to_nnx bridge_guide surgery + checkpointing jax_and_nnx_transforms haiku_to_flax + gemma diff --git a/docs_nnx/guides/linen_to_nnx.rst b/docs_nnx/guides/linen_to_nnx.rst index d0c20fd0..2c895021 100644 --- a/docs_nnx/guides/linen_to_nnx.rst +++ b/docs_nnx/guides/linen_to_nnx.rst @@ -1,13 +1,11 @@ -Evolution from Linen to NNX -########## +Evolution from Flax Linen to NNX +################################ -This guide will walk you through the differences between Flax Linen and Flax NNX -models, and side-by-side comparisions to help you migrate your code from the Linen API to NNX. +This guide demonstrates the differences between Flax Linen and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Flax Linen. -Before this guide, it's highly recommended to read through `The Basics of Flax NNX `__ to learn about the core concepts and code examples of Flax NNX. - -This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to `use NNX and Linen code together `__ +This document mainly teaches how to convert arbitrary Flax Linen code to Flax NNX. If you want to play it “safe” and convert your codebase iteratively, check out the `Use Flax NNX and Linen together via nnx.bridge `__ guide. +To get the most out of this guide, it is highly recommended to get go through `Flax NNX basics `__ document, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. .. testsetup:: Linen, NNX @@ -17,18 +15,18 @@ This guide mainly covers converting arbitratry Linen code to NNX. If you want to import flax.linen as nn from typing import Any -Basic Module Definition -========== +Basic ``Module`` definition +=========================== + +Both Flax Linen and Flax NNX use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. -Both Linen and NNX uses the ``Module`` as the default way to express a neural -library layer. There are two fundamental difference between Linen and NNX -modules: +There are two fundamental differences between Flax Linen and Flax NNX ``Module`` objects: -* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional ``.init()`` call and managed separately. NNX modules, however, owns its variables as attributes of this Python object. +* **Stateless vs. stateful**: A ``flax.linen.Module`` (``nn.Module``) instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. -* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input. +* **Lazy vs. eager**: A ``flax.linen.Module`` only allocates space to create variables when they actually see their input (lazy). A :class:`flax.nnx.Module` instance creates variables the moment they are instantiated before seeing a sample input (eager). - * Linen can use the ``@nn.compact`` decorator to define the model in a single method and use shape inference from the input sample, whereas NNX modules generally requests additional shape information to create all parameters during ``__init__`` and separately define the computation in ``__call__``. +* Flax Linen can use the ``@nn.compact`` decorator to define the model in a single method, and use shape inference from the input sample. A Flax NNX ``Module`` generally requests additional shape information to create all parameters during ``__init__`` , and separately defines the computation in the ``__call__`` method. .. codediff:: :title: Linen, NNX @@ -83,14 +81,16 @@ modules: return x -Variable Creation -========== +Variable creation +================= + +Next, let’s discuss instantiating the model and initializing its parameters: -To generate the model parameters for a Linen model, you call the ``init`` method with a ``jax.random.key`` plus some sample inputs that the model shall take. The result is a nested dictionary of JAX arrays to be carried around and maintained separately. +* To generate model parameters for a Flax Linen model, you call the ``flax.linen.Module.init`` (``nn.Module.init``) method with a ``jax.random.key`` (`doc `__) plus some sample inputs that the model shall take. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. -In NNX, the model parameters are automatically initialized when the user instantiates the model, and the variables are stored inside the module (or its submodule) as attributes. You still need to give it an RNG key, but the key will be wrapped inside a ``nnx.Rngs`` class and will be stored inside, generating more RNG keys when needed. +* In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-``Module``) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. -If you want to access NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `NNX split/merge API `__. +If you want to access Flax NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API `__ (:func:`nnx.split` / :func:`nnx.merge`). .. codediff:: :title: Linen, NNX @@ -109,28 +109,41 @@ If you want to access NNX model parameters in the stateless, dictionary-like fas model = Model(784, 256, 10, rngs=nnx.Rngs(0)) - # parameters were already initialized during model instantiation + # Parameters were already initialized during model instantiation. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) -Training Step and Compilation -========== +Training step and compilation +============================= + +Now, let’s proceed to writing a training step and compiling it using `JAX just-in-time compilation `__. Below are certain differences between Flax Linen and Flax NNX approaches. + +Compiling the training step: + +* Flax Linen uses ``@jax.jit`` - a `JAX transform `__ - to compile the training step. +* Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transform `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax NNX objects `__). So, while ``jax.jit`` only accepts functions pure stateless arguments, ``nnx.jit`` allows the arguments to be stateful NNX Modules. This greatly reduced the number of lines needed for a train step. -Now we write a training step and compile it using JAX just-in-time compilation. Note a few differences here: +Taking gradients: -* Linen uses ``@jax.jit`` to compile the training step, whereas NNX uses ``@nnx.jit``. ``jax.jit`` only accepts pure stateless arguments, but ``nnx.jit`` allows the arguments to be stateful NNX modules. This greatly reduced the number of lines needed for a train step. +* Similarly, Flax Linen uses ``jax.grad`` (a JAX transform for `automatic differentiation `__) to return a raw dictionary of gradients. +* Flax NNX uses :meth:`nnx.grad` (a Flax NNX transform) to return the gradients of NNX Modules as :class:`nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX you need to use the `Flax NNX split/merge API `__. -* Similarly, Linen uses ``jax.grad()`` to return a raw dictionary of gradients, wheras NNX can use ``nnx.grad`` to return the gradients of Modules as NNX ``State`` dictionaries. To use regular ``jax.grad`` with NNX you need to use the `NNX split/merge API `__. +Optimizers: - * If you are already using Optax optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here), check out `nnx.Optimizer example `__ for a much more concise way of training and updating your model. +* If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Flax Linen, check out the :class:`nnx.Optimizer` example in the `Flax NNX basics `__ guide for a much more concise way of training and updating your model. -* The Linen train step needs to return a tree of parameters, as the input of the next step. On the other hand, NNX's step doesn't need to return anything, because the ``model`` was already in-place-updated within ``nnx.jit``. +Model updates during each training step: -* NNX modules are stateful and automatically tracks a few things within, such as RNG keys and BatchNorm stats. That's why you don't need to explicitly pass an RNG key in on every step. Note that you can use `nnx.reseed `__ to reset its underlying RNG state. +* The Flax Linen training step needs to return a `pytree `__ of parameters as the input of the next step. +* The Flax NNX training step doesn't need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. +* In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``BatchNorm`` stats. That is why you don't need to explicitly pass an PRNG key in on every step. Also note that you can use :meth:`nnx.reseed` to reset its underlying PRNG state. -* In Linen, you need to explicitly define and pass in an argument ``training`` to control the behavior of ``nn.Dropout`` (namely, its ``deterministic`` flag, which means random dropout only happens if ``training=True``). In NNX, you can call ``model.train()`` to automatically switch ``nnx.Dropout`` to training mode. Conversely, call ``model.eval()`` to turn off training mode. You can learn more about what this API does at its `API reference `__. +Dropout behavior: + +* In Flax Linen, you need to explicitly define and pass in the ``training`` argument to control the behavior of ``flax.linen.Dropout`` (``nn.Dropout``), namely, its ``deterministic`` flag, which means random dropout only happens if ``training=True``. +* In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``nnx.Module.train`` does in its `API reference `__. .. codediff:: @@ -185,22 +198,19 @@ Now we write a training step and compile it using JAX just-in-time compilation. train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) -Collections and Variable Types -========== - -One key difference between Linen and NNX APIs is how we group variables into categories. In Linen, we use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types. - -You can freely create your own variable types as subclasses of ``nnx.Variable``. +Collections and variable types +============================== -For all the built-in Flax Linen layers and collections, NNX already created the corresponding layers and variable type. For example: +One key difference between Flax Linen and NNX APIs is how they group variables into categories. Flax Linen uses different collections, while Flax NNX, since all variables shall be top-level Python attributes, you use different variable types. - * ``nn.Dense`` creates ``params`` -> ``nnx.Linear`` creates ``nnx.Param``. +In Flax NNX, you can freely create your own variable types as subclasses of ``nnx.Variable``. - * ``nn.BatchNorm`` creates ``batch_stats`` -> ``nnx.BatchNorm`` creates ``nnx.BatchStats``. +For all the built-in Flax Linen layers and collections, Flax NNX already creates the corresponding layers and variable types. For example: - * ``linen.Module.sow()`` creates ``intermediates`` -> ``nnx.Module.sow()`` creates ``nnx.Intermediates``. - - * You can also simply get the intermediates by assigning it to a module attribute, like ``self.sowed = nnx.Intermediates(x)``. This will be similar to Linen's ``self.variable('intermediates' 'sowed', lambda: x)``. +* ``flax.linen.Dense`` (``nn.Dense``) creates ``params`` -> :class:`nnx.Linear` creates :class:nnx.Param`. +* ``flax.linen.BatchNorm`` (``nn.BatchNorm``) creates ``batch_stats`` -> :class:`nnx.BatchNorm` creates :class:`nnx.BatchStats`. +* ``flax.linen.Module.sow()`` creates ``intermediates`` -> :class:`nnx.Module.sow()` creates :class:`nnx.Intermediaries`. +* In Flax NNX, you can also simply obtain the intermediates by assigning it to an ``nnx.Module`` attribute - for example, ``self.sowed = nnx.Intermediates(x)``. This will be similar to Flax Linen's ``self.variable('intermediates' 'sowed', lambda: x)``. .. codediff:: :title: Linen, NNX @@ -257,7 +267,10 @@ For all the built-in Flax Linen layers and collections, NNX already created the model.batchnorm.mean # BatchStat(value=...) model.count # Counter(value=...) -If you want to extract certain arrays from the tree of variables, you can access the specific dictionary path in Linen, or use ``nnx.split`` to distinguish the types apart in NNX. The code below is an easier example, and check out `Filter API Guide `__ for more sophisticated filtering expressions. +If you want to extract certain arrays from the pytree of variables: + +* In Flax Linen, you can access the specific dictionary path. +* In Flax NNX, you can use :func:`nnx.split` to distinguish the types apart in Flax NNX. The code below is a simple example that splits up the variables by their types - check out the `Flax NNX Filters `__ guide for more sophisticated filtering expressions. .. codediff:: :title: Linen, NNX @@ -287,17 +300,15 @@ If you want to extract certain arrays from the tree of variables, you can access -Using Multiple Methods -========== +Using multiple methods +====================== -In this section we will take a look at how to use multiple methods in both -frameworks. As an example, we will implement an auto-encoder model with three methods: -``encode``, ``decode``, and ``__call__``. +In this section you will learn how to use multiple methods in both Flax Linen and Flax NNX. As an example, you will implement an auto-encoder model with three methods: ``encode``, ``decode``, and ``__call__``. -As before, we define the encoder and decoder layers without having to pass in the -input shape, since the module parameters will be initialized lazily using shape -inference in Linen. In NNX, we must pass in the input shape -since the module parameters will be initialized eagerly without shape inference. +Defining the encoder and decoder layers: + +* In Flax Linen, as before, define the layers without having to pass in the input shape, since the ``flax.linen.Module`` parameters will be initialized lazily using shape inference. +* In Flax NNX, you must pass in the input shape since the :class:`nnx.Module` parameters will be initialized eagerly without shape inference. .. codediff:: :title: Linen, NNX @@ -389,7 +400,10 @@ The variable structure is as follows: } }) -To call methods other than ``__call__``, in Linen you still need to use the ``apply`` API, wheras in NNX you can simply call the method directly. +To call methods other than ``__call__``: + +* In Flax Linen, you still need to use the ``apply`` API. +* In Flax NNX, you can simply call the method directly. .. codediff:: :title: Linen, NNX @@ -403,18 +417,17 @@ To call methods other than ``__call__``, in Linen you still need to use the ``ap -Lifted Transforms -========== +Transformations +=============== + +Both Flax Linen and `Flax NNX transformations `__ provide their own set of transforms that wrap `JAX transforms `__ in a way that they can be used with ``Module`` objects. -Flax APIs provide a set of transforms, which we will refer to as lifted transforms, that wrap JAX transforms in such a way that they can be used with Modules. +Most of the transforms in Flax Linen, such as ``grad`` or ``jit``, don't change much in Flax NNX. But, for example, if you try to do ``scan`` over layers, as described in the next section, the code differs by a lot. -Most of the transforms in Linen doesn't change much in NNX. See the next section (Scan over Layers) for a case in which the code differs a lot more. +Let’s start with an example: -To begin, we will first define a ``RNNCell`` module that will contain the logic for a single -step of the RNN. We will also define a ``initial_state`` method that will be used to initialize -the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__`` -method will be a function that takes the carry and input, and returns the new -carry and output. In this case, the carry and the output are the same. +* First, define an ``RNNCell`` ``Module`` that will contain the logic for a single step of the RNN. +* Define a ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan`` (`API doc `__), the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: :title: Linen, NNX @@ -450,21 +463,16 @@ carry and output. In this case, the carry and the output are the same. def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) -Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. +Next, define an ``RNN`` ``Module`` that will contain the logic for the entire RNN. -In Linen, we will use ``nn.scan`` to define a new temporary type that wraps -``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast -the ``params`` collection (all steps share the same parameters) and to not split the -``params`` rng stream (so all steps intialize with the same parameters), and finally -we will specify that we want scan to run over the second axis of the input and stack -the outputs along the second axis as well. We will then use this temporary type immediately -to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and -the run the ``__call__`` method which will ``scan`` over the sequence. +In Flax Linen: -In NNX, we define a scan function ``scan_fn`` that will use the ``RNNCell`` defined -in ``__init__`` to scan over the sequence, and explicitly set ``in_axes=(nnx.Carry, None, 1)``, -``Carry`` means that the ``carry`` argument will be the carry, ``None`` means that ``cell`` will -be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis 1. +* You will use ``flax.linen.scan`` (``nn.scan``) to define a new temporary type that wraps ``RNNCell``. During this process you will also: 1) instruct ``nn.scan`` to broadcast the ``params`` collection (all steps share the same parameters) and to not split the ``params`` PRNG stream (so that all steps initialize with the same parameters); and, finally, 2) specify that you want scan to run over the second axis of the input and stack outputs along the second axis as well. +* You will then use this temporary type immediately to create an instance of the “lifted” ``RNNCell`` and use it to create the ``carry``, and the run the ``__call__`` method, which will ``scan`` over the sequence. + +In Flax NNX: + +* You will create a ``scan`` function (``scan_fn``) that will use the ``RNNCell`` defined in ``__init__`` to scan over the sequence, and explicitly set ``in_axes=(nnx.Carry, None, 1)``. ``nnx.Carry`` means that the ``carry`` argument will be the carry, ``None`` means that ``cell`` will be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis `1`. .. codediff:: :title: Linen, NNX @@ -512,20 +520,18 @@ be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis 1 -Scan over Layers -========== - -In general, lifted transforms of Linen and NNX should look the same. However, NNX lifted transforms are designed to be closer to their lower level JAX counterparts, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. +Scan over layers +================ -Scan-over-layers is a technique in which, we want run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for big models. In this example, we will repeat the module ``Block`` for 5 times in a top-level module ``MLP``. +In general, transforms of Flax Linen and Flax NNX should look the same. However, `Flax NNX transforms `__ are designed to be closer to their lower-level `JAX counterparts `__, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. -In Linen, we apply a ``nn.scan`` upon the module ``Block`` to create a larger module ``ScanBlock`` that contains 5 ``Block``. It will automatically create a large parameter of shape ``(5, 64, 64)`` at initialization time, and at call time iterate over every ``(64, 64)`` slice for a total of 5 times, like a ``jax.lax.scan`` would. +Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the ``Block`` ``Module`` 5 times in the top-level ``MLP`` ``Module``. -But if you think closely, there actually isn't any need for ``jax.lax.scan`` operation at initialization time. What happened there is more like a ``jax.vmap`` operation - you are given a ``Block`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. +* In Flax Linen, you apply the ``flax.linen.scan`` (``nn.scan``) transforms upon the ``Block`` ``nn.Module`` to create a larger ``ScanBlock`` ``nn.Module`` that contains 5 ``Block`` ``nn.Module`` objects. It will automatically create a large parameter of shape ``(5, 64, 64)`` at initialization time, and iterate over at call time every ``(64, 64)`` slice for a total of 5 times, like a ``jax.lax.scan`` (`API doc `__) would. +* Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. +* In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan` transform to run the model input through them. -In NNX we take advantage of the fact that model initialization and running code are completely decoupled, and instead use ``nnx.vmap`` to initialize the underlying blocks, and ``nnx.scan`` to run the model input through them. - -For more information on NNX transforms, check out the `Transforms Guide `__. +For more information on Flax NNX transforms, check out the `Transforms guide `__. .. codediff:: :title: Linen, NNX @@ -595,23 +601,23 @@ For more information on NNX transforms, check out the `Transforms Guide ` line as it would not consume any PRNG key anyway.) -* **Why the `Block` in NNX doesn't need to take and return that extra dummy value?** This is a requirement from `jax.lax.scan `__. NNX simplifies this so that now you can choose to ignore the second input/output if you set ``out_axes=nnx.Carry`` instead of the default ``(nnx.Carry, 0)``. +* **Why the Block Module in Flax NNX doesn't need to take and return that extra dummy value:** This is a requirement from ``jax.lax.scan`` `(API doc `__. Flax NNX simplifies this, so that you can now choose to ignore the second output if you set ``out_axes=nnx.Carry`` instead of the default ``(nnx.Carry, 0)``. - * This is one of the rare cases in which NNX transforms diverge from JAX transforms API. + * This is one of the rare cases where Flax NNX transforms diverge from the `JAX transforms `__ APIs. -This is more lines of code, but it expresses what happened at each time more precisely. Since NNX lifted transforms become way closer to JAX APIs, it's recommended to have a good understanding of the underlying JAX transform before using their NNX versions. +There are more lines of code in the Flax NNX example above, but they express what happens at each time more precisely. Since Flax NNX transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying `JAX transforms `__ before using their `Flax NNX equivalents `__ -Now take a look at the variable tree on both sides: +Now inspect the variable pytree on both sides: .. tab-set:: @@ -648,22 +654,32 @@ Now take a look at the variable tree on both sides: }) -Using ``TrainState`` in NNX -========== +Using ``TrainState`` in Flax NNX +================================ + +Flax Linen has a convenient ``TrainState`` data class to bundle the model, +parameters and optimizer. In Flax NNX, this is not really necessary. In this section, +you will learn how to construct your Flax NNX code around ``TrainState`` for any backward +compatibility needs. + +In Flax NNX: -Flax offered a convenient ``TrainState`` dataclass to bundle the model, -parameters and optimizer. This is not really necessary in NNX era, but this section we would show how to construct your NNX code around it, for any backward compatibility needs. +* You must first call :meth:`nnx.split` on the model to get the + separate :class:`nnx.GraphDef` and :class:`nnx.State` + objects. +* You can pass in :class:`nnx.Param` to filter all trainable parameters + into a single :class:`nnx.State`, and pass in ``...`` for the remaining + variables. +* You also need to subclass ``TrainState`` to add a field for the other variables. +* Then, you can pass in :meth:`nnx.GraphDef.apply` as the ``apply`` function, + :class:`nnx.State` as the parameters and other variables, and an optimizer as arguments to the + ``TrainState`` constructor. -In NNX, we must first call ``nnx.split`` on the model to get the -separated ``GraphDef`` and ``State`` objects. We can pass in ``nnx.Param`` to filter -all trainable parameters into a single ``State``, and pass in ``...`` for the remaining -variables. We also need to subclass ``TrainState`` to add a field for the other variables. -We can then pass in ``GraphDef.apply`` as the apply function, ``State`` as the parameters -and other variables and an optimizer as arguments to the ``TrainState`` constructor. -One thing to note is that ``GraphDef.apply`` will take in ``State``'s as arguments and +Note that :class:`nnx.GraphDef.apply` will take in :class:`nnx.State` objects as arguments and return a callable function. This function can be called on the inputs to output the -model's logits, as well as updated ``GraphDef`` and ``State`` objects. Notice we also use -``@jax.jit`` since we aren't passing in NNX modules into ``train_step``. +model's logits, as well as the updated :class:`nnx.GraphDef` and :class:`nnx.State` objects. +Notice below the use of ``@jax.jit`` since you aren't passing in Flax NNX Modules into +the ``train_step``. .. codediff:: :title: Linen, NNX @@ -746,4 +762,6 @@ model's logits, as well as updated ``GraphDef`` and ``State`` objects. Notice we :hide: sample_x = jnp.ones((1, 784)) - train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) \ No newline at end of file + train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) + + diff --git a/docs_nnx/guides/performance.ipynb b/docs_nnx/guides/performance.ipynb new file mode 100644 index 00000000..a7451671 --- /dev/null +++ b/docs_nnx/guides/performance.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Performance Considerations\n", + "Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider:\n", + "\n", + "* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch).\n", + "* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead).\n", + "\n", + "\n", + "## Asynchronous dispatch\n", + "In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. \n", + "\n", + "![performance-graph](images/performance-graph.png)\n", + "\n", + "This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.\n", + "\n", + "## Lowering the Python Overhead\n", + "To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "\n", + "class Model(nnx.Module):\n", + " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", + " self.linear = nnx.Linear(din, dmid, rngs=rngs)\n", + " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", + " self.dropout = nnx.Dropout(0.2, rngs=rngs)\n", + " self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " x = nnx.relu(self.dropout(self.bn(self.linear(x))))\n", + " return self.linear_out(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing\n", + "metrics = nnx.MultiMetric(\n", + " loss=nnx.metrics.Average('loss'),\n", + ")\n", + "\n", + "@nnx.jit # <== currently slow\n", + "def train_step(model, optimizer, metrics, x, y):\n", + " def loss_fn(model):\n", + " y_pred = model(x) # call methods directly\n", + " return ((y_pred - y) ** 2).mean()\n", + "\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", + " optimizer.update(grads) # in-place updates\n", + " metrics.update(loss=loss)\n", + "\n", + " return loss\n", + " \n", + "for _ in range(10):\n", + " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n", + " loss = train_step(model, optimizer, metrics, x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n", + "optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing\n", + "metrics = nnx.MultiMetric(\n", + " loss=nnx.metrics.Average('loss'),\n", + ")\n", + "# split before training loop\n", + "graphdef, state = nnx.split((model, optimizer, metrics))\n", + "\n", + "@jax.jit # regular JAX\n", + "def train_step(graphdef, state, x, y):\n", + " # merge at the beginning of the function\n", + " model, optimizer, metrics = nnx.merge(graphdef, state)\n", + "\n", + " def loss_fn(model):\n", + " y_pred = model(x) # call methods directly\n", + " return ((y_pred - y) ** 2).mean()\n", + "\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", + " optimizer.update(grads)\n", + " metrics.update(loss=loss)\n", + "\n", + " # split at the end of the function\n", + " _, state = nnx.split((model, optimizer, metrics))\n", + "\n", + " # return new state\n", + " return state, loss\n", + "\n", + "for _ in range(10):\n", + " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n", + " state, loss = train_step(graphdef, state, x, y)\n", + "\n", + "# update objects after training\n", + "nnx.update((model, optimizer, metrics), state)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "name": "python", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs_nnx/guides/performance.md b/docs_nnx/guides/performance.md new file mode 100644 index 00000000..5d13b54b --- /dev/null +++ b/docs_nnx/guides/performance.md @@ -0,0 +1,110 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# Performance Considerations +Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider: + +* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch). +* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead). + + +## Asynchronous dispatch +In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. + +![performance-graph](images/performance-graph.png) + +This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead. + +## Lowering the Python Overhead +To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model: + +```{code-cell} +from flax import nnx +import jax +import jax.numpy as jnp +import optax + +class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) +``` + +Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects: + +```{code-cell} +model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing +metrics = nnx.MultiMetric( + loss=nnx.metrics.Average('loss'), +) + +@nnx.jit # <== currently slow +def train_step(model, optimizer, metrics, x, y): + def loss_fn(model): + y_pred = model(x) # call methods directly + return ((y_pred - y) ** 2).mean() + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) # in-place updates + metrics.update(loss=loss) + + return loss + +for _ in range(10): + x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) + loss = train_step(model, optimizer, metrics, x, y) +``` + +To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`: + +```{code-cell} +model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization +optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing +metrics = nnx.MultiMetric( + loss=nnx.metrics.Average('loss'), +) +# split before training loop +graphdef, state = nnx.split((model, optimizer, metrics)) + +@jax.jit # regular JAX +def train_step(graphdef, state, x, y): + # merge at the beginning of the function + model, optimizer, metrics = nnx.merge(graphdef, state) + + def loss_fn(model): + y_pred = model(x) # call methods directly + return ((y_pred - y) ** 2).mean() + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) + metrics.update(loss=loss) + + # split at the end of the function + _, state = nnx.split((model, optimizer, metrics)) + + # return new state + return state, loss + +for _ in range(10): + x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) + state, loss = train_step(graphdef, state, x, y) + +# update objects after training +nnx.update((model, optimizer, metrics), state) +``` + +Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`. diff --git a/docs_nnx/guides/quick_start.ipynb b/docs_nnx/guides/quick_start.ipynb deleted file mode 100644 index 1c8f2977..00000000 --- a/docs_nnx/guides/quick_start.ipynb +++ /dev/null @@ -1,568 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# NNX\n", - "\n", - "Welcome to NNX!\n", - "\n", - "NNX is an open source Python library for **N**eural **N**etwork in JA**X**. Its main feature is, much like Pytorch, allowing Python object semantics and reference sharing, which brings simplicty and familiarity, and easily crossing over into the functional world with through a set of simple APIs.\n", - "\n", - "This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using NNX and train the network for image classification on the MNIST dataset." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Installation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "! pip install -q nnx" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load the MNIST dataset\n", - "We will use the `datasets` library to load MNIST and convert it to NumPy arrays." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/cris/nnx/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Found cached dataset mnist (/home/cris/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", - "100%|██████████| 2/2 [00:00<00:00, 499.95it/s]\n" - ] - } - ], - "source": [ - "import datasets\n", - "import numpy as np\n", - "\n", - "dataset = datasets.load_dataset(\"mnist\")\n", - "X_train = np.array(np.stack(dataset[\"train\"][\"image\"]), dtype=np.uint8)[\n", - " ..., None\n", - "]\n", - "y_train = np.array(dataset[\"train\"][\"label\"], dtype=np.uint8)\n", - "X_test = np.array(np.stack(dataset[\"test\"][\"image\"]), dtype=np.uint8)[..., None]\n", - "y_test = np.array(dataset[\"test\"][\"label\"], dtype=np.uint8)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lets visualize a few examples from the dataset using matplotlib:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA58klEQVR4nO3df3zNdf/H8dcxs83P+f0rTWtcfi3UjESGvi3R1apFP7TQD3XxbYmkK2zlilQi+VkpikTzI+RSucxVXNpIiJIRlVVsFiM/Ztvn+0dfuzrn9WFn29n2PmeP++3mj/fT+3zO23q3l4/z2vvjsCzLEgAAUK4qlfcCAAAABRkAACNQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKTXr16+Xnj17Sr169SQ4OFgiIyPl3XffLe9lwYav78fvvvtORowYIV27dpXAwEBxOBxy6NCh8l4WLsLX92NiYqI4HA71KzAwsLyXViKVy3sBsLdq1SqJiYmRa6+9tmDzLV26VOLi4iQzM1NGjBhR3ktEBbJlyxaZPn26tGnTRlq3bi07duwo7yUBMnv2bKlevXrB2M/PrxxXU3IUZEPNmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmfrrX/8qx48flxo1asjLL79MQYYRYmNjpV69euW9DI/x2n+ydkdOTo6MHz9errnmGqlVq5ZUq1ZNunfvLsnJyRd9zdSpUyUkJESCgoKkR48esnv3bjVn7969EhsbK3Xq1JHAwECJiIiQVatWFbqe06dPy969eyUzM7PQudnZ2VK7du2CYiwiUrlyZalXr54EBQUV+nqYx5v3Y506daRGjRqFzoP38Ob9eIFlWZKdnS2+8tBCny7I2dnZ8uabb0pUVJRMnjxZEhMTJSMjQ6Kjo23/hv/OO+/I9OnTZdiwYfL000/L7t27pVevXnLkyJGCOXv27JEuXbrIt99+K2PGjJEpU6ZItWrVJCYmRlasWHHJ9aSmpkrr1q1lxowZha49KipK9uzZI+PGjZP9+/fLgQMHZMKECbJt2zYZPXp0kb8WKH/evB/he3xhP4aGhkqtWrWkRo0aMnDgQKe1eCXLS7399tuWiFhbt2696Jzc3Fzr3LlzTtlvv/1mNWzY0BoyZEhBdvDgQUtErKCgIOvw4cMFeUpKiiUi1ogRIwqy3r17W+Hh4dbZs2cLsvz8fKtr165WixYtCrLk5GRLRKzk5GSVJSQkFPrnO3XqlNW/f3/L4XBYImKJiFW1alVr5cqVhb4WZc/X9+OfvfTSS5aIWAcPHizS61B2fH0/Tps2zRo+fLi1aNEiKykpyYqPj7cqV65stWjRwjpx4kShrzeVT98h+/n5SZUqVUREJD8/X7KysiQ3N1ciIiJk+/btan5MTIw0bdq0YBwZGSmdO3eWtWvXiohIVlaWbNiwQfr37y8nT56UzMxMyczMlGPHjkl0dLSkpaVJenr6RdcTFRUllmVJYmJioWsPCAiQli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P8L3ePN+jI+Pl9dee03uueceueOOO2TatGmyYMECSUtLk1mzZhXxK2EOny7IIiILFiyQq666SgIDA6Vu3bpSv359+eijj+TEiRNqbosWLVTWsmXLgh/v2L9/v1iWJePGjZP69es7/UpISBARkaNHj3pk3cOHD5fVq1fL+++/L3fddZfce++9sn79emncuLHEx8d75D1Q9rx1P8I3+dJ+vOeee6RRo0ayfv36UnuP0ubTXdYLFy6UQYMGSUxMjDz55JPSoEED8fPzk0mTJsmBAweKfL38/HwRERk1apRER0fbzgkLCyvRmkX+aLaYN2+ejB49WipV+u/fmfz9/aVPnz4yY8YMycnJKfjbLbyDt+5H+CZf3I/NmjWTrKysUn2P0uTTBTkpKUlCQ0Nl+fLl4nA4CvILf1tzlZaWprJ9+/ZJ8+bNReSPBgKRPwrjDTfc4PkF/79jx45Jbm6u5OXlqd87f/685Ofn2/4ezOat+xG+ydf2o2VZcujQIenYsWOZv7en+PQ/WV/4IXHrTy3xKSkpsmXLFtv5K1eudPqMIzU1VVJSUqRPnz4iItKgQQOJioqSuXPnyi+//KJen5GRccn1uNvW36BBAwkODpYVK1ZITk5OQX7q1ClZvXq1tGrVih998kLeuh/hm7x5P9pda/bs2ZKRkSE33XRToa83ldffIb/11luybt06lcfHx0u/fv1k+fLlctttt0nfvn3l4MGDMmfOHGnTpo2cOnVKvSYsLEy6desmjz76qJw7d06mTZsmdevWdfoxo5kzZ0q3bt0kPDxcHnroIQkNDZUjR47Ili1b5PDhw7Jz586LrjU1NVV69uwpCQkJl2xc8PPzk1GjRsnYsWOlS5cuEhcXJ3l5eTJv3jw5fPiwLFy4sGhfJJQZX9yPIiInTpyQ1157TURENm/eLCJ/HF4THBwswcHBMnz4cHe+PChjvrofQ0JCZMCAARIeHi6BgYGyadMmef/996VDhw4ydOhQ979Apimv9u6SutDWf7FfP/30k5Wfn29NnDjRCgkJsQICAqyOHTtaa9asse6//34rJCSk4FoX2vpfeukla8qUKVazZs2sgIAAq3v37tbOnTvVex84cMCKi4uzGjVqZPn7+1tNmza1+vXrZyUlJRXM8cSPmSxatMiKjIy0goODraCgIKtz585O7wFz+Pp+vLAmu19/XjvM4Ov78cEHH7TatGlj1ahRw/L397fCwsKsp556ysrOzi7Jl63cOSzLR444AQDAi/n0Z8gAAHgLCjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAAt0/q+vNZp4Crsv5xdvYjLoX9CJO4ux+5QwYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxQubwXAKBoPv/8c5Vdd911Ktu0aZPKbrvtNpUdO3bMMwsDUCLcIQMAYAAKMgAABqAgAwBgAAoyAAAG8NmmrgYNGqjsySefdBqPGjXKrWtNnz5dZc8884zKTp065ebqgOKzLMutrFu3biqbMWOGyu6++27PLAwoR127dlVZYmKiyiIiIlQWGRmpsv3793tkXUXBHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6oqPj1fZyJEjncZ2jTB2/vd//1dloaGhKhswYIDKTp8+7dZ7AGWhSZMm5b0EoEDNmjVV5nrqXHh4uJozePBglTVv3lxlVapUcWsddk3ANHUBAFBBUZABADAABRkAAANQkAEAMIBPNHXZnTT08MMPF/q6pKQklc2dO1dlderUUdmcOXNUtmTJEpXFxsaq7Ny5c4WuDQDKUlhYmNO4bdu2xb5Ww4YNVdanTx+V2Z2Q1ahRo0Kv73A4VGb3ffXDDz9U2eLFi1W2d+/eQt+zLHCHDACAASjIAAAYgIIMAIABKMgAABjAJ5q60tPTVVa3bl2Vbd261Wlsd7KWu6d3tWnTRmUJCQkq+9vf/qayqVOnuvUeAFBSwcHBKvvoo49UdtVVVzmNq1atWuz3tGu6cvd769GjR53GKSkpas7mzZtV9sEHH6js0KFDbr2nKbhDBgDAABRkAAAMQEEGAMAAFGQAAAzgE01dV155ZbFe526TgZ3Jkyer7K677lJZpUr8nQdA+Zk3b57Krr32WpW58/1w0aJFKnP35MHly5er7MSJEyr74YcfnMZ2Tbu+imoBAIABKMgAABiAggwAgAEoyAAAGMAnmrree+89ld10000qu/rqq53GHTp0UHN27Njh1nuePXtWZd98843KittwBgCeYNc4ZXeS1sqVK53Gt99+e2ktCRfBHTIAAAagIAMAYAAKMgAABvCJz5DtfjD9mWeeUVlycrLT+LPPPlNzbr311kJfJyJSs2ZNlf31r39V2caNG1UGlAW7zwnh25o0aaKyqKgoldkdArJhwwansWvPjYjI5ZdfrrJ9+/apzPWJTReTnZ2tspycHLde64u4QwYAwAAUZAAADEBBBgDAABRkAAAM4BNNXXb279+vsri4OKfxqlWr1JwVK1aozK7Rq3nz5irz8/NT2bp16y61TOCSBg0apLLw8HC3XmvXuFORnpzj6xo3bqwyu+83ISEhbl3v1VdfLfGaLrBrKLTbj6mpqSr76KOPnMazZs1Sc7KyskqwOnNxhwwAgAEoyAAAGICCDACAASjIAAAYwGebuuy4nrg1ZMgQNWfJkiUqc30Kit21gJKKiYlR2ezZs1VWpUoVt65ndxLd8OHDi7wulL9WrVqpbOnSpSpr27atW9f75ZdfVHbkyBGn8eLFi91cnTZ48GCV2TV1tWnTRmWRkZFO4379+qk5H3zwgcpee+01lXnbqV/cIQMAYAAKMgAABqAgAwBgAAoyAAAGqFBNXa6WLVumsquuukpldqfJ2DXg2PHVE2XgeZ07d1ZZYGCgyuyaY+zYPXYP5rP7HrR582aVVa1aVWXHjx9X2QMPPKCyLVu2qMy1qaskXnrpJbfm2TWrRUdHO42feuopt64fFhamskcffdStdZiCO2QAAAxAQQYAwAAUZAAADEBBBgDAABW6qSs/P19lu3fvVllSUpLKBg4c6NZ7pKWlFX1hqJDsmrXs9qgdu8eNwnfk5eWp7P3331eZXQPX2bNnS2VNnrB3795Cs3/+859qzptvvqmy+++/X2Vz585V2Y4dO4qwwrLFHTIAAAagIAMAYAAKMgAABqAgAwBggArd1OWuatWqFfu1M2fOVNnjjz/uNOZRjiipqVOnlvcS4CG7du1Smd3pXT/++GNZLKfc7du3T2WTJk1S2Zo1a1Tm+r1WRGTQoEGeWFap4A4ZAAADUJABADAABRkAAANQkAEAMABNXW7o3bu3W/Pmz5+vsjvvvFNlQ4YMcRrT1AXgUipKA5e7srOz3ZpXs2bNUl6JZ3GHDACAASjIAAAYgIIMAIAB+AzZReXK+kvicDhUlpOTo7Lp06erzO6H2hMTE53Gzz//vJpj9xQU+JY6deo4jaOjo916XXp6uso+//xzj6wJ8AadOnUq7yWUCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1OXilltuUVmNGjVUtmfPHpXt2LFDZQcPHlTZTTfd5DSeMmWKmtO3b99LLRM+ICsry2n88ccfqzkdOnRQWdOmTVXWvXt3ldntUXgnu6c9NWjQQGXr168vi+WUu3bt2rk1b+3ataW8Es/iDhkAAANQkAEAMAAFGQAAA1CQAQAwAE1dxeRu88SJEydUtmTJEqfxK6+8ouaEhYWpbP/+/W6uDhXNrbfeqrI5c+aUw0pQGqpXr66yTz75RGW9evVS2caNG0tjSWXmjjvuUNkDDzygMrsnQNl9jUzGHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6XLRs2dKteZ5ssAoICFBZeHh4qb4nzLNr1y6VnT9/XmX+/v4qu/HGG1W2evVqlb333nsqy83NdRp/8MEHl1wnyt7OnTtVtnz5cpW9//77KrNr9Prmm288s7BSMGDAAKfxm2++qebYPf7Wrnntxx9/9Ni6ygJ3yAAAGICCDACAASjIAAAYgIIMAIABaOpysW/fPrfm2T3ububMmSqzO2GnX79+TuPMzEw1x9tP10HR2TXk1K9fX2UTJ05UWdWqVVVm9wjPm2++WWV5eXlO4zp16qg5c+fOVRnKzu+//66y4cOHq8zuEZ5ff/21yuwapV599VWnsacbv3r37q2yNm3aqOzFF190GlepUkXNsft/5d577y3B6szAHTIAAAagIAMAYAAKMgAABqAgAwBgAIdlWZZbEx2O0l6LEYKCglT21Vdfqczu8YgLFy5UWdOmTVXm2txgdzKS62k1pnNzG3lMRdmPdmJiYlT2+OOPq8z1MZ8i7v13stvvKSkpbq3NFBV1PzZp0kRln3/+ucquuOIKlbk2jp05c6bY67D7etSqVUtlfn5+Kjt69KjTePz48WrOu+++q7KzZ88WZYllyt39yB0yAAAGoCADAGAACjIAAAagIAMAYACautxgd9pWQkKCymJjY1VWs2ZNlS1btsxp/Mwzz6g5GRkZRVliuauoTTQwE/vxv9q1a6ey+Ph4lbk+8rVTp07Ffk+7r0dSUpLKdu/erbJ58+Y5jdPT04u9DlPQ1AUAgBehIAMAYAAKMgAABuAzZHgEn9nBJOxHmITPkAEA8CIUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADOD24xcBAEDp4Q4ZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKRWJiYnicDjUr8DAwPJeGmz4+n5csWKFREdHS5MmTSQgIEAuu+wyiY2Nld27d5f30mDD1/ejiMj69eulZ8+eUq9ePQkODpbIyEh59913y3tZJVK5vBeAS5s9e7ZUr169YOzn51eOq0FF9fXXX0vt2rUlPj5e6tWrJ7/++qu89dZbEhkZKVu2bJH27duX9xJRgaxatUpiYmLk2muvLbh5Wbp0qcTFxUlmZqaMGDGivJdYLBRkw8XGxkq9evXKexmo4MaPH6+yBx98UC677DKZPXu2zJkzpxxWhYpqxowZ0rhxY9mwYYMEBASIiMjQoUOlVatWMn/+fK8tyF77T9buyMnJkfHjx8s111wjtWrVkmrVqkn37t0lOTn5oq+ZOnWqhISESFBQkPTo0cP2n+T27t0rsbGxUqdOHQkMDJSIiAhZtWpVoes5ffq07N27VzIzM93+M1iWJdnZ2cL5Ld7PF/bjnzVo0ECqVq0qx48fL9brUb68eT9mZ2dL7dq1C4qxiEjlypWlXr16EhQUVOjrTeXTBTk7O1vefPNNiYqKksmTJ0tiYqJkZGRIdHS07NixQ81/5513ZPr06TJs2DB5+umnZffu3dKrVy85cuRIwZw9e/ZIly5d5Ntvv5UxY8bIlClTpFq1ahITEyMrVqy45HpSU1OldevWMmPGDLf/DKGhoVKrVi2pUaOGDBw40Gkt8C6+sB+PHz8uGRkZ8vXXX8uDDz4o2dnZ0rt3b7dfD3N4836MioqSPXv2yLhx42T//v1y4MABmTBhgmzbtk1Gjx5d5K+FMSwv9fbbb1siYm3duvWic3Jzc61z5845Zb/99pvVsGFDa8iQIQXZwYMHLRGxgoKCrMOHDxfkKSkplohYI0aMKMh69+5thYeHW2fPni3I8vPzra5du1otWrQoyJKTky0RsZKTk1WWkJBQ6J9v2rRp1vDhw61FixZZSUlJVnx8vFW5cmWrRYsW1okTJwp9PcqWr+/HC/7yl79YImKJiFW9enVr7NixVl5entuvR9nw9f146tQpq3///pbD4SjYj1WrVrVWrlxZ6GtN5tN3yH5+flKlShUREcnPz5esrCzJzc2ViIgI2b59u5ofExMjTZs2LRhHRkZK586dZe3atSIikpWVJRs2bJD+/fvLyZMnJTMzUzIzM+XYsWMSHR0taWlpkp6eftH1REVFiWVZkpiYWOja4+Pj5bXXXpN77rlH7rjjDpk2bZosWLBA0tLSZNasWUX8SsAE3rwfL3j77bdl3bp1MmvWLGndurWcOXNG8vLy3H49zOHN+zEgIEBatmwpsbGxsnjxYlm4cKFERETIwIED5YsvvijiV8Ig5fwXgmJz52+AlmVZ8+fPt8LDwy1/f/+Cv0mJiHXFFVcUzLnwN8Dx48er1993331WQECAZVn//RvhpX5t377dsiz7vwF6QqNGjazevXt79JoouYq4H7OysqyGDRtaI0eO9Ng14Rm+vh+HDh1qtW/f3ulfZ3JycqwWLVpYkZGRxbqmCXy6y3rhwoUyaNAgiYmJkSeffFIaNGggfn5+MmnSJDlw4ECRr5efny8iIqNGjZLo6GjbOWFhYSVac2GaNWsmWVlZpfoeKB2+th9r164tvXr1kkWLFnn00AiUDW/djzk5OTJv3jwZPXq0VKr033/k9ff3lz59+siMGTMkJyen4O7fm/h0QU5KSpLQ0FBZvny5OByOgjwhIcF2flpamsr27dsnzZs3F5E/GqxE/vgPf8MNN3h+wYWwLEsOHTokHTt2LPP3Rsn52n4UETlz5oycOHGiXN4bJeOt+/HYsWOSm5tr+1HJ+fPnJT8/32s/RvH5z5BFxOlHhlJSUmTLli2281euXOn0GUdqaqqkpKRInz59ROSPH/OIioqSuXPnyi+//KJen5GRccn1FKWt3+5as2fPloyMDLnpppsKfT3M48378ejRoyo7dOiQ/Otf/5KIiIhCXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVq5bU/+uT1d8hvvfWWrFu3TuXx8fHSr18/Wb58udx2223St29fOXjwoMyZM0fatGkjp06dUq8JCwuTbt26yaOPPirnzp2TadOmSd26dZ3a6GfOnCndunWT8PBweeihhyQ0NFSOHDkiW7ZskcOHD8vOnTsvutbU1FTp2bOnJCQkFNq4EBISIgMGDJDw8HAJDAyUTZs2yfvvvy8dOnSQoUOHuv8FQpny1f0YHh4uvXv3lg4dOkjt2rUlLS1N5s2bJ+fPn5cXXnjB/S8QypQv7kc/Pz8ZNWqUjB07Vrp06SJxcXGSl5cn8+bNk8OHD8vChQuL9kUySfl+hF18F5oWLvbrp59+svLz862JEydaISEhVkBAgNWxY0drzZo11v3332+FhIQUXOtC08JLL71kTZkyxWrWrJkVEBBgde/e3dq5c6d67wMHDlhxcXFWo0aNLH9/f6tp06ZWv379rKSkpII5JW3rf/DBB602bdpYNWrUsPz9/a2wsDDrqaeesrKzs0vyZUMp8fX9mJCQYEVERFi1a9e2KleubDVp0sS66667rF27dpXky4ZS4uv70bIsa9GiRVZkZKQVHBxsBQUFWZ07d3Z6D2/ksCyOgAIAoLz59GfIAAB4CwoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgALdP6vrzWaeAq7L+cXb2Iy6F/QiTuLsfuUMGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMULm8FwBUBA888IDK3njjDY9d3+FwqOzw4cMqmzhxospmz57tsXUAKD7ukAEAMAAFGQAAA1CQAQAwAAUZAAAD0NQFlIE+ffqozLIsj13f7lpNmjRR2fTp01XWoUMHp/HQoUM9ti6gqPz8/JzGHTt2VHOef/55ldnNu+eee1S2fv36EqyudHGHDACAASjIAAAYgIIMAIABKMgAABiApq4KrG7duk7jVq1aufW6zZs3l8ZyfNqXX36psmuuucZpfPnll5f6OipV0n8Hdz1FbM+ePWqOXTMYUFKuDYUiIgsWLHAaX3XVVcW+frNmzYr92vLAHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6fFCNGjVUNmTIEJWNHDnSady0aVO3ru96kg4KN2nSJJV98MEHTuNGjRp59D1feeUVlbk2konoRze6NvsBnjB48GCVPfPMMyq78sorncb/+c9/1JzTp0+r7OzZsypbuHBhUZZY7rhDBgDAABRkAAAMQEEGAMAAFGQAAAxAU5cXadu2rcoee+wxlUVHR6usuCfW/PDDD8V6HQq3f//+S46Lwu6/77lz54p9PcBddk2e48ePV9kTTzyhssqVdQl6+umnncavv/66mrNx40aVNWzYUGVVqlRR2fnz51VmCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1GUAu8ceDho0SGUPPfSQyoKDg1U2b948lS1btkxlgYGBTuPevXurOX//+99VhrJj15QyYcIEldmdguTuiVu///6703j16tVurg4QeeSRR1Rm19S1detWldmdIPjdd985jZcvX67mhIeHq+yNN95QmeveNh13yAAAGICCDACAASjIAAAYgIIMAIABaOoqZS1btlTZmDFjnMZ33nmnmlO1alWVLVq0SGV2DQ8rV64swgr/68MPPyzW61B6VqxYobKbbrrJo+/h+gi8bdu2efT68B133323yl544QWV/etf/1LZAw88oDK7kwATExOdxv369VNzsrOzVbZkyRKVeRvukAEAMAAFGQAAA1CQAQAwQIX+DNnu6SDPP/+8ys6cOaOyhIQEld1zzz0qc/18TkR/Pvzuu++qOXafDa9fv15l8A59+/ZVmevTu5577jk1x9/fv9jvmZGRobI5c+aobNasWcV+D1QsdocY2X2e+/jjj6vM7vPiW265RWWuPTY5OTlqTmxsrMrsPrf2NtwhAwBgAAoyAAAGoCADAGAACjIAAAaoUE1drk+/WbdunZrTvn17lVmWpbJbb71VZfXr11eZ3SELzz77rNOYZi3v1bRpU5Xdd999Khs3bpzKXJ+25a5Dhw6pbMCAASr78ccfVXb06NFivScgIlKnTh2VJSUlqWz37t0qi46OVtmUKVMKfU+7xthPP/200Nd5I+6QAQAwAAUZAAADUJABADAABRkAAANUqKau0aNHO42vuuqqYl/L4XCo7Pbbb1fZP//5z2K/B8zXsWNHldmd9lZcmzZtUtmLL76oMp7QhLJw8uRJlT388MMqs2v+snuqnZ3Bgwc7jRcvXuzm6rwfd8gAABiAggwAgAEoyAAAGICCDACAAXy2qatPnz4qe+KJJ5zGx48fV3Nq166tsm+++UZld911l8r27NlThBXCF9g9eu73339XWbVq1Yp1/W7duqksLCxMZXZ7Lz09XWWvvvpqofPsHtsIiNh/z3Q9AVFEZODAgW5db9CgQSqrSE1crrhDBgDAABRkAAAMQEEGAMAAFGQAAAzgs01dMTExKqtUyfnvH8uXL1dz3n77bZV99dVXKjtz5kzxFwef8dlnn6nM7sS2ESNGqMz1pLgmTZq49Z6NGjVyK7MTFxensu3btzuN9+3bp+aMHDlSZb/++qtb7wmIiPzjH/9QWUVu4LLDHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6tq/f7/KXB+Z2L17dzXnoYceKrU1oWJYv369W1mVKlWcxkOHDlVzJk+erLKAgIASrE67+uqrLzkWEWnXrp3KXnvtNZW99dZbKsvPzy/B6lBemjdvrrIHH3yw2Nf7+uuvVZaTk1Ps6/ki7pABADAABRkAAANQkAEAMAAFGQAAA/hsU9fu3btVlpeX5zRu0aKFmhMREaGybdu2eW5hwP9zbWixa5LaunWryuweLWrH7vSu4jbl2DV1zZ07V2X169dX2aRJk4r1nihbrk2GL7/8sppj9z3z448/Vtn111+vsmuvvVZlS5cuLcoSfR53yAAAGICCDACAASjIAAAYwGFZluXWRJdDNbzRCy+84DR+8skn1ZzffvtNZe3bt1dZenq65xbmA9zcRh5jyn60+zy3R48eKnv99ddV9v3335fKmi7w8/NTWdWqVVV26623Oo2HDx+u5nTq1Mmt93Tt0xARueWWW1Rm97mjJ1XU/VgSTz31lNPY9fuliMj8+fNVZnegzXPPPacyu/9X7L63+iJ39yN3yAAAGICCDACAASjIAAAYgIIMAIABfPZgEDsvvvii0/jOO+9Uc6644gqVVa9evdTWBO9m1zhl1yyYkpKistJu6rJrsDp58qTKFi5c6DS2e1Lahg0bVGb31Cm7r0flyhXq24xXuOGGG1Q2YcIEp3FaWpqaM2bMGJXZPbHp9OnTJVhdxcUdMgAABqAgAwBgAAoyAAAGoCADAGCACtVtkZWV5TT++eef1Ry7pi6gpOxOPXJ9cs6SJUvUnB9++MGt619zzTUqs2tGrFWrlsqefvppp3HTpk3VHLsGLjtfffWVyjZv3uzWa1E6mjdvrjK7k+NcT5MaOXKkmnPkyBGV2T0J7JFHHlHZJ598cqllQrhDBgDACBRkAAAMQEEGAMAAFGQAAAxQoZq6WrVq5TS+7rrr1BxfeIways7WrVtVtnbtWpV1795dZZMmTXIaDx48WM2xOzXLTrdu3VRWs2ZNt15bXHanfrmehicicvz48VJdBy7tvvvuU5ld86rrIxNXr17t1vUffvhhlTVu3Fhl//nPf9y6XkXGHTIAAAagIAMAYAAKMgAABqAgAwBgAIflejzLxSYa0uw0evRolZ09e1Zl7733nspmzJjhNLZ7/KLdCTbx8fEqs3vkWEXm5jbyGFP2o7tiY2NVtmDBAqdxYGBgWS3nkjIyMlT2xRdfqOzVV19VWXJycqmsqajYj/81f/58lfXt21dljRo1chq3bt1azXnjjTdUFhkZqTK777/333+/yvLz81Xmi9zdj9whAwBgAAoyAAAGoCADAGAACjIAAAbwuqaugwcPquzyyy8v1rXsHm0XGhparGtVdDTRFN0tt9ziNF65cmWxrzVt2jSV/fLLL269Ni8vz2k8derUYq/DFOzH/3rooYdUZte8euDAAadxSEiImmPXzDpx4kSVTZ48WWW5ubmXXKcvo6kLAAAvQkEGAMAAFGQAAAxAQQYAwABe9/jF8PBwlQ0ZMkRlEyZMUJlrk8vNN9/suYUBReT6eDs/P79yWgl82eHDh92ad+WVVzqNd+3apebYnba1Y8eOYq0LGnfIAAAYgIIMAIABKMgAABjA6w4GcVerVq1UdvLkSadxenp6WS3H53EQA0zCfoRJOBgEAAAvQkEGAMAAFGQAAAxAQQYAwAA+29SFskUTDUzCfoRJaOoCAMCLUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAAO4fVIXAAAoPdwhAwBgAAoyAAAGoCADAGAACjIAAAagIIvIoUOHxOFwyMsvv+yxa27cuFEcDods3LjRY9dExcB+hEnYj2XHawvy/PnzxeFwyLZt28p7KaUmPT1d+vfvL8HBwVKzZk259dZb5fvvvy/vZcFGRdiP69evl549e0q9evUkODhYIiMj5d133y3vZcGGr+/H5s2bi8PhsP3VokWL8l5esVUu7wXA3qlTp6Rnz55y4sQJ+fvf/y7+/v4ydepU6dGjh+zYsUPq1q1b3ktEBbJq1SqJiYmRa6+9VhITE8XhcMjSpUslLi5OMjMzZcSIEeW9RFQg06ZNk1OnTjllP/zwg4wdO1ZuvPHGclpVyVGQDTVr1ixJS0uT1NRU6dSpk4iI9OnTR9q1aydTpkyRiRMnlvMKUZHMmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmYqJiVHZP/7xDxERuffee8t4NZ7jtf9k7Y6cnBwZP368XHPNNVKrVi2pVq2adO/eXZKTky/6mqlTp0pISIgEBQVJjx49ZPfu3WrO3r17JTY2VurUqSOBgYESEREhq1atKnQ9p0+flr1790pmZmahc5OSkqRTp04FxVhEpFWrVtK7d29ZunRpoa+Hebx5P2ZnZ0vt2rULirGISOXKlaVevXoSFBRU6OthHm/ej3bee+89ueKKK6Rr167Fer0JfLogZ2dny5tvvilRUVEyefJkSUxMlIyMDImOjpYdO3ao+e+8845Mnz5dhg0bJk8//bTs3r1bevXqJUeOHCmYs2fPHunSpYt8++23MmbMGJkyZYpUq1ZNYmJiZMWKFZdcT2pqqrRu3VpmzJhxyXn5+fmya9cuiYiIUL8XGRkpBw4ckJMnT7r3RYAxvHU/iohERUXJnj17ZNy4cbJ//345cOCATJgwQbZt2yajR48u8tcC5c+b96Orr776Sr799lu55557ivxao1he6u2337ZExNq6detF5+Tm5lrnzp1zyn777TerYcOG1pAhQwqygwcPWiJiBQUFWYcPHy7IU1JSLBGxRowYUZD17t3bCg8Pt86ePVuQ5efnW127drVatGhRkCUnJ1siYiUnJ6ssISHhkn+2jIwMS0Ss5557Tv3ezJkzLRGx9u7de8lroGz58n60LMs6deqU1b9/f8vhcFgiYomIVbVqVWvlypWFvhZlz9f3o6uRI0daImJ98803RX6tSXz6DtnPz0+qVKkiIn/cdWZlZUlubq5ERETI9u3b1fyYmBhp2rRpwTgyMlI6d+4sa9euFRGRrKws2bBhg/Tv319OnjwpmZmZkpmZKceOHZPo6GhJS0uT9PT0i64nKipKLMuSxMTES677zJkzIiJO/zx4QWBgoNMceA9v3Y8if+zFli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P/5Zfn6+vP/++9KxY0dp3bp1kV5rGp9v6lqwYIFMmTJF9u7dK+fPny/Ir7jiCjXXrl2+ZcuWBZ/Z7t+/XyzLknHjxsm4ceNs3+/o0aNOm7Y4Lnwmd+7cOfV7Z8+edZoD7+KN+1FEZPjw4fLFF1/I9u3bpVKlP/4e379/f2nbtq3Ex8dLSkpKid8DZc9b9+Of/fvf/5b09HSfaCz06YK8cOFCGTRokMTExMiTTz4pDRo0ED8/P5k0aZIcOHCgyNfLz88XEZFRo0ZJdHS07ZywsLASrVlEpE6dOhIQECC//PKL+r0LWZMmTUr8Pihb3rofc3JyZN68eTJ69OiCYiwi4u/vL3369JEZM2ZITk5Owd0WvIO37kdXixYtkkqVKsndd9/t8WuXNZ8uyElJSRIaGirLly8Xh8NRkCckJNjOT0tLU9m+ffukefPmIiISGhoqIn98I7rhhhs8v+D/V6lSJQkPD7f9of6UlBQJDQ2VGjVqlNr7o3R46348duyY5ObmSl5envq98+fPS35+vu3vwWzeuh//7Ny5c7Js2TKJioryiZsUn/8MWUTE+tMjn1NSUmTLli2281euXOn0GUdqaqqkpKRInz59RESkQYMGEhUVJXPnzrW9e83IyLjkeorS1h8bGytbt251KsrfffedbNiwQe68885CXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVqxUcoXshb9+OfrV27Vo4fP+7VP3v8Z15/h/zWW2/JunXrVB4fHy/9+vWT5cuXy2233SZ9+/aVgwcPypw5c6RNmzbqlBeRP/45pVu3bvLoo4/KuXPnZNq0aVK3bl2nH+uYOXOmdOvWTcLDw+Whhx6S0NBQOXLkiGzZskUOHz4sO3fuvOhaU1NTpWfPnpKQkFBo48Lf/vY3eeONN6Rv374yatQo8ff3l1deeUUaNmwoI0eOdP8LhDLli/vRz89PRo0aJWPHjpUuXbpIXFyc5OXlybx58+Tw4cOycOHCon2RUGZ8cT/+2aJFiyQgIEDuuOMOt+Ybr9z6u0voQlv/xX799NNPVn5+vjVx4kQrJCTECggIsDp27GitWbPGuv/++62QkJCCa11o63/ppZesKVOmWM2aNbMCAgKs7t27Wzt37lTvfeDAASsuLs5q1KiR5e/vbzVt2tTq16+flZSUVDDHE239P/30kxUbG2vVrFnTql69utWvXz8rLS2tuF8ylKKKsB8XLVpkRUZGWsHBwVZQUJDVuXNnp/eAOSrCfjxx4oQVGBho3X777cX9MhnHYVl/+vcKAABQLnz6M2QAALwFBRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCA2yd1/fmsU8BVWf84O/sRl8J+hEnc3Y/cIQMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAaoXN4LAHBxdevWVdl9992nsttvv11l3bt3V1lSUpLKPvvsM6fxa6+9VpQlAvAQ7pABADAABRkAAANQkAEAMAAFGQAAAzgsy7LcmuhwlPZa4MXc3EYeU1H249ixY1X27LPPquzs2bMqW7JkicoGDhyosjNnzjiNQ0JC1Jzjx49fapnGYT/CJO7uR+6QAQAwAAUZAAADUJABADAABRkAAAPQ1FXKwsLCVPbOO++U6nvefPPNKvv9999Vdv78eY+9J000JdeuXTuVffzxxypr1KiRyuwavZ577jmVffPNNyr7y1/+4jSePn26mtO8eXOVPf744yr74YcfVFYeKsJ+vP7661Vm99/u559/VtmyZctUNm/ePM8sDApNXQAAeBEKMgAABqAgAwBgAAoyAAAGoKnLgzp06KCyTz75RGV2j9QrbXanNj322GNO48zMzGJfvyI00ZS2CRMmqOzvf/+7ytLS0lQWHh6uMrumvWHDhqls4sSJTuPq1atfcp0XJCcnq+yGG25w67WlrSLsx/T0dJXZNfy567fffnMaf/DBB2qOXTOY6+tERL788stir8MX0dQFAIAXoSADAGAACjIAAAagIAMAYIDK5b0AX2J3clF5NHDZGTBggMreeOMNp7Fdkw7KTmBgoFvz8vLyVObuqWszZ85UmWtD0uTJk91a2/r16916T5SOTZs2qax9+/Yq8/PzU1loaKjK6tSp4zR++OGH1Ry7zG4/Hjx4UGV2jW9bt25VmWvT4o033qjmnDx5UmVPPvmkynbt2qUyk3GHDACAASjIAAAYgIIMAIABKMgAABjAZ5u6qlWrprL69es7je2asOyaBexOrGnbtq3K+vXr59bazp496zS2ezTi/v37Vfbpp5+qbPTo0SqrUqWKW+uAWTIyMlRm1wjj6VOhFi9e7DS2OzEsKChIZcuXL/foOlA0do2adipX1t/mr776apWNGjXKaWx36lqtWrVUZtc0ZvfYWbt9e+WVV6qsuFzXLyISFxfnseuXBe6QAQAwAAUZAAADUJABADAABRkAAAP4RFNXkyZNVDZ37lyV3XzzzcW6vt0j8Ny1evVqla1Zs8Zp/Oabbxb7+nbNGcX9c8I8do9t8/SjBV955RWncc2aNdWcpKQkle3bt8+j60DpyM3NVVlqaqrK+vfv7zQODg5Wc7p06aKy6OholdmdUGjX1BUbG6uy4jalmnIqYklwhwwAgAEoyAAAGICCDACAAbzuM2S7zxdcn1okInLTTTeVxXKczJ8/X2WPPfaYyuwOAikP8fHxTmOe9uQdatSooTLXJ/WIiGRlZamsXbt2KrM7AMIVh4BUPMePH1fZunXr3Mrcdd9996ksMTHRaTx+/Hi3rjVs2LBir8MU3CEDAGAACjIAAAagIAMAYAAKMgAABjC6qWvEiBEqe/TRR1XmySeG2LH7wfoFCxaozO5pI55s4LJr5gkMDCz29ewOe0D5WbRokcomTZqksqZNm6rs2WefVZndgTOuh9KIiDRq1Mhp/N5776k5y5YtUxlQGq655hqnsd1BOHZP4Dt06FBpLanMcIcMAIABKMgAABiAggwAgAEoyAAAGMCYpq7GjRurLCoqSmWl3cBlx66B6+GHHy7zdTz44IMq69WrV7Gvl56eXpLlwMOOHj2qsk2bNqmsW7duKrvjjjtU1r59e5XZPRnthx9+cBonJCSoOefPn1cZUFJ2J3X9z//8T6Gv+/LLL4v9nnanPV522WUq+/7774v9HsXFHTIAAAagIAMAYAAKMgAABqAgAwBgAGOautq0aaOyfv36lfk67B6haHcCV2m7/PLLVXbXXXeV+TpQduwap1555RWVde/eXWWup22J2DdK5uTkqCwuLs5pXB7NLCi6iRMnqiwiIkJldif82Z1qtWXLFqex3QlZJeFwOFR22223qczf37/Qa/Xs2VNl586dU5lds2PNmjVVVq9ePZU1a9as0HV4GnfIAAAYgIIMAIABKMgAABiAggwAgAEclpuf3Nt9IO9JGRkZKqtTp45H3+Prr792Gvft21fNOX78uMo8+QhFd9k1tH344YfFvt706dNVNmbMGKexXVOEuzzdAFKY0t6PprD7f2D//v0qq1WrlsrsvkY33nijytavX1/M1ZmrIuxHT3/PdP0zlEVTlyff4+eff1ZZcnKyylzrgIjI6tWrVbZ3717PLEzc/3NyhwwAgAEoyAAAGICCDACAASjIAAAYwJiTuuxOSsnPzy/29TZv3qyyu+++22lcXo8fbNu2rdPY7uSll156qdjX//XXX1W2YcMGlZWkiQtl495771WZXQOXu3766aeSLAcGcT1ZS8T90w3tvvcdO3bMaexuI5JdM5XdiWENGjRw63qu7L6Xv/rqqypLSkoq1vVNwh0yAAAGoCADAGAACjIAAAagIAMAYABjmro8rXr16iqrXLl0/7jDhw9X2fXXX6+y0NBQp3HHjh09uo7Bgwer7JNPPvHoe8DzoqOjVfbiiy+qzO4RillZWSqze/ziDTfcoLLvvvvO3SXCIHaPYw0MDHTrtXZ7yO7xn+6waw61e+Tjjh07VNa8eXOVpaamOo3tHrWYm5vr/gK9CHfIAAAYgIIMAIABKMgAABjAmM+Q7Z5g4/pZa1G0b99eZa6fYeTl5RX7+naqVq2qsoCAAI9d/+jRoyqz+6H8f//73x57T5Sehg0bOo1feOEFNadKlSoqe+SRR1T27bffquyzzz5T2cSJE1Xm+qSbH3/8US8Wxjl9+rRbWXmIj49X2RVXXKEyu8NHli5d6jT21c+L7XCHDACAASjIAAAYgIIMAIABKMgAABjAmKYuu0YVTx9mUbNmTY9erzTNmjVLZZ9++qnKVq1aVRbLQSl4/vnnncZ2jYivv/66yt544w23ru9wOFRmd2BDhw4dnMY0daEounXrprInnnjCrdcuX75cZbNnzy7xmrwVd8gAABiAggwAgAEoyAAAGICCDACAAYxp6tq5c6fKXE9sERHp379/WSynVK1du9ZpbNfAtX79epUV92ksKH/33nuvyu644w6ncXp6upozatQot65v95Qfu1OQ7DLAXXaNsW+99ZbKatWqpTK7/f3000+r7OzZs8VcnffjDhkAAANQkAEAMAAFGQAAA1CQAQAwgDFNXZmZmSq77777VGZ3oteaNWtKZU1F9eKLL6rM7nQt18c+0qzl++z2rWvji11zzODBg1XWq1cvlV1++eUlWB3gnmHDhqksLCzMrdfaNXDZPXa3IuMOGQAAA1CQAQAwAAUZAAADUJABADCAw3Lz6B67R7kBF5T1CVDeth8///xzlV133XVO4zNnzqg5didwuevXX39Vmd2j7VybEXNycor9nqZgP5ac3aM/Bw0apDI/Pz+Vbdu2TWXXX3+9ys6dO1e8xXkZd/cjd8gAABiAggwAgAEoyAAAGICCDACAAWjqgkfQRHNp7dq1U1nfvn2dxrfddpua06lTJ5Vt3bpVZV9//bXKxo4dq7IjR45ccp2+gv1YdD169HAa250yaNfA5XryoIj9Xv7oo49KsDrvRlMXAABehIIMAIABKMgAABiAggwAgAFo6oJH0EQDk7Afi65Pnz5OY3cfa5uYmKiyCRMmeGJJPoOmLgAAvAgFGQAAA1CQAQAwQOXyXgAAoPy5PqHp5MmTak5ycrLK3nvvvVJbU0XDHTIAAAagIAMAYAAKMgAABqAgAwBgAA4GgUdwEANMwn6ESTgYBAAAL0JBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAM4PZJXQAAoPRwhwwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAAf4PkLEsNK/INnsAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# plot a 3x3 grid of MNIST digits\n", - "idxs = np.random.randint(0, len(X_train), size=(3, 3))\n", - "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", - "\n", - "for i in range(3):\n", - " for j in range(3):\n", - " axes[i, j].imshow(X_train[idxs[i, j]], cmap=\"gray\")\n", - " axes[i, j].axis(\"off\")\n", - " axes[i, j].set_title(f\"Label: {y_train[idxs[i, j]]}\")\n", - "\n", - "plt.show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Defining the Model\n", - "\n", - "To create a convolutional neural network using NNX define a `nnx.Module` subclass. We define the model by subclassing `nnx.Module` and defining a `forward` method that returns the model output. Like in PyTorch, the `__init__` method instantiates all the modules that will be used in the model. The `__call__` in this case\n", - "will define the forward computation. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "data": { - "text/plain": [ - "(1, 10)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import nnx\n", - "\n", - "\n", - "class CNN(nnx.Module):\n", - "\n", - " def __init__(self, *, rngs: nnx.Rngs):\n", - " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", - " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", - " self.linear1 = nnx.Linear(7 * 7 * 64, 256, rngs=rngs)\n", - " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", - " self.num_calls = nnx.var(\"counts\", 0)\n", - "\n", - " def __call__(self, x: jax.Array) -> jax.Array:\n", - " self.num_calls += 1\n", - " x = self.conv1(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = self.conv2(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = x.reshape((x.shape[0], -1)) # flatten\n", - " x = self.linear1(x)\n", - " x = nnx.relu(x)\n", - " x = self.linear2(x)\n", - " return x\n", - "\n", - "\n", - "model = CNN(rngs=nnx.Rngs(0))\n", - "\n", - "y = model(X_train[:1])\n", - "y.shape" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "One notable difference with other frameworks is that `__init__`, by convention, accepts a `rngs: nnx.Rngs` keyword-only argument. This object is passed around to generate PRNG keys as random state is explicit in JAX.\n", - "\n", - "One of the nice things about NNX is that Module contain their own state, are fully inspectable, and you can run them eargerly. For example, we can easily check out the kernel shape of the first `Conv` layer:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(3, 3, 1, 32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.conv1.kernel.shape" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also view the entire `State` of the model using the `.filter()` method. TODO: talk about collections." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'conv1/bias': Variable(\n", - " collection='params',\n", - " value=(32,)\n", - " ),\n", - " 'conv1/kernel': Variable(\n", - " collection='params',\n", - " value=(3, 3, 1, 32)\n", - " ),\n", - " 'conv2/bias': Variable(\n", - " collection='params',\n", - " value=(64,)\n", - " ),\n", - " 'conv2/kernel': Variable(\n", - " collection='params',\n", - " value=(3, 3, 32, 64)\n", - " ),\n", - " 'linear1/bias': Variable(\n", - " collection='params',\n", - " value=(256,)\n", - " ),\n", - " 'linear1/kernel': Variable(\n", - " collection='params',\n", - " value=(3136, 256)\n", - " ),\n", - " 'linear2/bias': Variable(\n", - " collection='params',\n", - " value=(10,)\n", - " ),\n", - " 'linear2/kernel': Variable(\n", - " collection='params',\n", - " value=(256, 10)\n", - " )\n", - "})" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.tree.map(jnp.shape, model.extract(nnx.Param))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training in eager mode\n", - "\n", - "For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n", - "\n", - "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree.map` operation. Finally, we will update the model's parameters using the `.update_state` method." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: loss=58.7676\n", - "Step 1: loss=80.0420\n", - "Step 2: loss=108.3005\n", - "Step 3: loss=26.6188\n", - "Step 4: loss=10.7236\n", - "Step 5: loss=4.7499\n", - "Step 6: loss=3.9177\n", - "Step 7: loss=2.9419\n", - "Step 8: loss=2.4733\n", - "Step 9: loss=1.8060\n" - ] - } - ], - "source": [ - "import optax\n", - "\n", - "for step in range(10):\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " def loss_fn(model: CNN):\n", - " logits = model(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n", - " params = model.extract(\"params\")\n", - " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", - "\n", - " model.update(params)\n", - " print(f\"Step {step}: loss={loss:.4f}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The loss is going down 🎉.\n", - "\n", - "### Training with the Functional API\n", - "\n", - "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", - "\n", - "In this next example we will use the `.split` method to split the model into a `params: State` and `graphdef: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `graphdef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "graphdef, params = model.split(\"params\")\n", - "\n", - "\n", - "@jax.jit\n", - "def train_step(params: nnx.State, x, y):\n", - " def loss_fn(params):\n", - " logits, _updates = graphdef.apply(params)(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " loss, grads = jax.value_and_grad(loss_fn)(params)\n", - " params = jax.tree.map(lambda w, g: w - 0.001 * g, params, grads)\n", - "\n", - " return loss, params" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Using `train_step` we can run a few more iterations and see that the loss is still going down, however, this time execution should be much faster." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: loss=1.4396\n", - "Step 1: loss=1.4127\n", - "Step 2: loss=1.8718\n", - "Step 3: loss=1.7080\n", - "Step 4: loss=1.7984\n", - "Step 5: loss=1.0350\n", - "Step 6: loss=1.2076\n", - "Step 7: loss=0.9081\n", - "Step 8: loss=0.8217\n", - "Step 9: loss=0.6687\n" - ] - } - ], - "source": [ - "for step in range(10):\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " loss, params = train_step(params, x, y)\n", - " print(f\"Step {step}: loss={loss:.4f}\")\n", - "\n", - "model.update(params)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Realistic Training using TrainState\n", - "\n", - "For real training scenarios, we recommend using `TrainState` to manage the state of your training loop. `TrainState` manages the `params` of your network along with other types of state, and uses `optax` to update the parameters according to the gradients.\n", - "\n", - "Next, we will define a `train_step` function that accepts a `TrainState` and a batch of data, and returns a new `TrainState` with updated parameters. The `apply_gradients` method will return a new `state` with the updated parameters. Flax users should be familiar with this API. In this case will will also define a `eval_step` function that will be used to evaluate the model on the test set and return some metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "state = nnx.TrainState(\n", - " graphdef,\n", - " params=params,\n", - " tx=optax.adam(0.001),\n", - ")\n", - "\n", - "\n", - "@jax.jit\n", - "def train_step(state: nnx.TrainState, x, y):\n", - " def loss_fn(params):\n", - " logits, _updates = state.apply_fn(params)(x)\n", - " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", - "\n", - " grads = jax.grad(loss_fn)(state.params)\n", - "\n", - " state = state.apply_gradients(grads=grads)\n", - "\n", - " return state\n", - "\n", - "\n", - "@jax.jit\n", - "def eval_step(state: nnx.TrainState, x, y):\n", - " logits, _updates = state.apply_fn(state.params)(x)\n", - " metrics = {\n", - " 'accuracy': jnp.mean(jnp.argmax(logits, axis=-1) == y),\n", - " 'loss': optax.softmax_cross_entropy_with_integer_labels(logits, y).mean(),\n", - " }\n", - " return metrics" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now lets create a simple training loop that runs for 1000 iterations and prints the metrics every 100 steps. At the end of training we will compute the final metrics." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 0: {'accuracy': Array(0.63119996, dtype=float32), 'loss': Array(1.1837534, dtype=float32)}\n", - "Step 100: {'accuracy': Array(0.9492, dtype=float32), 'loss': Array(0.16359854, dtype=float32)}\n", - "Step 200: {'accuracy': Array(0.9564, dtype=float32), 'loss': Array(0.14198248, dtype=float32)}\n", - "Step 300: {'accuracy': Array(0.96279997, dtype=float32), 'loss': Array(0.12757339, dtype=float32)}\n", - "Step 400: {'accuracy': Array(0.97169995, dtype=float32), 'loss': Array(0.09900841, dtype=float32)}\n", - "Step 500: {'accuracy': Array(0.96889997, dtype=float32), 'loss': Array(0.10143881, dtype=float32)}\n", - "Step 600: {'accuracy': Array(0.9745, dtype=float32), 'loss': Array(0.08513925, dtype=float32)}\n", - "Step 700: {'accuracy': Array(0.96379995, dtype=float32), 'loss': Array(0.11632324, dtype=float32)}\n", - "Step 800: {'accuracy': Array(0.97679996, dtype=float32), 'loss': Array(0.07204168, dtype=float32)}\n", - "Step 900: {'accuracy': Array(0.9765, dtype=float32), 'loss': Array(0.08413408, dtype=float32)}\n", - "Final metrics: {'accuracy': Array(0.9819, dtype=float32), 'loss': Array(0.05711861, dtype=float32)}\n" - ] - } - ], - "source": [ - "total_steps = 1000\n", - "eval_every = 100\n", - "\n", - "for step in range(total_steps):\n", - " if step % eval_every == 0:\n", - " metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", - " print(f\"Step {step}: {metrics}\")\n", - "\n", - " idxs = np.random.randint(0, len(X_train), size=32)\n", - " x = jnp.array(X_train[idxs])\n", - " y = jnp.array(y_train[idxs])\n", - "\n", - " state = train_step(state, x, y)\n", - "\n", - "metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", - "print(f\"Final metrics: {metrics}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Inference\n", - "\n", - "Finally, now that we have a trained model, lets use it to make some predictions. We will update the `model` object with the trained parameters and use it to make predictions on the test set." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABBzklEQVR4nO3de1hVVf7H8S83lXukqJiGaGrmJW9Zk5cUUUa8JOakZY1iTVTeffLalKWOllrpoJk2hdXgpKaMk6GOlk6ieSnJUrPMMDXGS5PiDS/A+v3hD2qztnA4HDgLeL+exz/Wh7X3XpxWfNnnLNb2UEopAQAAbuXp7gEAAAAKMgAARqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBggEpXkOvXry9Dhw7Nb2/ZskU8PDxky5YtLruGh4eHvPDCCy47Hyou5iNMwnx0rzItyEuXLhUPD4/8f9WqVZPGjRvLiBEj5OTJk2U5lBJLSUkpN5Nq165d8vTTT0vbtm3Fx8dHPDw83D0kIzAfy15ubq4sXbpU+vbtK/Xq1RN/f39p3ry5zJgxQy5fvuzu4bkV89F9FixYIE2bNpWqVavKLbfcIuPGjZOLFy+W+Ti8y/yKIjJt2jSJiIiQy5cvS2pqqixatEhSUlJk37594ufnV6Zj6dy5s2RlZUmVKlWKdVxKSoosXLjQdtJlZWWJt7dbXlpbKSkp8re//U1atmwpDRo0kO+++87dQzIK87HsXLp0SeLi4uSee+6RJ598UmrWrCmfffaZTJ06VT7++GP55JNPKv0vjMzHsjVx4kSZPXu2DBgwQEaPHi0HDhyQhIQE2b9/v2zYsKFsB6PKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL5/TTpw4oS5duqSUKl/jLm3Mx7J35coVtW3bNi1/8cUXlYiojRs3umFUZmA+lr2MjAzl7e2tHn30UUuekJCgRET961//KtPxGPEZcmRkpIiIpKeni4jI0KFDJSAgQA4fPiwxMTESGBgogwcPFpHrb3nNmzdPmjVrJtWqVZNatWpJfHy8nDlzxnJOpZTMmDFD6tatK35+ftK1a1fZv3+/du0bfUayc+dOiYmJkZCQEPH395eWLVvK/Pnz88e3cOFCERHLW0x57D4jSUtLk549e0pQUJAEBARIt27dZMeOHZY+eW9Zbdu2TcaNGyehoaHi7+8vsbGxcvr0aUvfzMxMOXjwoGRmZhb5+taqVUt8fX2L7IfrmI/XlcZ8rFKlitx7771aHhsbKyIi33zzTaHHV0bMx+tKYz5+9tlnkp2dLYMGDbLkee3333+/0ONdzYj3DQ4fPiwiItWrV8/PsrOzJTo6Wjp27Chz587Nf6smPj5eli5dKnFxcTJq1ChJT0+XBQsWSFpammzbtk18fHxEROT555+XGTNmSExMjMTExMiePXukR48ecvXq1SLHs3HjRundu7eEhYXJ6NGjpXbt2vLNN9/I2rVrZfTo0RIfHy8ZGRmyceNGee+994o83/79+6VTp04SFBQkEyZMEB8fH1m8eLF06dJF/vOf/8jdd99t6T9y5EgJCQmRqVOnypEjR2TevHkyYsQIWb58eX6f5ORkiYuLk8TERMsiDJQc87Hs5+OJEydERKRGjRrFPraiYz6W3ny8cuWKiIh2w5L3en7xxRdFjt+lyvJ2PO8tmU2bNqnTp0+rY8eOqffff19Vr15d+fr6quPHjyullBoyZIgSETVp0iTL8Vu3blUiopKSkiz5+vXrLfmpU6dUlSpVVK9evVRubm5+vylTpigRsbwls3nzZiUiavPmzUoppbKzs1VERIQKDw9XZ86csVznt+cq7C0ZEVFTp07Nb/fr109VqVJFHT58OD/LyMhQgYGBqnPnztrrExUVZbnW2LFjlZeXlzp79qzWNzEx0XYMN1Je3koqC8xH98/HPFFRUSooKEj7HisT5mPZz8cvvvhCiYiaPn26Jc97zQICAgo93tXc8pZ1VFSUhIaGSr169WTQoEESEBAgycnJcsstt1j6PfXUU5b2ypUrJTg4WLp37y4///xz/r+2bdtKQECAbN68WURENm3aJFevXpWRI0da3ioZM2ZMkWNLS0uT9PR0GTNmjNx0002Wrzmz2CQnJ0f+/e9/S79+/aRBgwb5eVhYmDz88MOSmpoq586dsxzzxBNPWK7VqVMnycnJkR9//DE/Gzp0qCiluDt2Aeaje+fjzJkzZdOmTfLSSy9p32NlxHwsu/nYpk0bufvuu+Xll1+WxMREOXLkiKxbt07i4+PFx8dHsrKyiv09lYRb3rJeuHChNG7cWLy9vaVWrVrSpEkT8fS0/m7g7e0tdevWtWSHDh2SzMxMqVmzpu15T506JSKS/x+mUaNGlq+HhoZKSEhIoWPLe3uoefPmjn9DhTh9+rRcunRJmjRpon2tadOmkpubK8eOHZNmzZrl57feequlX96YC34OBNdgPl7njvm4fPly+fOf/yyPPfaYVmAqK+bjdWU1H1etWiUDBw6UYcOGiYiIl5eXjBs3Tv7zn//It99+69Q5neWWgty+fXtp165doX2qVq2qTcLc3FypWbOmJCUl2R4TGhrqsjG6k5eXl22ulCrjkVQOzMfCldZ83Lhxo/zxj3+UXr16yRtvvFGic1UkzMfCuXo+3nLLLZKamiqHDh2SEydOSKNGjaR27dpSp04dady4cUmGWmxGLOpyVMOGDWXTpk3SoUOHQlcNh4eHi8j13xh/+zbI6dOni/wtqmHDhiIism/fPomKirphP0ffngkNDRU/Pz/b37QOHjwonp6eUq9ePYfOBbMwH523c+dOiY2NlXbt2smKFSuM+rvU8or5WDKNGjXKf9fgwIED8t///rfMPxI04s+eHPXggw9KTk6OTJ8+Xftadna2nD17VkSufwbj4+MjCQkJlt+a5s2bV+Q12rRpIxERETJv3rz88+X57bn8/f1FRLQ+BXl5eUmPHj1kzZo1cuTIkfz85MmTsmzZMunYsaMEBQUVOa6CivNnTygdzMdfFWc+fvPNN9KrVy+pX7++rF27lj/JcxHm469K8vMxNzdXJkyYIH5+fvLkk08W+/iSKFe/lt53330SHx8vs2bNki+//FJ69OghPj4+cujQIVm5cqXMnz9fBgwYIKGhofLMM8/IrFmzpHfv3hITEyNpaWmybt26Iv+swtPTUxYtWiR9+vSRVq1aSVxcnISFhcnBgwctO7e0bdtWRERGjRol0dHR4uXlpf0tW54ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rUozp+Z/Pjjj/l/fvD555/nj0nk+m/Ljz76qFNjqOyYj79ydD6eP39eoqOj5cyZMzJ+/Hj56KOPLF9v2LCh/O53v3NqDJUd8/FXxfn5OHr0aLl8+bK0atVKrl27JsuWLZNdu3bJO++8o31eXerKckn3jXaiKWjIkCHK39//hl9fsmSJatu2rfL19VWBgYGqRYsWasKECSojIyO/T05OjnrxxRdVWFiY8vX1VV26dFH79u3TdqIpuKw/T2pqqurevbsKDAxU/v7+qmXLliohISH/69nZ2WrkyJEqNDRUeXh4WJb4S4Fl/UoptWfPHhUdHa0CAgKUn5+f6tq1q9q+fbtDr4/dGIvzZyZ5x9v9u++++4o8vqJiPpb9fExPT7/hXJQCf3JT2TAf3fPzMTExUd15553K399fBQYGqm7duqlPPvmkyONKg4dSrBQCAMDdytVnyAAAVFQUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAO79TlzKO1UHmU9Z+zMx9RGOYjTOLofOQOGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwgMOPXwTgvOjoaC0bP368pR0ZGenSa9o9EnDNmjVatn37dkt73rx5Wp+rV6+6bFwA7HGHDACAASjIAAAYgIIMAIABPJRSyqGONp9HAXkcnEYuY/J8rF+/vpbt379fy6pVq1YGoym+lJQULXv11Ve1bPPmzWUxHKcwH2ESR+cjd8gAABiAggwAgAEoyAAAGICCDACAAVjUBZdgEc2vGjZsqGXfffedG0biOufPn9eyyZMna9nq1au17OTJk6UypsIwH2ESFnUBAFCOUJABADAABRkAAANQkAEAMACLugpo27atlv3lL3/RMn9/fy2bPn26lv373/92zcAMxyKaX/n6+mrZzJkztSwzM9PS3rRpk9bnz3/+s5YlJiZqWUxMjJY1a9ZMy1q3bq1lrmS3y1efPn1K9Zp2mI8wCYu6AAAoRyjIAAAYgIIMAIABKMgAABiARV0FxMXFadmbb77p0LHZ2dla1qNHDy379NNPiz8ww7GIxjx2i8vsFli9/vrrlnZISIjT1/z++++1rF27dlpmt/OXKzEfCzdp0qQis6CgIIfO9dprr2nZ7NmztczRHdsK/sxs06aN1uell15y6FymYFEXAADlCAUZAAADUJABADAABRkAAAOwqKuA2267TcvWr1+vZfXr19cyu9dow4YNWma3q1J5xyKa8qvgohm7HcOCg4OdPv9jjz2mZUuXLnX6fI5gPv5q8ODBWvb3v/9dyy5cuGBpZ2VlaX3sdii0Wzz41VdfadnEiRO1zM/PT8vefffdIq/ZqlUrh65pChZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7ocMHLkSC2z253G7jXKyMjQsnr16rlmYAZhEU3F8eSTT2rZwoULnT7f8uXLtezhhx92+nyOYD7+auvWrVrWoUMHLSu4u5bdbl533HGHlr3yyitaFh0drWVXr17Vsl9++UXLateurWUF/elPf9Kyt956q8jj3IVFXQAAlCMUZAAADEBBBgDAABRkAAAM4O3uAQCo2CIjI909BDjAkd3YDhw4oGWxsbFaZrcw8NVXX9UyRxZwVSbcIQMAYAAKMgAABqAgAwBgAD5DBoAK7NChQ1pmtzHIsGHDLO2kpCStT2pqqpZdvnxZy3bv3l2cIeL/cYcMAIABKMgAABiAggwAgAEoyAAAGIBFXQBK1fz58909hErtpZde0rJevXppWc2aNS3tTz/9VOszbtw4LVu7dq2WlfbTr06ePFmq53cX7pABADAABRkAAANQkAEAMAAFGQAAA7CoC4DF448/7vSxdrs27d+/vyTDQQl99913WtajRw8t+/jjjy3tm2++Wetj98SmqVOnatnnn39enCEWym7HMLuFZBUBd8gAABiAggwAgAEoyAAAGICCDACAAVjUVcoWLFjg7iHACf7+/loWEhKiZfHx8VrWoEGDUhlTYVJSUrQsMzNTy8LDw7Xs3nvvtbRbtGjh9Dh++OEHLfvXv/7l9PlQOvbu3atljz76qKX91ltvaX1q166tZcHBwVrWrVs3p8d2/vx5S/svf/mL0+cqb7hDBgDAABRkAAAMQEEGAMAAFGQAAAzAoq5SdvHiRXcPAUWoWrWqlv3973/Xsr59+5bFcJwyaNAgdw9BROwX/XTo0EHLdu7cqWXZ2dmlMiY4Zt26dZZ2kyZNtD5jxozRsoceekjL/Pz8tOzWW291aBw//fSTpX3w4EGHjqsIuEMGAMAAFGQAAAxAQQYAwAAUZAAADMCiLgd4eHg4lHl66r/f2PWDWbKysrRMKeWGkZR/do/s+/TTT7Vs4sSJWjZ37txSGROcU3DHLBGR6dOnO5TZLeTbunWrQ9e1e1xkZcEdMgAABqAgAwBgAAoyAAAGoCADAGAAFnU5wG6Bj12Wk5OjZRcuXCiVMcFc586d07J//OMfTp3rkUce0TK7R0MCJinJYtbt27e7cCTlC3fIAAAYgIIMAIABKMgAABiAz5BdyO7JTomJiW4YCcrKl19+qWX9+vXTsmPHjjl1/q+++krLWrdurWWPP/64U+e388MPP2hZWlqalkVGRmpZSEiIy8aB8qskG+v079/f0p49e3ZJh1NucIcMAIABKMgAABiAggwAgAEoyAAAGIBFXUAJ3HbbbVq2YsUKl53fbgGXj4+Py84vInL06FFL+/XXX9f6vPbaa1rWp08fLRs3bpxD1zx+/LiDo0NlExQU5O4huA13yAAAGICCDACAASjIAAAYgIIMAIABWNTlQsuXL3f3EOCEb775RssaNWqkZV5eXloWEBCgZe3bt3fNwErBTz/9pGW///3vLe1vv/3WoXN9+OGHDmVAcezevdvdQ3Ab7pABADAABRkAAANQkAEAMAAFGQAAA7CoywGrVq3SspkzZ2qZh4dHWQwHLtasWTMtmzt3rpY99thjWuaOXYWuXbumZWfPntWyf/zjH1q2ePFiLXN0ERdQFvbt2+fuIbgNd8gAABiAggwAgAEoyAAAGICCDACAAVjU5YCMjAwty83N1bK6deuWxXBQBp555hktW7RokZZt2LBByyIiIpy6ZmpqqpbZ7Xz1448/atnKlSuduiZQGkqywLUyL47lDhkAAANQkAEAMAAFGQAAA1CQAQAwgIdSSjnUsRJ/0G4nMzPToX7BwcGlPBIzODiNXIb5iMIwH92rQ4cOWrZ161aHjt21a5elfc8997hkTO7k6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwADt1Oenzzz/Xsnbt2rlhJABQcbjjkaam4A4ZAAADUJABADAABRkAAAPwGbKTpk+frmUTJ050w0gAoOJwdNOliog7ZAAADEBBBgDAABRkAAAMQEEGAMAAPO0JLsHTdWAS5qN73XbbbVqWlJSkZXfddZeWxcbGWtpr1qxx3cDchKc9AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyigUmYjzAJi7oAAChHKMgAABiAggwAgAEoyAAAGMDhRV0AAKD0cIcMAIABKMgAABiAggwAgAEoyAAAGKDSFeT69evL0KFD89tbtmwRDw8P2bJli8uu4eHhIS+88ILLzoeKi/kIkzAf3atMC/LSpUvFw8Mj/1+1atWkcePGMmLECDl58mRZDqXEUlJSytWkys3NlUWLFkmrVq3E19dXqlevLpGRkbJ37153D81tmI/ud+3aNbnjjjvEw8ND5s6d6+7huBXz0X1M+fnoXaZX+3/Tpk2TiIgIuXz5sqSmpsqiRYskJSVF9u3bJ35+fmU6ls6dO0tWVpZUqVKlWMelpKTIwoULbSddVlaWeHu75aW9oWHDhklSUpL88Y9/lBEjRsjFixclLS1NTp065e6huR3z0X0SEhLk6NGj7h6GUZiPZc+Un49ueVV69uwp7dq1ExGRxx9/XKpXry6vvvqqrFmzRh566CHbYy5evCj+/v4uH4unp6dUq1bNped09flKasWKFfLOO+/I6tWrJTY21t3DMQ7z0T1OnTol06ZNk4kTJ8rzzz/v7uEYg/lYtkz6+WjEZ8iRkZEiIpKeni4iIkOHDpWAgAA5fPiwxMTESGBgoAwePFhErr+1MG/ePGnWrJlUq1ZNatWqJfHx8XLmzBnLOZVSMmPGDKlbt674+flJ165dZf/+/dq1b/QZyc6dOyUmJkZCQkLE399fWrZsKfPnz88f38KFC0VELG8x5bH7jCQtLU169uwpQUFBEhAQIN26dZMdO3ZY+uS9ZbVt2zYZN26chIaGir+/v8TGxsrp06ctfTMzM+XgwYOSmZlZ5Ov76quvSvv27SU2NlZyc3Pl4sWLRR5TmTEfryut+Zhn0qRJ0qRJE3nkkUccPqYyYj5eVxl+PhpRkA8fPiwiItWrV8/PsrOzJTo6WmrWrClz586VBx54QERE4uPjZfz48dKhQweZP3++xMXFSVJSkkRHR8u1a9fyj3/++eflueeekzvvvFPmzJkjDRo0kB49ejj0Ym/cuFE6d+4sBw4ckNGjR8srr7wiXbt2lbVr1+aPoXv37iIi8t577+X/u5H9+/dLp06dZO/evTJhwgR57rnnJD09Xbp06SI7d+7U+o8cOVL27t0rU6dOlaeeeko+/PBDGTFihKVPcnKyNG3aVJKTkwv9Xs6dOye7du2Su+66S6ZMmSLBwcESEBAgDRo0kBUrVhT5WlRGzEcrV87HPLt27ZJ33nlH5s2bx6MLi8B8tKrQPx9VGUpMTFQiojZt2qROnz6tjh07pt5//31VvXp15evrq44fP66UUmrIkCFKRNSkSZMsx2/dulWJiEpKSrLk69evt+SnTp1SVapUUb169VK5ubn5/aZMmaJERA0ZMiQ/27x5sxIRtXnzZqWUUtnZ2SoiIkKFh4erM2fOWK7z23MNHz5c3ejlExE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HkGfPnj1KRFT16tVVrVq11Ouvv66SkpJU+/btlYeHh1q3bl2hx1dkzMeyn495427fvr166KGHlFJKpaenKxFRc+bMKfLYioz5yM9HtxTkgv/Cw8PV+vXr8/vlTbgff/zRcvyoUaNUcHCwOnXqlDp9+rTlX0BAgHr88ceVUkotW7ZMiYjlnEpdn4hFTbjdu3crEVGvvfZaod+LoxMuOztb+fn5qQcffFDrFx8frzw9PVVmZqbl9VmxYoWl3+rVq5WIqL179xY6Jjuffvpp/uu8Y8eO/Pz8+fOqRo0aqkOHDsU+Z0XBfLQqi/molFJvv/228vX1VUePHlVKUZDzMB+tKuPPR7cs6lq4cKE0btxYvL29pVatWtKkSRPx9LS+e+7t7S1169a1ZIcOHZLMzEypWbOm7XnzVsT9+OOPIiLSqFEjy9dDQ0MlJCSk0LHlvT3UvHlzx7+hQpw+fVouXbokTZo00b7WtGlTyc3NlWPHjkmzZs3y81tvvdXSL2/MBT8HcoSvr6+IiERERMjdd9+dnwcEBEifPn3k73//u2RnZxu36rEsMR+vK4v5eO7cOZk8ebKMHz9e6tWrV+zjKwPm43WV8eejW34Kt2/fPn8V4Y1UrVpVm4S5ublSs2ZNSUpKsj0mNDTUZWN0Jy8vL9tcOfFgrjp16oiISK1atbSv1axZU65duyYXL16U4ODgYp+7omA+Fs6V83Hu3Lly9epVGThwoBw5ckRERI4fPy4i13+gHjlyROrUqVPsP7OpSJiPhavIPx/L1W1Rw4YNZdOmTdKhQ4f832zshIeHi8j13xgbNGiQn58+fbrI36IaNmwoIiL79u2TqKioG/ZzdCFKaGio+Pn5ybfffqt97eDBg+Lp6Vmqdwp16tSR2rVry08//aR9LSMjQ6pVqyaBgYGldv2KjPlYfEePHpUzZ85Y7njyzJw5U2bOnClpaWnSqlWrUhtDRcV8LD7Tfj4ascraUQ8++KDk5OTI9OnTta9lZ2fL2bNnRUQkKipKfHx8JCEhwfJb07x584q8Rps2bSQiIkLmzZuXf748vz1X3t/8FexTkJeXl/To0UPWrFmTf0cgInLy5ElZtmyZdOzYUYKCgoocV0HFWdY/cOBAOXbsmGzcuDE/+/nnn2XNmjUSGRmp/aYNxzAff+XofBw1apQkJydb/i1evFhErv+5THJyskRERBT7+mA+/lZ5/flYru6Q77vvPomPj5dZs2bJl19+KT169BAfHx85dOiQrFy5UubPny8DBgyQ0NBQeeaZZ2TWrFnSu3dviYmJkbS0NFm3bp3UqFGj0Gt4enrKokWLpE+fPtKqVSuJi4uTsLAwOXjwoOzfv182bNggIiJt27YVkes/YKKjo8XLy0sGDRpke84ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rVITk6WuLg4SUxMtOw9a2fy5MmyYsUKeeCBB2TcuHESHBwsb7zxhly7dk1mzpzp1PXBfPwtR+djmzZtpE2bNpYs7wdxs2bNpF+/fk5dH8zH3yq3Px/LcgVZ3iq53bt3F9pvyJAhyt/f/4ZfX7JkiWrbtq3y9fVVgYGBqkWLFmrChAkqIyMjv09OTo568cUXVVhYmPL19VVdunRR+/btU+Hh4YWuIsyTmpqqunfvrgIDA5W/v79q2bKlSkhIyP96dna2GjlypAoNDVUeHh6WFYVSYFm/UteX10dHR6uAgADl5+enunbtqrZv3+7Q62M3xuL8mYlSSh0+fFjFxsaqoKAg5evrqyIjI9WuXbscOraiYj66bz7+Fqusr2M+8vPRQyknPgkHAAAuxYeHAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYwOGduniIOApT1n/OznxEYZiPMImj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAAA4/frEiCgoK0rKvvvpKy0aOHKllH374YamMCQDcoeAjAvfu3av1mThxopZt2LCh1MZU2XCHDACAASjIAAAYgIIMAIABPFTBDw5u1NHDo7THUuqaNWtmaSckJGh9unbtqmXvvvuulg0ZMsR1A6sAHJxGLlMR5iNKD/Ox+FauXGlp9+/fX+tz4MABLWvfvr2WZWVluW5gFYCj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAlWpjkObNm1vadgu47Jw6dao0hgNUCnPmzNGyvn37almTJk3KYji4gczMzCL7nD9/XstycnJKYziVEnfIAAAYgIIMAIABKMgAABiAggwAgAEq1aIuR/zvf//TMrsnnADQjR07VsuefvppLfPx8dGyNm3aWNp79uxx3cDgEidOnNCyq1evumEkFRN3yAAAGICCDACAASjIAAAYgIIMAIABKuyiLi8vLy17+OGHizxu586dWpabm+uSMRVHSEiIlp07d07L2CWnYqtWrZqWFXyMqIjIF198UarjsPv/adq0aVo2efJkLbN79NyyZcu0zO7Rfig7hw4dcvcQKj3ukAEAMAAFGQAAA1CQAQAwAAUZAAADVNhFXYMHD9aygo98++WXX7Q+Tz75ZKmN6UZuv/12Ldu8ebOW2S04mzFjhpZ9/vnnrhkY3O7ZZ5/VsuzsbC1z5aIuuwVcL7zwgpZNmjTJofPZLeD605/+pGWXL1926HwoHR07drS0PTw8tD7bt28vq+FUStwhAwBgAAoyAAAGoCADAGAACjIAAAaosIu6srKyiuxjtxtWYGBgaQzHomXLlpb2ggULtD61a9fWsvvvv1/LYmJitKzg4gwRkV27dhVniHCDdu3aadn48eO1bNasWaU6jsjISC2bMmWKQ8euXLlSy4YNG6Zl165dK/7A4DJ2O8D16tXL0v7LX/6i9dm0aVOpjQncIQMAYAQKMgAABqAgAwBgAAoyAAAGqBCLum655RYti4uLK/K4o0ePapnd7l2O8vTUf78ZMGCAlg0fPtzS7tSpk9PXtNtVyW4nJ5hv9OjRWlalSpVSv25sbKylvWrVKoeOs9sRbuDAgS4ZE1zHbqFqfHx8kceNGDFCyxYuXOiSMcEed8gAABiAggwAgAEoyAAAGKBCfIb82GOPaVnPnj2LPK7gZ7kiIidOnHB6HM2bN9ey5cuXO30+R+Tm5mrZnj17SvWacI0aNWpY2nYbgyiltOznn392+pp33HGHlr399ttFXvPQoUNa1qFDB6fHgbJz6dIlLXvjjTe0rODmL+vWrdP6lOTnI4rGHTIAAAagIAMAYAAKMgAABqAgAwBggAqxqMtu8w07q1evtrRTUlKcvqafn5+WJSUlOXUuu0U6O3bs0LIWLVpoWXh4uJY9+OCDWrZixQqnxgbXqFWrlpZ99NFHlnaTJk20Plu3btWyJUuWOHTNoKAgLVu8eHGR/Y4cOaL16devn5bxxKbyoW7dulr20EMPadm+ffssbbuNX0zRuHFjh/p99913pTwS1+IOGQAAA1CQAQAwAAUZAAADUJABADBAuVvUZff0G7vdh+wU3LHGbkciOz4+Plq2dOlSLbPbqctOwd1uRo4cqfX54IMPHMrq1aunZT/88IND40DpsHsC1zvvvKNlrVu3trSvXr2q9Zk+fbqWObqY6oEHHtAyu921Cv5/cODAAa1P9+7dtcxul7jytoimMnj66ae17P7779eyw4cPW9qXL18utTEVJjo62tKePXu21sdugauHh4eWJSYmatmwYcNKMLrSxR0yAAAGoCADAGAACjIAAAagIAMAYIByt6jLbhcqu0U0X3/9tZY988wzRZ7fbtHYsmXLtMxuwYyjRo8ebWnbLday8/3332uZ3UIGu13EUHbee+89LbNbFFXQm2++qWV2/83tFuS0adNGy+Lj44u8pp2YmBiHsm3btmlZ586dnbomSs/Bgwe1LDIyUsvWrl1raS9atKjUxpTHbrfEggsPAwICtD6OLsjt1auXcwNzE+6QAQAwAAUZAAADUJABADAABRkAAAOUu0VdGzdu1DK7HYPsHj138uTJIs9/6623allJFnC9//77WuboIq6C7rrrLi2zW9TVv39/Lfv000+duiaKr127dlpm99+poBEjRmjZ8OHDXTKm4ozDUdu3b3fZuVB6du7cqWV2i2MffvhhS9vVi7rsdtyy2zmu4CKuPXv2aH08PfV7Sbuf3f/85z+LMUL34w4ZAAADUJABADAABRkAAANQkAEAMEC5W9Rlt+OR3Qf86enpRZ6rTp06WpacnOzcwEQkLS1Ny15++WUtK7gIzW4nmilTpmiZ3S5Ido/i27BhQ6HjhOvUqFFDy0qys5CrjruRs2fPalnBnZzsdrmzW4i4ZcsWVw0LpSgwMFDL7B716ayqVatq2aOPPqpldjslZmdna9nChQst7QsXLmh9/vCHP2iZ3Y5kBXdFNB13yAAAGICCDACAASjIAAAYgIIMAIAByt2iri5dujjUb8eOHUX22bRpk5Y1bdrUofOnpqZq2ciRI7UsLCxMy/r27WtpjxkzRusTEhLi0DhefPFFLVu3bp1Dx6Lkfv75Zy373e9+p2U+Pj5Fnstu1yK7RyjefffdWma3WGvOnDla9u6772pZRkZGkWND+WW3c9yVK1e07LvvvnPq/Hbz0W6xlp2pU6dq2axZsyxtu0WGDRo00LITJ05o2eXLlx0ahym4QwYAwAAUZAAADEBBBgDAAB7KwZ0HXPmUmJJYvXq1lsXGxmrZpUuXtOz8+fOWdq1atZweR8Fz3YjdH+U7y27jkXvuuUfLXPlH/45y9QYWRTFlPrpSlSpVtMzu87PbbrtNy+bOnatlEydOdM3AyiHmY/H17t3b0l67dq3Wx+6JTXabHz3xxBNadu+992rZV199pWU9e/a0tI8cOaL1sXuCld04TOHofOQOGQAAA1CQAQAwAAUZAAADUJABADBAudsYZNWqVVpmt6jLz8/PocxZrlysZcduIcOzzz6rZe5YwIXSMX36dC2zW8C1bds2LSu4mQKQJzw8XMvsFljt37+/yHN9//33Wma3wNXbWy8t9erV07KCT74TETlz5oylbbfJiN2GPBUBd8gAABiAggwAgAEoyAAAGICCDACAAcrdTl12Pv74Yy2LjIx0w0gcc+jQIUv7r3/9q9ZnwYIFZTUcl2BnpOL7wx/+YGmvWLFC62P3uto9AerNN9903cAqAOZj4aKiorSs4C5cdk/Da9SokZb997//1TK7OXrs2DEtCw0N1bK4uDhL+5dfftH62I3NZOzUBQBAOUJBBgDAABRkAAAMQEEGAMAAFWJRV0hIiJbZLSoYO3aspV2zZk2nr2m3Q1ZmZqaWHThwQMsGDBhgaVeEXWdYRFM4f39/LSv4aMX69etrfVauXKlldnP77NmzTo+tImI+Fl/BXbMcfQ3tfhb+8MMPWmb3Gi1btkzLCi4uqwi7EbKoCwCAcoSCDACAASjIAAAYgIIMAIABKsSiLkfdfffdlrbdY+y8vLwcOpfd4+7sdtw6ceKEg6Mr31hEU7hXXnlFy8aMGWNpX7p0SevTs2dPLUtNTXXZuCoq5mPxFdzt7YEHHtD6BAcHa9mHH36oZf369XPZuCoCFnUBAFCOUJABADAABRkAAANQkAEAMEClWtSF0sMiml917txZy5YuXapl4eHhlvbw4cO1Pm+88YbLxlWZMB9hEhZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyi+dWGDRu0LCoqSsvefvttS/tPf/pTqY2psmE+wiQs6gIAoByhIAMAYAAKMgAABuAzZLgEn9nBJMxHmITPkAEAKEcoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGMDhxy8CAIDSwx0yAAAGoCADAGAACjIAAAagIAMAYIBKV5Dr168vQ4cOzW9v2bJFPDw8ZMuWLS67hoeHh7zwwgsuOx8qLuYjTMJ8dK8yLchLly4VDw+P/H/VqlWTxo0by4gRI+TkyZNlOZQSS0lJKZeT6tq1a3LHHXeIh4eHzJ07193DcSvmo/usWLFC7rnnHrnpppukevXqct9998lHH33k7mG5FfPRPd5880257777pFatWlK1alWJiIiQuLg4OXLkSJmPxbvMrygi06ZNk4iICLl8+bKkpqbKokWLJCUlRfbt2yd+fn5lOpbOnTtLVlaWVKlSpVjHpaSkyMKFC20nXVZWlnh7u+WlLVJCQoIcPXrU3cMwCvOxbCUkJMioUaOkV69e8tJLL8nly5dl6dKl0rt3b1m1apX079/f3UN0K+Zj2UpLS5OIiAjp27evhISESHp6urz55puydu1a2bt3r9SpU6fsBqPKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL1+JnTx5UgUHB6tp06YpEVFz5sxx95DcivnoHo0aNVJ33XWXys3Nzc8yMzNVQECA6tu3rxtH5l7MR3N8/vnnSkTUrFmzyvS6RnyGHBkZKSIi6enpIiIydOhQCQgIkMOHD0tMTIwEBgbK4MGDRUQkNzdX5s2bJ82aNZNq1apJrVq1JD4+Xs6cOWM5p1JKZsyYIXXr1hU/Pz/p2rWr7N+/X7v2jT4j2blzp8TExEhISIj4+/tLy5YtZf78+fnjW7hwoYiI5S2mPHafkaSlpUnPnj0lKChIAgICpFu3brJjxw5Ln7y3rLZt2ybjxo2T0NBQ8ff3l9jYWDl9+rSlb2Zmphw8eFAyMzMdeYlFRGTSpEnSpEkTeeSRRxw+pjJiPl5XWvPx3LlzUrNmTcsY88bh6+tb5PGVDfPxutL++fhb9evXFxGRs2fPOnW8s4x43+Dw4cMiIlK9evX8LDs7W6Kjo6Vjx44yd+7c/Ldq4uPjZenSpRIXFyejRo2S9PR0WbBggaSlpcm2bdvEx8dHRESef/55mTFjhsTExEhMTIzs2bNHevToIVevXi1yPBs3bpTevXtLWFiYjB49WmrXri3ffPONrF27VkaPHi3x8fGSkZEhGzdulPfee6/I8+3fv186deokQUFBMmHCBPHx8ZHFixdLly5d5D//+Y/cfffdlv4jR46UkJAQmTp1qhw5ckTmzZsnI0aMkOXLl+f3SU5Olri4OElMTLQswriRXbt2yTvvvCOpqamW/zmgYz6W7nzs0qWLfPDBB5KQkCB9+vSRy5cvS0JCgmRmZsro0aOLHH9lw3ws/Z+PIiL/+9//JCcnR44ePSrTpk0TEZFu3bo5dKzLlOXteN5bMps2bVKnT59Wx44dU++//76qXr268vX1VcePH1dKKTVkyBAlImrSpEmW47du3apERCUlJVny9evXW/JTp06pKlWqqF69elneFpsyZYoSEctbMps3b1YiojZv3qyUUio7O1tFRESo8PBwdebMGct1fnuuwt6SERE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HUHDc7du3Vw899JBSSqn09HTeslbMR3fNx5MnT6pu3bopEcn/V6NGDbV9+/Yij63ImI/umY95qlatmj8fq1evrv761786fKyruOUt66ioKAkNDZV69erJoEGDJCAgQJKTk+WWW26x9Hvqqacs7ZUrV0pwcLB0795dfv755/x/bdu2lYCAANm8ebOIiGzatEmuXr0qI0eOtNwNjhkzpsixpaWlSXp6uowZM0Zuuukmy9ecubPMycmRf//739KvXz9p0KBBfh4WFiYPP/ywpKamyrlz5yzHPPHEE5ZrderUSXJycuTHH3/Mz4YOHSpKKYd++1u6dKl8/fXX8vLLLxd7/JUB87Fs56Ofn580adJEhgwZIitXrpS3335bwsLCpH///vL9998X+3uqaJiPZTsf86xbt05SUlLklVdekVtvvVUuXrxY7O+npNzylvXChQulcePG4u3tLbVq1ZImTZqIp6f1dwNvb2+pW7euJTt06JBkZmZKzZo1bc976tQpEZH8/zCNGjWyfD00NFRCQkIKHVve20PNmzd3/BsqxOnTp+XSpUvSpEkT7WtNmzaV3NxcOXbsmDRr1iw/v/XWWy398sZc8HMgR5w7d04mT54s48ePl3r16hX7+MqA+XhdWcxHEZE//OEP4u3tLR9++GF+dv/990ujRo3k2Weftbz1WBkxH68rq/mYp2vXriIi0rNnT7n//vulefPmEhAQICNGjCjReYvDLQW5ffv20q5du0L7VK1aVZuEubm5UrNmTUlKSrI9JjQ01GVjdCcvLy/bXDnxYK65c+fK1atXZeDAgfl/V3f8+HERuT6Bjxw5InXq1Cn2nzVUJMzHwrlyPv7www+yfv16WbJkiSW/+eabpWPHjrJt2zanxliRMB8L58r5eCMNGzaU1q1bS1JSUsUvyM5q2LChbNq0STp06FDoaszw8HARuf4b42/fBjl9+nSRv0U1bNhQRET27dsnUVFRN+zn6NszoaGh4ufnJ99++632tYMHD4qnp2ep3rkePXpUzpw5Y/kNM8/MmTNl5syZkpaWJq1atSq1MVRUzMfiy9vgIicnR/vatWvXJDs7u9SuXdExH10rKytLrly5UqbXNOLPnhz14IMPSk5OjkyfPl37WnZ2dv4S9aioKPHx8ZGEhATLb03z5s0r8hpt2rSRiIgImTdvnrbk/bfn8vf3F5Gil8V7eXlJjx49ZM2aNZadX06ePCnLli2Tjh07SlBQUJHjKsjRZf2jRo2S5ORky7/FixeLyPXPWZKTkyUiIqLY1wfz8bccnY+33XabeHp6yvLlyy3jP378uGzdulVat25d7GvjOubjrxydj9nZ2ba/hOzatUu+/vrrIt+pcLVydYd83333SXx8vMyaNUu+/PJL6dGjh/j4+MihQ4dk5cqVMn/+fBkwYICEhobKM888I7NmzZLevXtLTEyMpKWlybp166RGjRqFXsPT01MWLVokffr0kVatWklcXJyEhYXJwYMHZf/+/bJhwwYREWnbtq2IXC940dHR4uXlJYMGDbI954wZM2Tjxo3SsWNHefrpp8Xb21sWL14sV65ckdmzZzv1Wji6rL9NmzbSpk0bS5Y38Zs1ayb9+vVz6vpgPv6Wo/MxNDRUhg0bJn/729+kW7du0r9/fzl//ry8/vrrkpWVJZMnT3bq+mA+/paj8/HChQtSr149GThwoDRr1kz8/f3l66+/lsTERAkODpbnnnvOqes7rSyXdN9oJ5qChgwZovz9/W/49SVLlqi2bdsqX19fFRgYqFq0aKEmTJigMjIy8vvk5OSoF198UYWFhSlfX1/VpUsXtW/fPm0nmoLL+vOkpqaq7t27q8DAQOXv769atmypEhIS8r+enZ2tRo4cqUJDQ5WHh4dlib8UWNavlFJ79uxR0dHRKiAgQPn5+amuXbtqf+Zxo9fHbozOLOvPw589Xcd8dM98vHbtmkpISFCtWrVSAQEBKiAgQHXt2lV98sknRR5bkTEfy34+XrlyRY0ePVq1bNlSBQUFKR8fHxUeHq4ee+wxlZ6eXuixpcFDKRd+Eg4AAJxSrj5DBgCgoqIgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABnB4py4eao/ClPWfszMfURjmI0zi6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCAt7sHAJRn3t6O/S90//33a9mdd95paa9evVrr07x5cy37/PPPtezgwYMOjQOAubhDBgDAABRkAAAMQEEGAMAAFGQAAAzgoZRSDnX08CjtsaAI9evX17LZs2drWWRkpJY1atRIy86cOeOScYmIODiNXMYd8zEwMFDLPvroIy3z9/fXsoiICC07evSopd2iRQuHxrF7924tW7p0qZZt3rxZy7799luHrlHeVYb5aIoOHTo41C8qKkrLJk6cqGWbNm2ytJOTk7U+dnP7yJEjDo3DHRydj9whAwBgAAoyAAAGoCADAGAACjIAAAZgUZehfH19tSwpKUnL7HaAevnll7VsypQprhnYDVS0RTQPPPCAltktQGnbtm2pjsNR2dnZWrZu3Tot++Mf/6hl586dK5UxuVNFm48l4ePjo2XBwcFadvnyZUt73LhxWp9BgwZp2e23365lrnz97V5bu59xkydPdtk1XY1FXQAAlCMUZAAADEBBBgDAABRkAAAMwOMXnVS1alUts9uNyW53mvfee8/SzszM1PosWbJEy+wWcH3xxRdaNmfOHC1D8XzwwQdalpub64aROObUqVNa9sknn2iZ3S5fBefosWPHXDcwuN2iRYu0LC4uTssK7hx36623ltqY8mzdulXLOnXqVOrXNRV3yAAAGICCDACAASjIAAAYgIIMAIABWNTlgAYNGmjZSy+9pGV2uzvZadeunaWdkJCg9RkwYIBD53r22We1zJWPVayszp49q2VBQUFlPxAH2T3y0c4vv/yiZdWrVy+yT2hoqJYdP35cy+x2DEPZee2117Rs2LBhWma3c1TBRVyHDh3S+tjtVmfXz24BZHx8vJa1adNGywrau3evlq1fv77I48oj7pABADAABRkAAANQkAEAMAAFGQAAA7Coq4D69etrmd0CLrtFV+fPn9eyzZs3a1nBx4SlpqZqfex2AktLS9OyjRs3ahlKzm6hnd0COjt2u2YV3AVJRGTZsmWW9lNPPaX1OXHihJa99dZbWvbVV19pmd0OcHY7NBVc1PW73/1O67NgwQIts9uZzu77RNk5efKk08cW3DVr8ODBWp+ffvrJoXO98MILWvbII49o2c0336xl3333naX9+9//XutTku/TZNwhAwBgAAoyAAAGoCADAGAACjIAAAbwUHZbtth19PAo7bGUuTp16mhZSkqKlrVo0ULLLly4oGWvv/66lj3//PNaNmLECEv7lVde0fpcvHhRy+wWRaxZs0bL3MHBaeQypT0fW7VqpWV2j7q0s2fPHoeOLbhQyu6xh3YLs+zYLUbs3r27lr3xxhtaVnChTmBgoNbHbpcykxd1VbT5WBJ2r4VdtnDhQkvb19dX69O4cWMts3tcot35v/nmGy1buXKlltktCCvvHJ2P3CEDAGAACjIAAAagIAMAYIBK9Rny7bffbmnbPTGkXr16WnbgwAEtGzNmjJZ9/PHHWmb32d7hw4ct7atXr2p9hg4dqmXLly/XMlNUtM/sPD3131XtPu/q16+fy645ZcoULbPb+KVu3bpatmTJEi2ze0KTI+zOtWvXLi2zm4+XLl1y6pquVtHmY0l07NhRy5KTk7XMbpMOR8yZM0fLPvjgAy07ePCgltmtxamI+AwZAIByhIIMAIABKMgAABiAggwAgAEq1aKugk8zuffeex06bvr06Vrm6B+vF/xjexGRJ5980tK2W0Rj9+Qfk1WGRTQtW7bUsnXr1mlZ7dq1y2I4LrN9+3ZLOyYmRutj9yQzk1WG+VgSdgta7TaSccTf/vY3LfvnP/+pZXb/r1QWLOoCAKAcoSADAGAACjIAAAagIAMAYIAKu6hr4sSJWvbSSy9Z2ufOndP6tG7dWst++OEHh645bdo0Lfvzn/+sZdu2bbO07Z6WUt5U1kU0Y8eO1bK5c+e6YSTOmz9/vqU9btw4N43EdSrrfHSU3ZOc+vbta2n3799f62O361dYWJiW5eTkaNnevXu1rODPZBGRjz76yNLOysrS+pQ3LOoCAKAcoSADAGAACjIAAAagIAMAYIAKsajLboHCjh07tKxFixaW9rVr17Q+//vf/xy6pt3rUaNGDS2ze4zf5cuXLe3OnTtrfb744guHxmGKyrqIxsfHR8tmzZqlZXaLv0wxfPhwSzs1NVXrs2/fvrIajktU1vlY2mrVqqVldjsevvXWW1oWHBzs0DVWrVplab/77rtan7Vr1zp0LlOwqAsAgHKEggwAgAEoyAAAGICCDACAASrsoq7XX39dyx588EFLu1q1ak5f0+71sHsp7RaJJSYmWtp2u9WcOXPG6bG5A4toflWlShUtW7FihZb16dOnLIZTbHZz9vDhw1r217/+Vcu+/vprLXPHgjDmo3mioqK0bNGiRVrWsGFDS9vutZ0yZYqW2S2mNAWLugAAKEcoyAAAGICCDACAASjIAAAYoEIs6nJU8+bNLW27x4bZiYyM1DK7xzvavZQFd0ESEXnjjTccum55wiKawj3zzDNa9vLLLzt1rp9++knL7BbHzJgxQ8u+/fZbLWvSpIlT47Czfv16LYuNjdWyq1evuuyadpiP5UNoaKiWPfLII5b2c889p/UJCAjQsmeffVbLXnnlFS3Lzc0tzhBdgkVdAACUIxRkAAAMQEEGAMAAFGQAAAxQqRZ1OWvDhg1a1r17dy3bu3evlrVu3bpUxmQaFtEUzm7xyttvv21px8TEOH3+5ORkLbN7VN6kSZO0rOBcvvPOO7U+ffv21bIDBw5o2R133KFldgvJtm7dqmVPPfWUpV2SxTfMx4qjQ4cOWvbpp586dGzNmjW1zNFH7LoSi7oAAChHKMgAABiAggwAgAH4DLmA3//+91pm9/lc1apVtcxuExC7DRsqIj6zK76CT4VatWqV1qcknytfunRJyzIyMrTsrbfesrTt1kIsX75cy7Kzs7UsJCSkOEO0aN++vaX9xRdfOH2u8jwfvb29tczuyXQXLlxw2TVN5uPjo2V2TxC77bbbtGzs2LFaZveUstLGZ8gAAJQjFGQAAAxAQQYAwAAUZAAADKCvHqhE7BYL2D0xpODiGxGRzz77TMvefPNN1wwMlULBJx5t27ZN61OSRV1+fn5aZrfwZdasWZa23cIvX19fLbNbfFQSU6dOtbTtNiOpDAYPHqxldk8Lmzlzppb94x//KJUxudO1a9e0LCcnx6Fj7Rbfmow7ZAAADEBBBgDAABRkAAAMQEEGAMAAlXqnLrtdXObOnevQsQMHDtSyDz74oMRjKq/K885IprBbZFi9enUtGz16tJZNmDChVMbkTl5eXk4fW57n40033aRldgv+mjZtqmVr167VstmzZ2tZamqqc4NzA7uFiLt379ay4OBgLYuLi9Oyd955xzUDKwZ26gIAoByhIAMAYAAKMgAABqAgAwBggEq9qGv79u1ads8992jZe++9p2VDhgwplTGVV+V5EU15Y7dD1owZM7Rs/PjxZTEcl9m5c6elfe+99zp9roo2HwMCArTM7pGY3bp107JffvlFy+wWtK5bt87SPnfuXHGG6DIFHzX5r3/9S+tj930W3PlOROSWW27RMrvXo7SxqAsAgHKEggwAgAEoyAAAGICCDACAASrV4xcfeughS7tFixZaH7uFAVu2bCmtIQHFlp2drWV2jw21W2gUGxurZatWrbK0H3vsMa2P3Y5hrjZt2rRSv0Z5deHCBS3r1auXlnXs2FHL7BalLlu2TMtOnDhhaT/++ONan4ILv4ojLCxMy7p3765lBXeia926tdbHbpHU4sWLtcwdC7hKgjtkAAAMQEEGAMAAFGQAAAxAQQYAwACVaqeugo/satOmjdbnrbfe0rInnnii1MZUUVS0nZEqAk9P/fdtu6zgIrE777xT69O/f38ts3uEn92iIjt2jwS8dOmSpV2SOcV8/JXdLl92/52WLFliadeuXVvrk5SUpGWnTp3SMrsFs3Y7rwUGBmpZQV999ZWWTZo0Scs2b96sZXaLdN2BnboAAChHKMgAABiAggwAgAEoyAAAGKBSLer6/vvvLW27XVxeffVVLXv//fdLbUwVBYtoYBLmY/HdfPPNlvbtt9/u0HF2C6zsdhGzY7dIbPXq1Zb2Z599pvU5efKkQ+c3BYu6AAAoRyjIAAAYgIIMAIABKtVnyCg9fGYHkzAfYRI+QwYAoByhIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYACHH78IAABKD3fIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABjg/wCJ6yDm5w+D/QAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model.update(state.params)\n", - "\n", - "# plot a 3x3 grid of MNIST digits\n", - "idxs = np.random.randint(0, len(X_test), size=(3, 3))\n", - "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", - "\n", - "for i in range(3):\n", - " for j in range(3):\n", - " logits = model(jnp.array([X_test[idxs[i, j]]]))\n", - " axes[i, j].imshow(X_test[idxs[i, j]], cmap=\"gray\")\n", - " axes[i, j].axis(\"off\")\n", - " axes[i, j].set_title(f\"Prediction: {jnp.argmax(logits)}\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Awesome! We hope you've enjoyed this tutorial and learned the basics of NNX." - ] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs_nnx/guides/randomness.ipynb b/docs_nnx/guides/randomness.ipynb index 517e3f22..a0d41462 100644 --- a/docs_nnx/guides/randomness.ipynb +++ b/docs_nnx/guides/randomness.ipynb @@ -6,7 +6,7 @@ "source": [ "# Randomness\n", "\n", - "Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines \"random state as object state\". In essence, this means that random state is just another type of state, and is stored in Variables and help by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters." + "Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines \"random state as object state\". In essence, this means that random state is just another type of state, it's stored in Variables and held by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters." ] }, { @@ -234,7 +234,7 @@ "source": [ "## Filtering random state\n", "\n", - "Random state can be manipulated using [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:" + "Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:" ] }, { diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md index 815567b4..c4b9623e 100644 --- a/docs_nnx/guides/randomness.md +++ b/docs_nnx/guides/randomness.md @@ -10,7 +10,7 @@ jupytext: # Randomness -Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines "random state as object state". In essence, this means that random state is just another type of state, and is stored in Variables and help by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters. +Random state in Flax NNX is radically simplified compared to systems like Haiku/Flax Linen in that it defines "random state as object state". In essence, this means that random state is just another type of state, it's stored in Variables and held by the models themselves. The main characteristic of the RNG system in Flax NNX is that its **explicit**, its **order-based**, and uses **dynamic counters**. This is a bit different from [Flax Linen's RNG system](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) which is (path + order)-based and uses a static counters. ```{code-cell} ipython3 from flax import nnx @@ -99,7 +99,7 @@ As shown above, a key from the `default` stream can also be generated by calling ## Filtering random state -Random state can be manipulated using [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`: +Random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`RngState`, `RngKey`, `RngCount`) or using strings corresponding to the stream names (see [The Filter DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`: ```{code-cell} ipython3 model = Model(nnx.Rngs(params=0, dropout=1)) diff --git a/docs_nnx/guides/surgery.ipynb b/docs_nnx/guides/surgery.ipynb index 00a1839e..b179f681 100644 --- a/docs_nnx/guides/surgery.ipynb +++ b/docs_nnx/guides/surgery.ipynb @@ -6,15 +6,13 @@ "source": [ "# Model surgery\n", "\n", - "> **Attention**: This page relates to the new Flax NNX API.\n", + "Model surgery is an act of making modifications on an existing neural network's building blocks and parameters, such as layer replacement, parameter or state manipulation, or even \"monkey patching\". In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios:\n", "\n", - "In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n", + "* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model.\n", "\n", - "* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n", + "* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation.\n", "\n", - "* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n", - "\n", - "* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", + "* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", "\n", "* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method." ] @@ -63,11 +61,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Pythonic module manipulation\n", + "## Pythonic `nnx.Module` manipulation\n", + "\n", + "It is easier to perform model surgery when:\n", "\n", - "Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n", + "1) You already have a fully fleshed-out model loaded with correct parameters; and\n", + "2) You don't intend to change your model definition code.\n", "\n", - "You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching:" + "You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching:" ] }, { @@ -80,25 +81,25 @@ "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n", "\n", - "# Sub-module swapping\n", + "# Sub-`Module` swapping.\n", "original1, original2 = model.linear1, model.linear2\n", "model.linear1, model.linear2 = model.linear2, model.linear1\n", "np.testing.assert_allclose(model(x), original1(original2(x)))\n", "\n", - "# Module sharing (tying all weights)\n", + "# `Module` sharing (tying all weights together).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear2 = model.linear1\n", "assert not hasattr(nnx.state(model), 'linear2')\n", "np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n", "\n", - "# Variable sharing (weight-tying)\n", + "# Variable sharing (weight-tying).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n", "assert hasattr(nnx.state(model), 'linear2')\n", "assert hasattr(nnx.state(model)['linear2'], 'bias')\n", "assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n", "\n", - "# Monkey-patching\n", + "# Monkey-patching.\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "def awesome_layer(x): return x\n", "model.linear2 = awesome_layer\n", @@ -111,13 +112,14 @@ "source": [ "## Creating an abstract model or state without memory allocation\n", "\n", - "For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", + "To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", + "\n", + "To create an abstract model:\n", "\n", - "To create an abstract model,\n", "* Create a function that returns a valid Flax NNX model; and\n", "* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n", "\n", - "Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information." + "Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information." ] }, { @@ -164,7 +166,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model." + "When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model." ] }, { @@ -188,9 +190,11 @@ "source": [ "## Checkpoint surgery\n", "\n", - "With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n", + "With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them.\n", + "\n", + "This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible.\n", "\n", - "This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here:" + "Let's run a simple example here:" ] }, { @@ -209,7 +213,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure:" + "In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure:" ] }, { @@ -221,15 +225,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "This will throw error: : 'layer1'\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", - " warnings.warn(\n" + "This will throw error: : Dict key mismatch; expected keys: ['linear1', 'linear2']; dict: {'layer1': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}, 'layer2': {'bias': {'value': RestoreArgs(restore_type=None, dtype=None)}, 'kernel': {'value': RestoreArgs(restore_type=None, dtype=None)}}}.\n" ] } ], @@ -255,7 +251,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition." + "However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition." ] }, { @@ -267,45 +263,46 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n", - " 'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n", + "{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", + " 'kernel': {'value': Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n", " [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n", " [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n", " [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)}},\n", - " 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},\n", - " 'kernel': {'raw_value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n", + " 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", + " 'kernel': {'value': Array([[ 0.21010089, 0.8289361 , 0.04589564, 0.5422644 ],\n", " [ 0.41914317, 0.84359694, -0.47937787, -0.49135214],\n", " [-0.46072108, 0.4630125 , 0.39276958, -0.9441406 ],\n", " [-0.6690758 , -0.18474789, -0.57622856, 0.4821079 ]], dtype=float32)}}}\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", + " warnings.warn(\n" + ] } ], "source": [ - "def module_from_variables_dict(module_factory, variables, map_key_fn):\n", - " if map_key_fn is None:\n", - " map_key_fn = lambda path: path\n", - " mdl = nnx.eval_shape(module_factory)\n", - " graph_def, state = nnx.split(mdl)\n", - " state = state.flat_state()\n", - " for path, val in flax.traverse_util.flatten_dict(variables).items():\n", - " mapped_path = map_key_fn(path)\n", - " if mapped_path not in state:\n", - " raise ValueError(f\"{mapped_path} doesn't exist in {state.keys()}\")\n", - " state[mapped_path].value = val\n", - " state = nnx.State.from_flat_path(state)\n", - " return nnx.merge(graph_def, state)\n", - "\n", - "# Make your local change on the checkpoint.\n", - "raw = checkpointer.restore('/tmp/nnx-surgery-state')\n", - "pprint(raw)\n", - "raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']\n", - "del raw['linear1'], raw['linear2']\n", - "\n", - "restored_model = module_from_variables_dict(\n", - " lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),\n", - " raw,\n", - " lambda path: path[:-1] if path[-1] == 'raw_value' else path\n", - ")\n", + "def process_raw_dict(raw_state_dict):\n", + " flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n", + " # Cut the '.value' postfix on every leaf path.\n", + " flattened = {(path[:-1] if path[-1] == 'value' else path): value\n", + " for path, value in flattened.items()}\n", + " return nnx.traversals.unflatten_mapping(flattened)\n", + "\n", + "# Make your local change on the checkpoint dictionary.\n", + "raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')\n", + "pprint(raw_dict)\n", + "raw_dict['layer1'] = raw_dict.pop('linear1')\n", + "raw_dict['layer2'] = raw_dict.pop('linear2')\n", + "\n", + "# Fit it into the model state.\n", + "abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", + "graph_def, state = nnx.split(abs_model)\n", + "state.replace_by_pure_dict(process_raw_dict(raw_dict))\n", + "restored_model = nnx.merge(graph_def, state)\n", "\n", "np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))" ] @@ -316,7 +313,10 @@ "source": [ "## Partial initialization\n", "\n", - "In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization." + "In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through:\n", + "\n", + "- Naive partial initialization; or\n", + "- Memory-efficient partial initialization." ] }, { @@ -325,9 +325,9 @@ "source": [ "### Naive partial initialization\n", "\n", - "You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below.\n", + "To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this.\n", "\n", - "> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output." + "> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output." ] }, { @@ -339,9 +339,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Number of jax arrays in memory at start: 34\n", - "Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n", - "Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model)\n" + "Number of jax arrays in memory at start: 38\n", + "Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n", + "Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)\n" ] } ], @@ -351,8 +351,8 @@ "\n", "simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n", "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", - "# On this line, extra kernel and bias is created inside the new LoRALinear!\n", - "# They are wasted since you are going to use the kernel and bias in `old_state` anyway.\n", + "# In this line, extra kernel and bias is created inside the new LoRALinear!\n", + "# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.\n", "simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n", "print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n", " ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n", @@ -367,7 +367,7 @@ "source": [ "### Memory-efficient partial initialization\n", "\n", - "Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" + "To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" ] }, { @@ -379,8 +379,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Number of jax arrays in memory at start: 40\n", - "Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b)\n" + "Number of jax arrays in memory at start: 44\n", + "Number of jax arrays in memory at end: 46 (2 new created - lora_a and lora_b)\n" ] } ], @@ -389,7 +389,7 @@ "old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "\n", "# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n", - "@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)\n", + "@nnx.jit(donate_argnums=0)\n", "def partial_init(old_state, rngs):\n", " model = TwoLayerMLP(4, rngs=rngs)\n", " # Create a new state.\n", @@ -398,12 +398,26 @@ " nnx.update(model, old_state)\n", " return model\n", "\n", - "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", + "print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')\n", "# Note that `old_state` will be deleted after this `partial_init` call.\n", "good_model = partial_init(old_state, nnx.Rngs(42))\n", - "print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n", + "print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'\n", " ' (2 new created - lora_a and lora_b)')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -420,7 +434,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs_nnx/guides/surgery.md b/docs_nnx/guides/surgery.md index e829f850..904eb7cf 100644 --- a/docs_nnx/guides/surgery.md +++ b/docs_nnx/guides/surgery.md @@ -10,15 +10,13 @@ jupytext: # Model surgery -> **Attention**: This page relates to the new Flax NNX API. +Model surgery is an act of making modifications on an existing neural network's building blocks and parameters, such as layer replacement, parameter or state manipulation, or even "monkey patching". In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios: -In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases: +* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model. -* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model. +* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation. -* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation. - -* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. +* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. * __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method. @@ -52,36 +50,39 @@ class TwoLayerMLP(nnx.Module): return self.linear2(x) ``` -## Pythonic module manipulation +## Pythonic `nnx.Module` manipulation + +It is easier to perform model surgery when: -Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code. +1) You already have a fully fleshed-out model loaded with correct parameters; and +2) You don't intend to change your model definition code. -You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching: +You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching: ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) np.testing.assert_allclose(model(x), model.linear2(model.linear1(x))) -# Sub-module swapping +# Sub-`Module` swapping. original1, original2 = model.linear1, model.linear2 model.linear1, model.linear2 = model.linear2, model.linear1 np.testing.assert_allclose(model(x), original1(original2(x))) -# Module sharing (tying all weights) +# `Module` sharing (tying all weights together). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear2 = model.linear1 assert not hasattr(nnx.state(model), 'linear2') np.testing.assert_allclose(model(x), model.linear1(model.linear1(x))) -# Variable sharing (weight-tying) +# Variable sharing (weight-tying). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate assert hasattr(nnx.state(model), 'linear2') assert hasattr(nnx.state(model)['linear2'], 'bias') assert not hasattr(nnx.state(model)['linear2'], 'kernel') -# Monkey-patching +# Monkey-patching. model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) def awesome_layer(x): return x model.linear2 = awesome_layer @@ -90,13 +91,14 @@ np.testing.assert_allclose(model(x), model.linear1(x)) ## Creating an abstract model or state without memory allocation -For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. +To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. + +To create an abstract model: -To create an abstract model, * Create a function that returns a valid Flax NNX model; and * Run `nnx.eval_shape` (not `jax.eval_shape`) upon it. -Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information. +Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information. ```{code-cell} ipython3 abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) @@ -104,7 +106,7 @@ gdef, abs_state = nnx.split(abs_model) pprint(abs_state) ``` -When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model. +When you fill every `nnx.VariableState` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model. ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) @@ -118,9 +120,11 @@ np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now! ## Checkpoint surgery -With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them. +With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them. + +This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible. -This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here: +Let's run a simple example here: ```{code-cell} ipython3 # Save a version of model into a checkpoint @@ -129,7 +133,7 @@ old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) ``` -In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure: +In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure: ```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): @@ -149,49 +153,45 @@ except Exception as e: print(f'This will throw error: {type(e)}: {e}') ``` -But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition. +However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition. ```{code-cell} ipython3 -def module_from_variables_dict(module_factory, variables, map_key_fn): - if map_key_fn is None: - map_key_fn = lambda path: path - mdl = nnx.eval_shape(module_factory) - graph_def, state = nnx.split(mdl) - state = state.flat_state() - for path, val in flax.traverse_util.flatten_dict(variables).items(): - mapped_path = map_key_fn(path) - if mapped_path not in state: - raise ValueError(f"{mapped_path} doesn't exist in {state.keys()}") - state[mapped_path].value = val - state = nnx.State.from_flat_path(state) - return nnx.merge(graph_def, state) - -# Make your local change on the checkpoint. -raw = checkpointer.restore('/tmp/nnx-surgery-state') -pprint(raw) -raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2'] -del raw['linear1'], raw['linear2'] - -restored_model = module_from_variables_dict( - lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))), - raw, - lambda path: path[:-1] if path[-1] == 'raw_value' else path -) +def process_raw_dict(raw_state_dict): + flattened = nnx.traversals.flatten_mapping(raw_state_dict) + # Cut the '.value' postfix on every leaf path. + flattened = {(path[:-1] if path[-1] == 'value' else path): value + for path, value in flattened.items()} + return nnx.traversals.unflatten_mapping(flattened) + +# Make your local change on the checkpoint dictionary. +raw_dict = checkpointer.restore('/tmp/nnx-surgery-state') +pprint(raw_dict) +raw_dict['layer1'] = raw_dict.pop('linear1') +raw_dict['layer2'] = raw_dict.pop('linear2') + +# Fit it into the model state. +abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))) +graph_def, state = nnx.split(abs_model) +state.replace_by_pure_dict(process_raw_dict(raw_dict)) +restored_model = nnx.merge(graph_def, state) np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4)))) ``` ## Partial initialization -In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization. +In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through: + +- Naive partial initialization; or +- Memory-efficient partial initialization. +++ ### Naive partial initialization -You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below. +To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this. -> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output. +> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output. ```{code-cell} ipython3 # Some pretrained model state @@ -199,8 +199,8 @@ old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') -# On this line, extra kernel and bias is created inside the new LoRALinear! -# They are wasted since you are going to use the kernel and bias in `old_state` anyway. +# In this line, extra kernel and bias is created inside the new LoRALinear! +# They are wasted, because you are going to use the kernel and bias in `old_state` anyway. simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42)) print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}' ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)') @@ -211,14 +211,14 @@ print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ### Memory-efficient partial initialization -Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: +To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: ```{code-cell} ipython3 # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) # Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient! -@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1) +@nnx.jit(donate_argnums=0) def partial_init(old_state, rngs): model = TwoLayerMLP(4, rngs=rngs) # Create a new state. @@ -227,9 +227,17 @@ def partial_init(old_state, rngs): nnx.update(model, old_state) return model -print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') +print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}') # Note that `old_state` will be deleted after this `partial_init` call. good_model = partial_init(old_state, nnx.Rngs(42)) -print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' +print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}' ' (2 new created - lora_a and lora_b)') ``` + +```{code-cell} ipython3 + +``` + +```{code-cell} ipython3 + +``` diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index edf8ef8e..4ad1e48a 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -5,16 +5,18 @@ "id": "962be290", "metadata": {}, "source": [ - "# Transforms\n", - "JAX transformations in general operate on [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of arrays\n", - "and abide by value semantics, this presents a challenge for Flax NNX which represents Modules as regular Python objects\n", - "that follow reference semantics. To address this, Flax NNX introduces its own set of transformations that extend JAX\n", - "transformations to allow Modules and other Flax NNX objects to be passed in and out of transformations while preserving\n", + "# Transformations\n", + "\n", + "In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s\n", + "and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects\n", + "that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX\n", + "transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving\n", "reference semantics.\n", "\n", - "Flax NNX transformations should feel quite familar to those who have used JAX transformations before as they use the\n", - "same APIs and behave like the JAX transformations when only working with Pytrees of arrays. However, when working with\n", + "Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the\n", + "same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with\n", "Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes:\n", + "\n", "* Preserving shared references across multiple objects in the inputs and outputs of the transformation.\n", "* Propagating any state changes made to the objects inside the transformation to the objects outside the transformation.\n", "* Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs." @@ -37,11 +39,12 @@ "id": "b44fb248", "metadata": {}, "source": [ - "Throughout this guide we will use `nnx.vmap` as a case study to demonstrate how Flax NNX transformations work but the principles\n", - "outlined here extend to all transformations.\n", + "Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles\n", + "outlined in this document extends to all transforms.\n", "\n", - "## Basic Example\n", - "To begin, let's look at a simple example of using `nnx.vmap` to extend an elementwise `vector_dot` function to work on\n", + "## Basic example\n", + "\n", + "To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on\n", "batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed\n", "as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use\n", "`nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`:" @@ -63,7 +66,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -75,7 +78,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -112,10 +115,10 @@ "id": "d2b222eb", "metadata": {}, "source": [ - "Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case.\n", + "Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case.\n", "\n", - "Objects are also allowed as outputs of Flax NNX transformations, this can be useful to transform initializers. For example,\n", - "we can define a `create_weights` function to create an single `Weights` Module and use `nnx.vmap` to create a stack of\n", + "Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example,\n", + "you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of\n", "`Weights` with the same shapes as before:" ] }, @@ -128,7 +131,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -140,7 +143,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -167,8 +170,9 @@ "id": "fac3dca9", "metadata": {}, "source": [ - "## Transforming Methods\n", - "Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `vmap` to do the work of `vector_dot`:" + "## Transforming methods\n", + "\n", + "Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`:" ] }, { @@ -187,7 +191,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -199,7 +203,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -236,7 +240,7 @@ "id": "13b52d61", "metadata": {}, "source": [ - "Throughout the rest of the guide we will focus on transforming individual functions, however, note all examples can easily be written in this method style." + "The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style." ] }, { @@ -245,8 +249,9 @@ "metadata": {}, "source": [ "## State propagation\n", - "So far our functions have been stateless. However, the real power of Flax NNX transformations comes when we have stateful functions since one of their main features is to propagate state changes to preserve reference semantics. Let's update our example by adding\n", - "a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function." + "\n", + "So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding\n", + "a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function:" ] }, { @@ -300,7 +305,7 @@ "id": "322312ee", "metadata": {}, "source": [ - "After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!" + "After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!" ] }, { @@ -309,9 +314,12 @@ "metadata": {}, "source": [ "### Graph updates propagation\n", - "JAX transformations see inputs as pytrees of arrays, and Flax NNX see inputs pytrees of arrays and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit!\n", "\n", - "The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap` and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation." + "JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported).\n", + "\n", + "This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit!\n", + "\n", + "The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation:" ] }, { @@ -323,7 +331,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -335,7 +343,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -383,7 +391,7 @@ "> With great power comes great responsibility.\n", ">
\\- Uncle Ben\n", "\n", - "While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error." + "While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error." ] }, { @@ -391,21 +399,22 @@ "id": "0d11d191", "metadata": {}, "source": [ - "## Transforming Substates (Lift Types)\n", + "## Transforming sub-states (lift types)\n", "\n", - "Certain JAX transformation allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of `Lift Types` which allow specifying how different substates of an object should be transformed. Different transformations support different Lift Types, here is the list of currently supported Lift Types for each transformation:\n", + "Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation:\n", "\n", - "| Lift Type | Transforms |\n", + "| Lift type | JAX transforms |\n", "|------------------|-----------------------------------------|\n", "| `StateAxes` | `vmap`, `pmap`, `scan` |\n", - "| `StateSharding` | `jit`, `shard_map` |\n", + "| `StateSharding` | `jit`, `shard_map`* |\n", "| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |\n", "\n", - "> NOTE: `shard_map` is not yet implemented.\n", + "> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.\n", "\n", - "If we want to specify how to vectorize different substates of an object in `nnx.vmap`, we create a `StateAxes` which maps a set of substates via [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and pass the `StateAxes` to `in_axes` and `out_axes` as if it were a pytree prefix. Let's use the previous `stateful_vector_dot` example and\n", - "vectorize only the `Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n", - "To do this we will define a `StateAxes` with a filter that matches the `Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `StateAxes` to `in_axes` for the `Weights` object." + "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n", + "\n", + "Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n", + "To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object." ] }, { @@ -458,7 +467,7 @@ "id": "1cfd87e1", "metadata": {}, "source": [ - "Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects." + "Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects." ] }, { @@ -466,10 +475,13 @@ "id": "1c8bb104", "metadata": {}, "source": [ - "### Random State\n", - "In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.\n", + "### Random state\n", + "\n", + "In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly.\n", + "\n", + "Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element.\n", "\n", - "Let's suppose we want change things up a bit and apply the same weights to all elements in the batch but we want to add different random noise to each element. To do this we will add a `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction, this seed key must be `split` before hand so we can vectorize it succesfully. For pedagogical reasons, we will assign the seed key to a `noise` Stream and sample from it. To vectorize the RNG state we must configure `StateAxes` to map all `RngState` (base class for all variables in `Rngs`) to axis `0`, and `Param` and `Count` to `None`." + "To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`." ] }, { @@ -488,7 +500,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -500,7 +512,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -547,7 +559,7 @@ "source": [ "Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.\n", "\n", - "In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then \"lower\" it back so its usable." + "In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then \"lower\" it back so that it becomes usable." ] }, { @@ -566,7 +578,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -578,7 +590,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -615,18 +627,69 @@ "nnx.display(weights)" ] }, + { + "cell_type": "markdown", + "id": "60eee7f9", + "metadata": {}, + "source": [ + "## Rules and limitations\n", + "In this section we will cover some rules and limitations apply when using Modules inside transformations.\n", + "\n", + "### Mutable Module cannot be passed by closure\n", + "\n", + "While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function.\n", + "\n", + "For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f8b95c03", + "metadata": {}, + "outputs": [], + "source": [ + "class Counter(nnx.Module):\n", + " def __init__(self):\n", + " self.count = nnx.Param(jnp.array(0))\n", + "\n", + " def increment(self):\n", + " self.count += jnp.array(1)\n", + "\n", + "counter = Counter()\n", + "\n", + "@nnx.jit\n", + "def f(x):\n", + " counter.increment()\n", + " return 2 * x\n", + "\n", + "try:\n", + " y = f(3)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "6f37e23b", + "metadata": {}, + "source": [ + "To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument." + ] + }, { "cell_type": "markdown", "id": "75edf7a8", "metadata": {}, "source": [ - "## Consistent aliasing\n", - "The main issue with allowing for reference semantics in transforms that references can be shared across inputs and outputs, this can be problematic if not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below we have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem is that we also specified we wanted to vectorize `arg1` in axis `0` and `arg2` in axis `1`, this is fine in JAX due to referential transparency of pytrees but its problematic in Flax NNX since we are trying to vectorize `m` in two different ways. NNX will enforce consistency by raising an error." + "### Consistent aliasing\n", + "\n", + "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "46b1cc25", "metadata": {}, "outputs": [ @@ -635,14 +698,10 @@ "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", - "Node: \n", + "Node: \n", " param: 0\n", " param: 0\n", - " param: 1\n", - "Node: \n", - " : 0\n", - " : 0\n", - " : 1\n" + " param: 1\n" ] } ], @@ -670,12 +729,12 @@ "id": "46aa978c", "metadata": {}, "source": [ - "Inconsistent aliasing can also happen between inputs and outputs. In the next example we have a trivial function that accepts and immediately return `arg1`, however `arg1` is vectorized on axis `0` on the input and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error." + "Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "cca9cf31", "metadata": {}, "outputs": [ @@ -684,14 +743,10 @@ "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", - "Node: \n", + "Node: \n", " param: 0\n", " param: 0\n", - " param: 1\n", - "Node: \n", - " : 0\n", - " : 0\n", - " : 1\n" + " param: 1\n" ] } ], @@ -711,13 +766,20 @@ "id": "13f9aeea", "metadata": {}, "source": [ - "## Axes Metadata\n", - "Flax NNX Variables can have hold arbitrary metadata which can be added by simply passing them as keyword arguments to their constructor. This is often used to store `sharding` information which is used by the `nnx.spmd` APIs like `nnx.get_partition_spec` and `nnx.get_named_sharding`. However, its often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved, for example, if we vectorize a variable on axis `1` we should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes is temporarily removed. To achieve this Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument, when the `nnx.PARTITION_NAME` key is present the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. Let's see an example of this in action:" + "## Axis metadata\n", + "\n", + "Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).\n", + "\n", + "However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.\n", + "\n", + "To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.\n", + "\n", + "Let's see an example of this in action:" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "d85c772c", "metadata": {}, "outputs": [ @@ -754,14 +816,14 @@ "id": "a23bda09", "metadata": {}, "source": [ - "Here we added a `sharding` metadata to the `Param` variables and used `transform_metadata` to update the `sharding` metadata to reflect the axes changes, specifically we can see that first axis `b` was removed from the `sharding` metadata when inside `vmap` and then added back when outside `vmap`.\n", + "Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.\n", "\n", - "We can verify that this also works when Modules are created inside the transformation, the new `sharding` axes will be added to the Module's Variables outside the transformation, matching the axes of the transformed Variables." + "You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "358e51f7", "metadata": {}, "outputs": [ diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 1b185e50..0cd7046f 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -10,16 +10,18 @@ jupytext: jupytext_version: 1.13.8 --- -# Transforms -JAX transformations in general operate on [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of arrays -and abide by value semantics, this presents a challenge for Flax NNX which represents Modules as regular Python objects -that follow reference semantics. To address this, Flax NNX introduces its own set of transformations that extend JAX -transformations to allow Modules and other Flax NNX objects to be passed in and out of transformations while preserving +# Transformations + +In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s +and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects +that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX +transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving reference semantics. -Flax NNX transformations should feel quite familar to those who have used JAX transformations before as they use the -same APIs and behave like the JAX transformations when only working with Pytrees of arrays. However, when working with +Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the +same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes: + * Preserving shared references across multiple objects in the inputs and outputs of the transformation. * Propagating any state changes made to the objects inside the transformation to the objects outside the transformation. * Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs. @@ -30,11 +32,12 @@ from jax import numpy as jnp, random from flax import nnx ``` -Throughout this guide we will use `nnx.vmap` as a case study to demonstrate how Flax NNX transformations work but the principles -outlined here extend to all transformations. +Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles +outlined in this document extends to all transforms. -## Basic Example -To begin, let's look at a simple example of using `nnx.vmap` to extend an elementwise `vector_dot` function to work on +## Basic example + +To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use `nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`: @@ -61,10 +64,10 @@ print(f'{y.shape = }') nnx.display(weights) ``` -Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it where a Pytree of arrays. Prefix patterns are also allowed, `in_axes=(0, 0)` would've also worked in this case. +Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case. -Objects are also allowed as outputs of Flax NNX transformations, this can be useful to transform initializers. For example, -we can define a `create_weights` function to create an single `Weights` Module and use `nnx.vmap` to create a stack of +Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example, +you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of `Weights` with the same shapes as before: ```{code-cell} ipython3 @@ -79,8 +82,9 @@ weights = nnx.vmap(create_weights)(seeds) nnx.display(weights) ``` -## Transforming Methods -Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `vmap` to do the work of `vector_dot`: +## Transforming methods + +Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`: ```{code-cell} ipython3 class WeightStack(nnx.Module): @@ -104,13 +108,14 @@ print(f'{y.shape = }') nnx.display(weights) ``` -Throughout the rest of the guide we will focus on transforming individual functions, however, note all examples can easily be written in this method style. +The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style. +++ ## State propagation -So far our functions have been stateless. However, the real power of Flax NNX transformations comes when we have stateful functions since one of their main features is to propagate state changes to preserve reference semantics. Let's update our example by adding -a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function. + +So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding +a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function: ```{code-cell} ipython3 class Count(nnx.Variable): pass @@ -139,14 +144,17 @@ y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x) weights.count ``` -After running `stateful_vector_dot` once we verify that the `count` attribute was correctly updated. Because Weights is vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by 1 inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice! +After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice! +++ ### Graph updates propagation -JAX transformations see inputs as pytrees of arrays, and Flax NNX see inputs pytrees of arrays and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit! -The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap` and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation. +JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). + +This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit! + +The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation: ```{code-cell} ipython3 class Count(nnx.Variable): pass @@ -181,25 +189,26 @@ nnx.display(weights) > With great power comes great responsibility. >
\- Uncle Ben -While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error. +While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error. +++ -## Transforming Substates (Lift Types) +## Transforming sub-states (lift types) -Certain JAX transformation allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of `Lift Types` which allow specifying how different substates of an object should be transformed. Different transformations support different Lift Types, here is the list of currently supported Lift Types for each transformation: +Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation: -| Lift Type | Transforms | +| Lift type | JAX transforms | |------------------|-----------------------------------------| | `StateAxes` | `vmap`, `pmap`, `scan` | -| `StateSharding` | `jit`, `shard_map` | +| `StateSharding` | `jit`, `shard_map`* | | `DiffState` | `grad`, `value_and_grad`, `custom_vjp` | -> NOTE: `shard_map` is not yet implemented. +> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document. + +To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix. -If we want to specify how to vectorize different substates of an object in `nnx.vmap`, we create a `StateAxes` which maps a set of substates via [Filters](https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and pass the `StateAxes` to `in_axes` and `out_axes` as if it were a pytree prefix. Let's use the previous `stateful_vector_dot` example and -vectorize only the `Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements. -To do this we will define a `StateAxes` with a filter that matches the `Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `StateAxes` to `in_axes` for the `Weights` object. +Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements. +To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object. ```{code-cell} ipython3 class Weights(nnx.Module): @@ -227,14 +236,17 @@ y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, weights.count ``` -Here count is now a scalar since its not being vectorized. Also, note that `StateAxes` can only be used directly on Flax NNX objects, it cannot be used as a prefix for a pytree of objects. +Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects. +++ -### Random State -In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly. +### Random state -Let's suppose we want change things up a bit and apply the same weights to all elements in the batch but we want to add different random noise to each element. To do this we will add a `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction, this seed key must be `split` before hand so we can vectorize it succesfully. For pedagogical reasons, we will assign the seed key to a `noise` Stream and sample from it. To vectorize the RNG state we must configure `StateAxes` to map all `RngState` (base class for all variables in `Rngs`) to axis `0`, and `Param` and `Count` to `None`. +In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need be aware of how the state is laid out so we can transform it correctly. + +Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element. + +To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`. ```{code-cell} ipython3 class Weights(nnx.Module): @@ -268,7 +280,7 @@ nnx.display(weights) Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called. -In the example above we manually split the random state during construction, this is fine as it makes the intention clear but it also doesn't let us use `Rngs` outside of `vmap` since its state is always split. To solve this we pass an unplit seed and use the `nnx.split_rngs` decorator before `vmap` to split the `RngState` right before each call to the function and then "lower" it back so its usable. +In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then "lower" it back so that it becomes usable. ```{code-cell} ipython3 weights = Weights( @@ -297,8 +309,43 @@ print(jnp.allclose(y1, y2)) nnx.display(weights) ``` -## Consistent aliasing -The main issue with allowing for reference semantics in transforms that references can be shared across inputs and outputs, this can be problematic if not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below we have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem is that we also specified we wanted to vectorize `arg1` in axis `0` and `arg2` in axis `1`, this is fine in JAX due to referential transparency of pytrees but its problematic in Flax NNX since we are trying to vectorize `m` in two different ways. NNX will enforce consistency by raising an error. +## Rules and limitations +In this section we will cover some rules and limitations apply when using Modules inside transformations. + +### Mutable Module cannot be passed by closure + +While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function. + +For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this: + +```{code-cell} ipython3 +class Counter(nnx.Module): + def __init__(self): + self.count = nnx.Param(jnp.array(0)) + + def increment(self): + self.count += jnp.array(1) + +counter = Counter() + +@nnx.jit +def f(x): + counter.increment() + return 2 * x + +try: + y = f(3) +except Exception as e: + print(e) +``` + +To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument. + ++++ + +### Consistent aliasing + +The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error. ```{code-cell} ipython3 class Weights(nnx.Module): @@ -319,7 +366,7 @@ except ValueError as e: print(e) ``` -Inconsistent aliasing can also happen between inputs and outputs. In the next example we have a trivial function that accepts and immediately return `arg1`, however `arg1` is vectorized on axis `0` on the input and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error. +Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error. ```{code-cell} ipython3 @nnx.vmap(in_axes=0, out_axes=1) @@ -332,8 +379,15 @@ except ValueError as e: print(e) ``` -## Axes Metadata -Flax NNX Variables can have hold arbitrary metadata which can be added by simply passing them as keyword arguments to their constructor. This is often used to store `sharding` information which is used by the `nnx.spmd` APIs like `nnx.get_partition_spec` and `nnx.get_named_sharding`. However, its often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved, for example, if we vectorize a variable on axis `1` we should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes is temporarily removed. To achieve this Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument, when the `nnx.PARTITION_NAME` key is present the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. Let's see an example of this in action: +## Axis metadata + +Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`). + +However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed. + +To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. + +Let's see an example of this in action: ```{code-cell} ipython3 class Weights(nnx.Module): @@ -352,9 +406,9 @@ print(f'Outter {m.param.shape = }') print(f'Outter {m.param.sharding = }') ``` -Here we added a `sharding` metadata to the `Param` variables and used `transform_metadata` to update the `sharding` metadata to reflect the axes changes, specifically we can see that first axis `b` was removed from the `sharding` metadata when inside `vmap` and then added back when outside `vmap`. +Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`. -We can verify that this also works when Modules are created inside the transformation, the new `sharding` axes will be added to the Module's Variables outside the transformation, matching the axes of the transformed Variables. +You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s. ```{code-cell} ipython3 @nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'}) diff --git a/docs_nnx/guides/why.ipynb b/docs_nnx/guides/why.ipynb deleted file mode 100644 index d38fe6c8..00000000 --- a/docs_nnx/guides/why.ipynb +++ /dev/null @@ -1,770 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Why NNX?\n", - "\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb)\n", - "\n", - "Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n", - "\n", - "We introduced some ideas that have proven to be good:\n", - " - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops.\n", - " - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms)\n", - " - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses.\n", - "\n", - "However, one choice we made was to use functional \"define by call\" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets.\n", - "\n", - "NNX is an attempt to keep the features that made Linen useful while introducing some new principles:\n", - "\n", - "- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references.\n", - "- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks.\n", - "\n", - "We'd love to hear from any of our users about their thoughts on these ideas.\n", - "\n", - "[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)]\n", - "[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)]" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": {}, - "outputs": [], - "source": [ - "! pip install -U git+https://github.com/google/flax.git\n", - "from functools import partial\n", - "import jax\n", - "from jax import random, numpy as jnp\n", - "from flax import nnx" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### NNX is Pythonic\n", - "The main feature of NNX Module is that it adheres to Python semantics. This means that:\n", - "\n", - "* fields are mutable so you can perform inplace updates\n", - "* Module references can be shared between multiple Modules\n", - "* Module construction implies parameter initialization\n", - "* Module methods can be called directly" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model = CounterLinear(\n", - " linear=Linear(\n", - " in_features=4,\n", - " out_features=4,\n", - " use_bias=True,\n", - " dtype=None,\n", - " param_dtype=,\n", - " precision=None,\n", - " kernel_init=.init at 0x7f3dc9ad3370>,\n", - " bias_init=,\n", - " dot_general=\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", - " pass\n", - "\n", - "\n", - "class CounterLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs): # explicit RNG threading\n", - " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", - " self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections\n", - "\n", - " def __call__(self, x):\n", - " self.count.value += 1 # in-place stateful updates\n", - " return self.linear(x)\n", - "\n", - "\n", - "model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", - "y = model(jnp.ones((2, 4))) # call methods directly\n", - "\n", - "print(f'{model = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Because NNX Modules contain their own state, they are very easily to inspect:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.count = Array(1, dtype=int32)\n", - "model.linear.kernel = Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", - " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", - " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", - " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n" - ] - } - ], - "source": [ - "print(f'{model.count = }')\n", - "print(f'{model.linear.kernel = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Intuitive Surgery\n", - "\n", - "In NNX surgery can be done at the Module level by simply updating / replacing existing fields." - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.7531997, 1.6318591, 2.1417565, 3.120555 ],\n", - " [1.7531997, 1.6318591, 2.1417565, 3.120555 ]], dtype=float32)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# pretend this came from a checkpoint or elsewhere:\n", - "pretrained_weight = random.uniform(random.key(0), (4, 4))\n", - "\n", - "# you can replace weights directly\n", - "model.linear.kernel = pretrained_weight\n", - "y = model(jnp.ones((2, 4)))\n", - "y" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "outputId": "5190ac7b-12f7-4400-d5bb-f91b97a557b6" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[1.624419 , 0.8313738 , 0.37612876, 1.9937458 ],\n", - " [1.624419 , 0.8313738 , 0.37612876, 1.9937458 ]], dtype=float32)" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def load_pretrained_fragment():\n", - " # pretend this inits / loads some fragment of a model\n", - " replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1))\n", - " return replacement\n", - "\n", - "# you can replace modules directly\n", - "model.linear = load_pretrained_fragment()\n", - "y = model(jnp.ones((2, 4)))\n", - "y" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before)." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "rngs = nnx.Rngs(0)\n", - "model = nnx.Sequence(\n", - " [\n", - " nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs),\n", - " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", - " nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs),\n", - " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", - " lambda x: x.reshape((x.shape[0], -1)), # flatten\n", - " nnx.Linear(32 * 7 * 7, 10, rngs=rngs),\n", - " ]\n", - ")\n", - "\n", - "y = model(jnp.ones((2, 28, 28, 1)))\n", - "\n", - "# Do some weird surgery of the stack:\n", - "for i, layer in enumerate(model):\n", - " if isinstance(layer, nnx.Conv):\n", - " model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)\n", - "\n", - "y = model(jnp.ones((2, 28, 28, 1)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Interacting with JAX is easy\n", - "\n", - "While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code.\n", - "\n", - "NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", - "\n", - "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module." - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": { - "outputId": "9a3f378b-739e-4f45-9968-574651200ede" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state = State({\n", - " 'count': Array(0, dtype=int32),\n", - " 'linear/bias': Array([0., 0., 0., 0.], dtype=float32),\n", - " 'linear/kernel': Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", - " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", - " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", - " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", - "})\n", - "\n", - "graphdef = GraphDef(\n", - " type=CounterLinear,\n", - " index=0,\n", - " static_fields=(),\n", - " variables=(('count', Count(\n", - " value=Empty\n", - " )),),\n", - " submodules=(\n", - " ('linear', GraphDef(\n", - " type=Linear,\n", - " index=1,\n", - " static_fields=(('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x7f3dc9ad3370>), ('out_features', 4), ('param_dtype', ), ('precision', None), ('use_bias', True)),\n", - " variables=(('bias', Param(\n", - " value=Empty\n", - " )), ('kernel', Param(\n", - " value=Empty\n", - " ))),\n", - " submodules=()\n", - " ))\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "model = CounterLinear(4, 4, rngs=nnx.Rngs(0))\n", - "\n", - "graphdef, state = model.split()\n", - "\n", - "# state is a dictionary-like JAX pytree\n", - "print(f'{state = }')\n", - "\n", - "# graphdef is also a JAX pytree, but containing no data, just metadata\n", - "print(f'\\n{graphdef = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object.\n", - "\n", - "Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": { - "outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y.shape = (2, 4)\n", - "state[\"count\"] = Array(1, dtype=int32)\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", - " model = graphdef.merge(state)\n", - " y = model(x)\n", - " state, _ = model.split()\n", - " return y, state\n", - "\n", - "x = jnp.ones((2, 4))\n", - "y, state = forward(graphdef,state, x)\n", - "\n", - "print(f'{y.shape = }')\n", - "print(f'{state[\"count\"] = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Custom lifting and transformation\n", - "\n", - "By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior.\n", - "\n", - "One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases.\n", - "\n", - "As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple.\n", - "\n", - "It uses the single additional method `update` to locally modify model state." - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": { - "outputId": "fdd212d7-4994-4fa5-d922-5a7d7cfad3e3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y.shape = (8, 4)\n", - "ensemble.models.count = Array(1, dtype=int32)\n", - "state = State({\n", - " 'models/count': (),\n", - " 'models/linear/bias': (8, 4),\n", - " 'models/linear/kernel': (8, 4, 4)\n", - "})\n" - ] - } - ], - "source": [ - "class LinearEnsemble(nnx.Module):\n", - " def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs):\n", - " # get raw rng seeds\n", - " keys = rngs.fork(num_models) # split all keys into `num_models`\n", - "\n", - " # define pure init fn and vmap\n", - " def vmap_init(keys):\n", - " return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(\n", - " nnx.Param, Count\n", - " )\n", - " params, counts, graphdef = jax.vmap(\n", - " vmap_init, in_axes=(0,), out_axes=(0, None, None)\n", - " )(keys)\n", - "\n", - " # update wrapped submodule reference\n", - " self.models = graphdef.merge(params, counts)\n", - "\n", - " def __call__(self, x):\n", - " # get module values, define pure fn,\n", - " # notice that we split the data into two collections by their types.\n", - " params, counts, graphdef = self.models.split(nnx.Param, Count)\n", - "\n", - " # define pure init fn and vmap\n", - " def vmap_apply(x, params, counts, graphdef):\n", - " model = graphdef.merge(params, counts)\n", - " y = model(x)\n", - " params, counts, graphdef = model.split(nnx.Param, Count)\n", - " return y, params, counts, graphdef\n", - "\n", - " y, params, counts, graphdef = jax.vmap(\n", - " vmap_apply,\n", - " in_axes=(None, 0, None, None),\n", - " out_axes=(0, 0, None, None)\n", - " )(x, params, counts, graphdef)\n", - "\n", - " # update wrapped module\n", - " # uses `update` to integrate the new state\n", - " self.models.update(params, counts, graphdef)\n", - " return y\n", - "\n", - "x = jnp.ones((4,))\n", - "ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0))\n", - "\n", - "# forward pass\n", - "y = ensemble(x)\n", - "\n", - "print(f'{y.shape = }')\n", - "print(f'{ensemble.models.count = }')\n", - "print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Convenience lifted transforms" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications." - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": { - "outputId": "c4800a49-efd1-4ee5-e703-6e63e18da4cb" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'scan_module/bias': Array([[0., 0.],\n", - " [0., 0.],\n", - " [0., 0.],\n", - " [0., 0.]], dtype=float32),\n", - " 'scan_module/kernel': Array([[[-0.32325608, 0.16164146],\n", - " [ 0.46505648, -0.34060344]],\n", - " \n", - " [[-1.1558908 , 1.2445341 ],\n", - " [-1.3710847 , -0.1787171 ]],\n", - " \n", - " [[-0.68510336, 0.25847596],\n", - " [ 1.0730107 , -0.11857361]],\n", - " \n", - " [[-0.01770882, 0.5472832 ],\n", - " [-0.84826714, 0.17867221]]], dtype=float32)\n", - "})" - ] - }, - "execution_count": 112, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# class transform:\n", - "ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n", - "\n", - "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", - "scanned.get_state()" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": { - "outputId": "9efd6e71-d180-4674-ade0-2b02057a400b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'model/bias': Array([[0., 0.],\n", - " [0., 0.],\n", - " [0., 0.],\n", - " [0., 0.]], dtype=float32),\n", - " 'model/kernel': Array([[[-0.32325608, 0.16164146],\n", - " [ 0.46505648, -0.34060344]],\n", - " \n", - " [[-1.1558908 , 1.2445341 ],\n", - " [-1.3710847 , -0.1787171 ]],\n", - " \n", - " [[-0.68510336, 0.25847596],\n", - " [ 1.0730107 , -0.11857361]],\n", - " \n", - " [[-0.01770882, 0.5472832 ],\n", - " [-0.84826714, 0.17867221]]], dtype=float32)\n", - "})" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# method decorators:\n", - "\n", - "class ScannedLinear(nnx.Module):\n", - "\n", - " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", - " def __init__(self, din, dout, *, rngs: nnx.Rngs):\n", - " self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs))\n", - "\n", - " @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4)\n", - " def __call__(self, x):\n", - " return self.model(x)\n", - "\n", - "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", - "scanned.get_state()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Aside: Why aren't Modules Pytrees?\n", - "\n", - "A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`.\n", - "\n", - "The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees.\n", - "\n", - "As an example, lets take a look at what would happen if we allowed this very simple program to be valid:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f(m1: nnx.Module, m2: nnx.Module):\n", - " return m1, m2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong?\n", - "\n", - "There are two main problems with this:\n", - "* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.\n", - "* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Standardized \"Hooks\"\n", - "\n", - "NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications." - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": { - "outputId": "c4e6586a-bfe0-4f26-d05b-8c9e395971b2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "self.kernel.shape = (4, 8)\n", - "outer kernel shape = (8, 4)\n" - ] - } - ], - "source": [ - "class TransposedParam(nnx.Variable):\n", - " def create_value(self, value):\n", - " return value.T # called on variable creation to transform initial value\n", - " def get_value(self):\n", - " return self.value.T # called when value fetched via module getattr\n", - " def set_value(self, value):\n", - " return self.replace(value=value.T) # called when setting value from module setattr\n", - "\n", - "\n", - "class OddLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs):\n", - " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)))\n", - " self.bias = nnx.Param(jnp.zeros((dout,)))\n", - "\n", - " def __call__(self, x):\n", - " print(f'{self.kernel.shape = }')\n", - " return x @ self.kernel + self.bias\n", - "\n", - "\n", - "model = OddLinear(4, 8, rngs=nnx.Rngs(0))\n", - "y = model(jnp.ones((2, 4)))\n", - "\n", - "print(f'outer kernel shape = {model.split()[0][\"kernel\"].shape}')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema:" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "metadata": { - "outputId": "ef312738-0f56-4c0e-9aaf-3319d131f1a2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state.variables['kernel'].meta='foo'\n", - "state.variables['kernel'].other_meta=0\n", - "state.variables['bias'].meta='bar'\n", - "state.variables['bias'].other_meta=1\n" - ] - } - ], - "source": [ - "class MetadataParam(nnx.Param):\n", - " def __init__(self, *args, **kwargs):\n", - " for key in kwargs:\n", - " setattr(self, key, kwargs[key])\n", - " super().__init__(*args)\n", - "\n", - "\n", - "class AnnotatedLinear(nnx.Module):\n", - " def __init__(self, din, dout, *, rngs):\n", - " self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0)\n", - " self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1)\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.kernel + self.bias\n", - "\n", - "\n", - "model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0))\n", - "y = model(jnp.ones((2, 4)))\n", - "\n", - "graphdef, state = model.split()\n", - "\n", - "print(f\"{state.variables['kernel'].meta=}\\n{state.variables['kernel'].other_meta=}\")\n", - "print(f\"{state.variables['bias'].meta=}\\n{state.variables['bias'].other_meta=}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Shape Inference\n", - "\n", - "Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic." - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": { - "outputId": "942a3788-bcbf-426d-87e6-c5a041172c64" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'encoder/bias': (4,),\n", - " 'encoder/kernel': (3, 3, 3, 4),\n", - " 'linear/bias': (4,),\n", - " 'linear/kernel': (144, 4)\n", - "})" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def batched_flatten(x):\n", - " return jnp.reshape(x, (x.shape[0], -1))\n", - "\n", - "class Example(nnx.Module):\n", - " def __init__(self, *,\n", - " in_filters=3,\n", - " out_filters=4,\n", - " input_shape=None, # provide an example input size\n", - " rngs):\n", - " self.encoder = nnx.Conv(in_filters, out_filters,\n", - " kernel_size=(3, 3),\n", - " strides=(1, 1),\n", - " padding=\"SAME\",\n", - " rngs=rngs)\n", - " # calculate the flattened shape post-conv using jax.eval_shape\n", - " encoded_shape = jax.eval_shape(\n", - " lambda x: batched_flatten(self.encoder(x)),\n", - " jax.ShapeDtypeStruct(input_shape, jnp.float32)\n", - " ).shape\n", - " # use this shape information to continue initializing\n", - " self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs)\n", - "\n", - " def __call__(self, x):\n", - " x = self.encoder(x)\n", - " x = batched_flatten(x)\n", - " return self.linear(x)\n", - "\n", - "model = Example(in_filters=3,\n", - " out_filters=4,\n", - " input_shape=(2, 6, 6, 3),\n", - " rngs=nnx.Rngs(0))\n", - "\n", - "graphdef, state = model.split()\n", - "jax.tree.map(jnp.shape, state)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - }, - "language_info": { - "name": "python", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs_nnx/guides/why.md b/docs_nnx/guides/why.md deleted file mode 100644 index b080319b..00000000 --- a/docs_nnx/guides/why.md +++ /dev/null @@ -1,409 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.13.8 ---- - -# Why NNX? - - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb) - -Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years. - -We introduced some ideas that have proven to be good: - - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops. - - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms) - - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses. - -However, one choice we made was to use functional "define by call" semantics for NN programming via the lazy initialization of parameters. This made for concise (`compact`) implementation code, allowed for a single specification when transforming a layer, and aligned our API with Haiku. Lazy initialization meant that the semantics of modules and variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets. - -NNX is an attempt to keep the features that made Linen useful while introducing some new principles: - -- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references. -- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks. - -We'd love to hear from any of our users about their thoughts on these ideas. - -[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)] -[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)] - -```{code-cell} -! pip install -U git+https://github.com/google/flax.git -from functools import partial -import jax -from jax import random, numpy as jnp -from flax import nnx -``` - -### NNX is Pythonic -The main feature of NNX Module is that it adheres to Python semantics. This means that: - -* fields are mutable so you can perform inplace updates -* Module references can be shared between multiple Modules -* Module construction implies parameter initialization -* Module methods can be called directly - -```{code-cell} -:outputId: d8ef66d5-6866-4d5c-94c2-d22512bfe718 - -class Count(nnx.Variable): # custom Variable types define the "collections" - pass - - -class CounterLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): # explicit RNG threading - self.linear = nnx.Linear(din, dout, rngs=rngs) - self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections - - def __call__(self, x): - self.count.value += 1 # in-place stateful updates - return self.linear(x) - - -model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method -y = model(jnp.ones((2, 4))) # call methods directly - -print(f'{model = }') -``` - -Because NNX Modules contain their own state, they are very easily to inspect: - -```{code-cell} -:outputId: 10a46b0f-2993-4677-c26d-36a4ddf33449 - -print(f'{model.count = }') -print(f'{model.linear.kernel = }') -``` - -#### Intuitive Surgery - -In NNX surgery can be done at the Module level by simply updating / replacing existing fields. - -```{code-cell} -:outputId: e6f86be8-3537-4c48-f471-316ee0fb6c45 - -# pretend this came from a checkpoint or elsewhere: -pretrained_weight = random.uniform(random.key(0), (4, 4)) - -# you can replace weights directly -model.linear.kernel = pretrained_weight -y = model(jnp.ones((2, 4))) -y -``` - -```{code-cell} -:outputId: 5190ac7b-12f7-4400-d5bb-f91b97a557b6 - -def load_pretrained_fragment(): - # pretend this inits / loads some fragment of a model - replacement = nnx.Linear(4, 4, rngs=nnx.Rngs(1)) - return replacement - -# you can replace modules directly -model.linear = load_pretrained_fragment() -y = model(jnp.ones((2, 4))) -y -``` - -Not only is this easier than messing with dictionary structures and aligning that with code changes, but one can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before). - -```{code-cell} -rngs = nnx.Rngs(0) -model = nnx.Sequence( - [ - nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs), - partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), - nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs), - partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)), - lambda x: x.reshape((x.shape[0], -1)), # flatten - nnx.Linear(32 * 7 * 7, 10, rngs=rngs), - ] -) - -y = model(jnp.ones((2, 28, 28, 1))) - -# Do some weird surgery of the stack: -for i, layer in enumerate(model): - if isinstance(layer, nnx.Conv): - model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs) - -y = model(jnp.ones((2, 28, 28, 1))) -``` - -Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc. - -+++ - -### Interacting with JAX is easy - -While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code. - -NNX has two very simple APIs to interact with JAX: `split` and `merge`. - -The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module. - -```{code-cell} -:outputId: 9a3f378b-739e-4f45-9968-574651200ede - -model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) - -graphdef, state = model.split() - -# state is a dictionary-like JAX pytree -print(f'{state = }') - -# graphdef is also a JAX pytree, but containing no data, just metadata -print(f'\n{graphdef = }') -``` - -The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object. - -Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example: - -```{code-cell} -:outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d - -@jax.jit -def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array): - model = graphdef.merge(state) - y = model(x) - state, _ = model.split() - return y, state - -x = jnp.ones((2, 4)) -y, state = forward(graphdef,state, x) - -print(f'{y.shape = }') -print(f'{state["count"] = }') -``` - -#### Custom lifting and transformation - -By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. - -One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes it very easy to implement custom lifted Modules or bespoke custom functional transforms for specific use cases. - -As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple. - -It uses the single additional method `update` to locally modify model state. - -```{code-cell} -:outputId: fdd212d7-4994-4fa5-d922-5a7d7cfad3e3 - -class LinearEnsemble(nnx.Module): - def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs): - # get raw rng seeds - keys = rngs.fork(num_models) # split all keys into `num_models` - - # define pure init fn and vmap - def vmap_init(keys): - return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split( - nnx.Param, Count - ) - params, counts, graphdef = jax.vmap( - vmap_init, in_axes=(0,), out_axes=(0, None, None) - )(keys) - - # update wrapped submodule reference - self.models = graphdef.merge(params, counts) - - def __call__(self, x): - # get module values, define pure fn, - # notice that we split the data into two collections by their types. - params, counts, graphdef = self.models.split(nnx.Param, Count) - - # define pure init fn and vmap - def vmap_apply(x, params, counts, graphdef): - model = graphdef.merge(params, counts) - y = model(x) - params, counts, graphdef = model.split(nnx.Param, Count) - return y, params, counts, graphdef - - y, params, counts, graphdef = jax.vmap( - vmap_apply, - in_axes=(None, 0, None, None), - out_axes=(0, 0, None, None) - )(x, params, counts, graphdef) - - # update wrapped module - # uses `update` to integrate the new state - self.models.update(params, counts, graphdef) - return y - -x = jnp.ones((4,)) -ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0)) - -# forward pass -y = ensemble(x) - -print(f'{y.shape = }') -print(f'{ensemble.models.count = }') -print(f'state = {jax.tree.map(jnp.shape, ensemble.get_state())}') -``` - -#### Convenience lifted transforms - -+++ - -Like linen, for convenience we still provide simple lifted transforms for standard JAX transforms, usable as class transforms and decorators. We've endeavored to simplify the API for scan and vmap compared to the flax specifications. - -```{code-cell} -:outputId: c4800a49-efd1-4ee5-e703-6e63e18da4cb - -# class transform: -ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4) - -scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) -scanned.get_state() -``` - -```{code-cell} -:outputId: 9efd6e71-d180-4674-ade0-2b02057a400b - -# method decorators: - -class ScannedLinear(nnx.Module): - - @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) - def __init__(self, din, dout, *, rngs: nnx.Rngs): - self.model = nnx.Linear(din, dout, rngs=nnx.Rngs(rngs)) - - @partial(nnx.scan, variable_axes={nnx.Param: 0}, length=4) - def __call__(self, x): - return self.model(x) - -scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) -scanned.get_state() -``` - -#### Aside: Why aren't Modules Pytrees? - -A common questions is why aren't NNX Modules registered as Pytrees? (in the style of Equinox, Treex, PytreeClass, etc.) It _is_ trivial to define a pytree registration in terms of `split`/`merge`. - -The problem is that Pytrees impose value semantics (referencial transparency) while Modules assume reference semantics, and therefore it is dangerous in general to automatically treat Modules as Pytrees. - -As an example, lets take a look at what would happen if we allowed this very simple program to be valid: - -```{code-cell} -@jax.jit -def f(m1: nnx.Module, m2: nnx.Module): - return m1, m2 -``` - -Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong? - -There are two main problems with this: -* Shared references are not maintained, that is, if `m1.shared` is the same as `m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`. -* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undesirable asymmetry and `jit` would no longer be a no-op. - -+++ - -### Standardized "Hooks" - -NNX introduces a standard getter/setter/creator interface for custom variables (similar to Haiku hooks). This is used internally to support SPMD metadata for managing sharding information, but is available for user-defined applications. - -```{code-cell} -:outputId: c4e6586a-bfe0-4f26-d05b-8c9e395971b2 - -class TransposedParam(nnx.Variable): - def create_value(self, value): - return value.T # called on variable creation to transform initial value - def get_value(self): - return self.value.T # called when value fetched via module getattr - def set_value(self, value): - return self.replace(value=value.T) # called when setting value from module setattr - - -class OddLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): - self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout))) - self.bias = nnx.Param(jnp.zeros((dout,))) - - def __call__(self, x): - print(f'{self.kernel.shape = }') - return x @ self.kernel + self.bias - - -model = OddLinear(4, 8, rngs=nnx.Rngs(0)) -y = model(jnp.ones((2, 4))) - -print(f'outer kernel shape = {model.split()[0]["kernel"].shape}') -``` - -SPMD metadata is handled using `nnx.with_partitioning` helpers, but it's easy to add one's own metadata schema: - -```{code-cell} -:outputId: ef312738-0f56-4c0e-9aaf-3319d131f1a2 - -class MetadataParam(nnx.Param): - def __init__(self, *args, **kwargs): - for key in kwargs: - setattr(self, key, kwargs[key]) - super().__init__(*args) - - -class AnnotatedLinear(nnx.Module): - def __init__(self, din, dout, *, rngs): - self.kernel = TransposedParam(random.uniform(rngs.params(), (din, dout)), meta='foo', other_meta=0) - self.bias = TransposedParam(jnp.zeros((dout,)), meta='bar', other_meta=1) - - def __call__(self, x): - return x @ self.kernel + self.bias - - -model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0)) -y = model(jnp.ones((2, 4))) - -graphdef, state = model.split() - -print(f"{state.variables['kernel'].meta=}\n{state.variables['kernel'].other_meta=}") -print(f"{state.variables['bias'].meta=}\n{state.variables['bias'].other_meta=}") -``` - -## Shape Inference - -Shape inference is still possible in NNX using abstract evaluation when it's really needed, it just isn't automatic. - -```{code-cell} -:outputId: 942a3788-bcbf-426d-87e6-c5a041172c64 - -def batched_flatten(x): - return jnp.reshape(x, (x.shape[0], -1)) - -class Example(nnx.Module): - def __init__(self, *, - in_filters=3, - out_filters=4, - input_shape=None, # provide an example input size - rngs): - self.encoder = nnx.Conv(in_filters, out_filters, - kernel_size=(3, 3), - strides=(1, 1), - padding="SAME", - rngs=rngs) - # calculate the flattened shape post-conv using jax.eval_shape - encoded_shape = jax.eval_shape( - lambda x: batched_flatten(self.encoder(x)), - jax.ShapeDtypeStruct(input_shape, jnp.float32) - ).shape - # use this shape information to continue initializing - self.linear = nnx.Linear(encoded_shape[-1], 4, rngs=rngs) - - def __call__(self, x): - x = self.encoder(x) - x = batched_flatten(x) - return self.linear(x) - -model = Example(in_filters=3, - out_filters=4, - input_shape=(2, 6, 6, 3), - rngs=nnx.Rngs(0)) - -graphdef, state = model.split() -jax.tree.map(jnp.shape, state) -``` diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index ce1b81b2..87d584a6 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -1,6 +1,6 @@ Flax -======== +==== .. div:: sd-text-left sd-font-italic **N**\ eural **N**\ etworks for JA\ **X** @@ -8,20 +8,18 @@ Flax ---- -Flax delivers an **end-to-end and flexible user experience for researchers -who use JAX with neural networks**. Flax -exposes the full power of `JAX `__. +Flax provides a **flexible end-to-end user experience for researchers and developers who use JAX for neural networks**. Flax enables you to use the full power of `JAX `__. -At its core is **Flax NNX, a simplified API that makes it easier to create, inspect, -debug, and analyze neural networks in JAX.** It has first class support -for Python reference semantics, allowing users to express their models using regular -Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it took years of -experience to bring a simpler and more user-friendly experience. +At the core of Flax is **NNX - a simplified API that makes it easier to create, inspect, +debug, and analyze neural networks in JAX.** Flax NNX has first class support +for Python reference semantics, enabling users to express their models using regular +Python objects. Flax NNX is an evolution of the previous `Flax Linen `__ +API, and it took years of experience to bring a simpler and more user-friendly API. .. note:: - Flax Linen API is not going to be deprecated in the near future as most of our users still - rely on this API, however new users are encouraged to use Flax NNX. - For existing Linen users to move to NNX, check out the `evolution guide `_. + Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX `_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API. + + To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics `_ and then start your move following the `evolution guide `_. Features ^^^^^^^^^ @@ -177,14 +175,14 @@ Learn more .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light - :link: /api_reference/index.html + :link: api_reference/index.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`import_contacts;2em` Glossary :class-card: sd-text-black sd-bg-light - :link: glossary.html + :link: nnx_glossary.html ---- @@ -195,9 +193,10 @@ Learn more nnx_basics mnist_tutorial + why guides/index examples/index - glossary - The Flax philosophy + nnx_glossary + The Flax philosophy How to contribute api_reference/index diff --git a/docs_nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb index a01acbcf..147ed07b 100644 --- a/docs_nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -5,15 +5,16 @@ "id": "0", "metadata": {}, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n", - "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n", + "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n", "\n", - "# MNIST Tutorial\n", + "# MNIST tutorial\n", "\n", - "Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional\n", - "neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library\n", - "built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within\n", - "[Flax](https://github.com/google/flax)." + "Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n", + "\n", + "Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", + "\n", + "Let’s get started!" ] }, { @@ -23,8 +24,7 @@ "source": [ "## 1. Install Flax\n", "\n", - "If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the\n", - "following cell:" + "If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):" ] }, { @@ -46,11 +46,9 @@ "id": "3", "metadata": {}, "source": [ - "## 2. Load the MNIST Dataset\n", + "## 2. Load the MNIST dataset\n", "\n", - "First, the MNIST dataset is loaded and prepared for training and testing using\n", - "Tensorflow Datasets. Image values are normalized, the data is shuffled and divided\n", - "into batches, and samples are prefetched to enhance performance." + "First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance." ] }, { @@ -72,10 +70,10 @@ } ], "source": [ - "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", + "import tensorflow_datasets as tfds # TFDS to download MNIST.\n", + "import tensorflow as tf # TensorFlow / `tf.data` operations.\n", "\n", - "tf.random.set_seed(0) # set random seed for reproducibility\n", + "tf.random.set_seed(0) # Set the random seed for reproducibility.\n", "\n", "train_steps = 1200\n", "eval_every = 200\n", @@ -95,13 +93,13 @@ " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", " 'label': sample['label'],\n", " }\n", - ") # normalize test set\n", + ") # Normalize the test set.\n", "\n", - "# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", + "# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.\n", "train_ds = train_ds.repeat().shuffle(1024)\n", - "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)\n", - "# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", + "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" ] }, @@ -110,9 +108,9 @@ "id": "5", "metadata": {}, "source": [ - "## 3. Define the Network with Flax NNX\n", + "## 3. Define the model with Flax NNX\n", "\n", - "Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`." + "Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:" ] }, { @@ -135,7 +133,7 @@ } ], "source": [ - "from flax import nnx # Flax NNX API\n", + "from flax import nnx # The Flax NNX API.\n", "from functools import partial\n", "\n", "class CNN(nnx.Module):\n", @@ -156,7 +154,9 @@ " x = self.linear2(x)\n", " return x\n", "\n", + "# Instantiate the model.\n", "model = CNN(rngs=nnx.Rngs(0))\n", + "# Visualize it.\n", "nnx.display(model)" ] }, @@ -165,9 +165,9 @@ "id": "7", "metadata": {}, "source": [ - "### Run model\n", + "### Run the model\n", "\n", - "Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results." + "Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results." ] }, { @@ -203,9 +203,9 @@ "id": "9", "metadata": {}, "source": [ - "## 4. Create Optimizer and Metrics\n", + "## 4. Create the optimizer and define some metrics\n", "\n", - "In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." + "In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." ] }, { @@ -247,9 +247,13 @@ "id": "13", "metadata": {}, "source": [ - "## 5. Define step functions\n", + "## 5. Define training step functions\n", + "\n", + "In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n", "\n", - "We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics." + "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n", + "\n", + "During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics." ] }, { @@ -271,13 +275,13 @@ " \"\"\"Train for a single step.\"\"\"\n", " grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)\n", " (loss, logits), grads = grad_fn(model, batch)\n", - " metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates\n", - " optimizer.update(grads) # inplace updates\n", + " metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.\n", + " optimizer.update(grads) # In-place updates.\n", "\n", "@nnx.jit\n", "def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):\n", " loss, logits = loss_fn(model, batch)\n", - " metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates" + " metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates." ] }, { @@ -285,12 +289,9 @@ "id": "17", "metadata": {}, "source": [ - "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with\n", - "[XLA](https://www.tensorflow.org/xla), optimizing performance on\n", - "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", - "except it can transforms functions that contain Flax NNX objects as inputs and outputs.\n", + "In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a \"lifted\" version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad ` is a lifted version of `jax.value_and_grad `. Check out [the lifted transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more.\n", "\n", - "**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code." + "> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)." ] }, { @@ -298,11 +299,11 @@ "id": "21", "metadata": {}, "source": [ - "## 6. Train and Evaluate\n", + "## 6. Train and evaluate the model\n", "\n", - "Now we train a model using batches of data for 10 epochs, evaluate its performance\n", - "on the test set after each epoch, and log the training and testing metrics (loss and\n", - "accuracy) throughout the process. Typically this leads to a model with around 99% accuracy." + "Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance\n", + "on the test set after each epoch, and log the training and testing metrics (the loss and\n", + "the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy." ] }, { @@ -422,25 +423,25 @@ "\n", "for step, batch in enumerate(train_ds.as_numpy_iterator()):\n", " # Run the optimization for one step and make a stateful update to the following:\n", - " # - the train state's model parameters\n", - " # - the optimizer state\n", - " # - the training loss and accuracy batch metrics\n", + " # - The train state's model parameters\n", + " # - The optimizer state\n", + " # - The training loss and accuracy batch metrics\n", " train_step(model, optimizer, metrics, batch)\n", "\n", - " if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed\n", - " # Log training metrics\n", - " for metric, value in metrics.compute().items(): # compute metrics\n", - " metrics_history[f'train_{metric}'].append(value) # record metrics\n", - " metrics.reset() # reset metrics for test set\n", + " if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.\n", + " # Log the training metrics.\n", + " for metric, value in metrics.compute().items(): # Compute the metrics.\n", + " metrics_history[f'train_{metric}'].append(value) # Record the metrics.\n", + " metrics.reset() # Reset the metrics for the test set.\n", "\n", - " # Compute metrics on the test set after each training epoch\n", + " # Compute the metrics on the test set after each training epoch.\n", " for test_batch in test_ds.as_numpy_iterator():\n", " eval_step(model, metrics, test_batch)\n", "\n", - " # Log test metrics\n", + " # Log the test metrics.\n", " for metric, value in metrics.compute().items():\n", " metrics_history[f'test_{metric}'].append(value)\n", - " metrics.reset() # reset metrics for next training epoch\n", + " metrics.reset() # Reset the metrics for the next training epoch.\n", "\n", " print(\n", " f\"[train] step: {step}, \"\n", @@ -459,9 +460,9 @@ "id": "23", "metadata": {}, "source": [ - "## 7. Visualize Metrics\n", + "## 7. Visualize the metrics\n", "\n", - "Use Matplotlib to create plots for loss and accuracy." + "With Matplotlib, you can create plots for the loss and the accuracy:" ] }, { @@ -503,9 +504,9 @@ "id": "25", "metadata": {}, "source": [ - "## 10. Perform inference on test set\n", + "## 10. Perform inference on the test set\n", "\n", - "Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." + "Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." ] }, { @@ -515,12 +516,22 @@ "metadata": {}, "outputs": [], "source": [ + "model.eval() # Switch to evaluation mode.\n", + "\n", "@nnx.jit\n", "def pred_step(model: CNN, batch):\n", " logits = model(batch['image'])\n", " return logits.argmax(axis=1)" ] }, + { + "cell_type": "markdown", + "id": "1d6cb81f", + "metadata": {}, + "source": [ + "Note that we use `.eval()` to ensure that the model is in evaluation mode, even though we are not using `Dropout` or `BatchNorm` in this model, `.eval()` ensure that the outputs are deterministic." + ] + }, { "cell_type": "code", "execution_count": 10, @@ -556,7 +567,9 @@ "id": "28", "metadata": {}, "source": [ - "Congratulations! You made it to the end of the annotated MNIST example." + "Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.\n", + "\n", + "Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html)." ] } ], diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index 74039533..74be0ec5 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -9,22 +9,22 @@ jupytext: jupytext_version: 1.13.8 --- -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb) -[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/nnx/mnist_tutorial.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb) +[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb) -# MNIST Tutorial +# MNIST tutorial -Welcome to Flax NNX! This tutorial will guide you through building and training a simple convolutional -neural network (CNN) on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library -built upon [JAX](https://github.com/google/jax) and currently offered as an experimental module within -[Flax](https://github.com/google/flax). +Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API. + +Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning. + +Let’s get started! +++ ## 1. Install Flax -If `flax` is not installed in your environment, you can install it from PyPI, uncomment and run the -following cell: +If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook): ```{code-cell} ipython3 :tags: [skip-execution] @@ -32,17 +32,15 @@ following cell: # !pip install flax ``` -## 2. Load the MNIST Dataset +## 2. Load the MNIST dataset -First, the MNIST dataset is loaded and prepared for training and testing using -Tensorflow Datasets. Image values are normalized, the data is shuffled and divided -into batches, and samples are prefetched to enhance performance. +First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance. ```{code-cell} ipython3 -import tensorflow_datasets as tfds # TFDS for MNIST -import tensorflow as tf # TensorFlow operations +import tensorflow_datasets as tfds # TFDS to download MNIST. +import tensorflow as tf # TensorFlow / `tf.data` operations. -tf.random.set_seed(0) # set random seed for reproducibility +tf.random.set_seed(0) # Set the random seed for reproducibility. train_steps = 1200 eval_every = 200 @@ -62,22 +60,22 @@ test_ds = test_ds.map( 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } -) # normalize test set +) # Normalize the test set. -# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from +# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from. train_ds = train_ds.repeat().shuffle(1024) -# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1) -# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency +# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) ``` -## 3. Define the Network with Flax NNX +## 3. Define the model with Flax NNX -Create a convolutional neural network with Flax NNX by subclassing `nnx.Module`. +Create a CNN for classification with Flax NNX by subclassing `nnx.Module`: ```{code-cell} ipython3 -from flax import nnx # Flax NNX API +from flax import nnx # The Flax NNX API. from functools import partial class CNN(nnx.Module): @@ -98,13 +96,15 @@ class CNN(nnx.Module): x = self.linear2(x) return x +# Instantiate the model. model = CNN(rngs=nnx.Rngs(0)) +# Visualize it. nnx.display(model) ``` -### Run model +### Run the model -Let's put our model to the test! We'll perform a forward pass with arbitrary data and print the results. +Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results. ```{code-cell} ipython3 :outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da @@ -115,9 +115,9 @@ y = model(jnp.ones((1, 28, 28, 1))) nnx.display(y) ``` -## 4. Create Optimizer and Metrics +## 4. Create the optimizer and define some metrics -In Flax NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model's reference so it can update its parameters, and an `optax` optimizer to define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. +In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. ```{code-cell} ipython3 import optax @@ -134,9 +134,13 @@ metrics = nnx.MultiMetric( nnx.display(optimizer) ``` -## 5. Define step functions +## 5. Define training step functions + +In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over. + +In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. -We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. During training, we'll use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the optimizer. During both training and testing, the loss and logits are used to calculate the metrics. +During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics. ```{code-cell} ipython3 def loss_fn(model: CNN, batch): @@ -151,29 +155,26 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b """Train for a single step.""" grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch) - metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates - optimizer.update(grads) # inplace updates + metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. + optimizer.update(grads) # In-place updates. @nnx.jit def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): loss, logits = loss_fn(model, batch) - metrics.update(loss=loss, logits=logits, labels=batch['label']) # inplace updates + metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. ``` -The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with -[XLA](https://www.tensorflow.org/xla), optimizing performance on -hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), -except it can transforms functions that contain Flax NNX objects as inputs and outputs. +In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a "lifted" version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad ` is a lifted version of `jax.value_and_grad `. Check out [the lifted transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more. -**NOTE**: in the above code we performed serveral inplace updates to the model, optimizer, and metrics, and we did not explicitely return the state updates. This is because Flax NNX transforms respect reference semantics for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. +> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). +++ -## 6. Train and Evaluate +## 6. Train and evaluate the model -Now we train a model using batches of data for 10 epochs, evaluate its performance -on the test set after each epoch, and log the training and testing metrics (loss and -accuracy) throughout the process. Typically this leads to a model with around 99% accuracy. +Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance +on the test set after each epoch, and log the training and testing metrics (the loss and +the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy. ```{code-cell} ipython3 :outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 @@ -187,25 +188,25 @@ metrics_history = { for step, batch in enumerate(train_ds.as_numpy_iterator()): # Run the optimization for one step and make a stateful update to the following: - # - the train state's model parameters - # - the optimizer state - # - the training loss and accuracy batch metrics + # - The train state's model parameters + # - The optimizer state + # - The training loss and accuracy batch metrics train_step(model, optimizer, metrics, batch) - if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # one training epoch has passed - # Log training metrics - for metric, value in metrics.compute().items(): # compute metrics - metrics_history[f'train_{metric}'].append(value) # record metrics - metrics.reset() # reset metrics for test set + if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed. + # Log the training metrics. + for metric, value in metrics.compute().items(): # Compute the metrics. + metrics_history[f'train_{metric}'].append(value) # Record the metrics. + metrics.reset() # Reset the metrics for the test set. - # Compute metrics on the test set after each training epoch + # Compute the metrics on the test set after each training epoch. for test_batch in test_ds.as_numpy_iterator(): eval_step(model, metrics, test_batch) - # Log test metrics + # Log the test metrics. for metric, value in metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) - metrics.reset() # reset metrics for next training epoch + metrics.reset() # Reset the metrics for the next training epoch. print( f"[train] step: {step}, " @@ -219,9 +220,9 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()): ) ``` -## 7. Visualize Metrics +## 7. Visualize the metrics -Use Matplotlib to create plots for loss and accuracy. +With Matplotlib, you can create plots for the loss and the accuracy: ```{code-cell} ipython3 :outputId: 431a2fcd-44fa-4202-f55a-906555f060ac @@ -240,17 +241,21 @@ ax2.legend() plt.show() ``` -## 10. Perform inference on test set +## 10. Perform inference on the test set -Define a jitted inference function, `pred_step`, to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. +Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. ```{code-cell} ipython3 +model.eval() # Switch to evaluation mode. + @nnx.jit def pred_step(model: CNN, batch): logits = model(batch['image']) return logits.argmax(axis=1) ``` +Note that we use `.eval()` to ensure that the model is in evaluation mode, even though we are not using `Dropout` or `BatchNorm` in this model, `.eval()` ensure that the outputs are deterministic. + ```{code-cell} ipython3 :outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e @@ -264,4 +269,6 @@ for i, ax in enumerate(axs.flatten()): ax.axis('off') ``` -Congratulations! You made it to the end of the annotated MNIST example. +Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset. + +Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html). diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 51e7480c..f5b74326 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -36,7 +36,7 @@ }, "outputs": [], "source": [ - "# ! pip install -U flax treescope" + "# ! pip install -U flax" ] }, { @@ -54,19 +54,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## The Flax `nnx.Module` system\n", + "## The Flax NNX Module system\n", "\n", - "The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that:\n", + "The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly, the [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user, and all shape information must be provided on initialization (no shape inference).\n", "\n", - "1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly.\n", - "2) The [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user.\n", - "3) All shape information must be provided on initialization (no shape inference).\n", - "\n", - "Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that:\n", - "\n", - "- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly.\n", - "- Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred.\n", - "- The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor." + "Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). As shown next, dynamic state is usually stored in `nnx.Param`s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic states, although storing them inside `nnx.Variable`s, such as `Param`, is preferred. Also the [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor." ] }, { @@ -90,9 +82,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Also note that:\n", - "\n", - "- The inner values of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).\n", + "Also note that the inner values of `nnx.Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).\n", "\n", "To initialize a Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html), you just call the constructor, and all the parameters of a `Module` are usually created eagerly. Since [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s hold their own state methods, you can call them directly without the need for a separate `apply` method.\n", "This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model." @@ -113,7 +103,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -190,7 +192,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Nested `nnx.Module`s\n", + "### Nested Modules\n", "\n", "Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.\n", "\n", @@ -205,7 +207,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -247,7 +261,7 @@ "source": [ "### Model surgery\n", "\n", - "Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes model surgery quite easy as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared.\n", + "Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared.\n", "\n", "The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers:" ] @@ -260,7 +274,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -355,7 +381,9 @@ "1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n", "2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone.\n", "\n", - "### `nnx.scan` over layers\n", + "> **Note**
`nnx.jit` has performance overhead for small models, check the [Performance Considerations](https://flax.readthedocs.io/en/latest/guides/performance.html) guide for more information.\n", + "\n", + "### Scan over layers\n", "\n", "The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input.\n", "\n", @@ -382,7 +410,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -441,7 +481,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -474,7 +526,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `State` and `GraphDef`\n", + "### State and GraphDef\n", "\n", "A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function:\n", "\n", @@ -490,7 +542,19 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -502,7 +566,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -522,7 +586,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### `split`, `merge`, and `update`\n", + "### Split, merge, and update\n", "\n", "Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows:\n", "\n", @@ -574,14 +638,14 @@ "source": [ "The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.\n", "\n", - "**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about." + "**Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Fine-grained `State` control\n", + "### Fine-grained State control\n", "\n", "Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", "\n", @@ -603,7 +667,31 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" ], "text/plain": [ "" @@ -615,7 +703,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index dc1e103e..61b96e2d 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -32,7 +32,7 @@ Install Flax with `pip` and impost necessary dependencies: ```{code-cell} ipython3 :tags: [skip-execution] -# ! pip install -U flax treescope +# ! pip install -U flax ``` ```{code-cell} ipython3 @@ -41,19 +41,11 @@ import jax import jax.numpy as jnp ``` -## The Flax `nnx.Module` system +## The Flax NNX Module system -The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that: +The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) and other `Module` systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly, the [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user, and all shape information must be provided on initialization (no shape inference). -1) The [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) itself holds the state (such as parameters) directly. -2) The [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user. -3) All shape information must be provided on initialization (no shape inference). - -Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that: - -- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly. -- Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred. -- The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor. +Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). As shown next, dynamic state is usually stored in `nnx.Param`s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic states, although storing them inside `nnx.Variable`s, such as `Param`, is preferred. Also the [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor. ```{code-cell} ipython3 class Linear(nnx.Module): @@ -67,9 +59,7 @@ class Linear(nnx.Module): return x @ self.w + self.b ``` -Also note that: - -- The inner values of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). +Also note that the inner values of `nnx.Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). To initialize a Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html), you just call the constructor, and all the parameters of a `Module` are usually created eagerly. Since [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s hold their own state methods, you can call them directly without the need for a separate `apply` method. This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model. @@ -112,7 +102,7 @@ to handle them, as demonstrated in later sections of this guide. +++ -### Nested `nnx.Module`s +### Nested Modules Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on. @@ -143,7 +133,7 @@ In Flax, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/fla ### Model surgery -Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes model surgery quite easy as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared. +Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-`Module` attribute can be replaced with anything else, such as new `Module`s, existing shared `Module`s, `Module`s of different types, and so on. Moreover, [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s can also be modified or replaced/shared. The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers: @@ -209,7 +199,9 @@ There are two things happening in this example that are worth mentioning: 1. The updates to each of the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside. 2. The `optimizer` holds a mutable reference to the `model` - this relationship is preserved inside the `train_step` function making it possible to update the model's parameters using the optimizer alone. -### `nnx.scan` over layers +> **Note**
`nnx.jit` has performance overhead for small models, check the [Performance Considerations](https://flax.readthedocs.io/en/latest/guides/performance.html) guide for more information. + +### Scan over layers The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) to create a stack of multiple MLP layers and [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) to iteratively apply each layer of the stack to the input. @@ -272,7 +264,7 @@ y = model(jnp.ones((1, 3))) nnx.display(model) ``` -### `State` and `GraphDef` +### State and GraphDef A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function: @@ -285,7 +277,7 @@ graphdef, state = nnx.split(model) nnx.display(graphdef, state) ``` -### `split`, `merge`, and `update` +### Split, merge, and update Flax's [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) is the reverse of [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). It takes the [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) + [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and reconstructs the [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The example below demonstrates this as follows: @@ -318,11 +310,11 @@ print(f'{model.count.value = }') The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries. -**Why aren't Flax `nnx.Module`s just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about. +**Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s that have a shared `Module` through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about. +++ -### Fine-grained `State` control +### Fine-grained State control Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations). diff --git a/docs_nnx/nnx_glossary.rst b/docs_nnx/nnx_glossary.rst new file mode 100644 index 00000000..864c8a0a --- /dev/null +++ b/docs_nnx/nnx_glossary.rst @@ -0,0 +1,43 @@ +***************** +Flax NNX glossary +***************** + +For additional terms, refer to the `JAX glossary `__. + +.. glossary:: + + Filter + A way to extract only certain :term:`nnx.Variable` objects out of a Flax NNX :term:`Module` (``nnx.Module``). This is usually done by calling :meth:`nnx.split ` upon the :class:`nnx.Module`. Refer to the `Filter guide `__ to learn more. + + Folding in + In Flax, `folding in `__ means generating a new `JAX pseudorandom number generator (PRNG) `__ key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with `jax.random.split `__, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the `Randomness/PRNG guide `__. + + GraphDef + :class:`nnx.GraphDef` is a class that represents all the static, stateless, and Pythonic parts of a Flax :term:`Module` (:class:`nnx.Module`). + + Merge + Refer to :term:`Split and merge`. + + Module + :class:`nnx.Module ` is a dataclass that enables defining and initializing parameters in a referentially-transparent form. It is responsible for storing and updating :term:`Variable objects and parameters within itself. + + Params / parameters + :class:`nnx.Param ` is a particular subclass of :class:`nnx.Variable ` that generally contains the trainable weights. + + PRNG states + A Flax :class:`nnx.Module ` can keep a reference of a `pseudorandom number generator (PRNG) `__ state object :class:`nnx.Rngs ` that can generate new `JAX PRNG `__ keys. These keys are used to generate random JAX arrays through `JAX's functional PRNGs `__. + You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks). + Refer to the Flax `Randomness/PRNG guide `__ + for more details. + + Split and merge + :meth:`nnx.split ` is a way to represent an :class:`nnx.Module ` by two parts: 1) a static Flax NNX :term:`GraphDef ` that captures its Pythonic static information; and 2) one or more :term:`Variable state(s)` that capture its `JAX arrays `__ (``jax.Array``) in the form of `JAX pytrees `__. They can be merged back to the original ``nnx.Module`` using :meth:`nnx.merge `. + + Transformation + A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation `__ that allows the function that is being transformed to take the Flax NNX :term:`Module` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit `__ is :meth:`nnx.jit `. Check out the `Flax NNX transforms guide `__ to learn more. + + Variable + The weights / parameters / data / array :class:`nnx.Variable ` residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. + + Variable state + :class:`nnx.VariableState ` is a purely functional `JAX pytree `__ of all the :term:`Variables` inside a :term:`Module`. Since it is pure, it can be an input or output of a `JAX transformation `__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split ` on the :class:`nnx.Module `. (Refer to :term:`splitting` and :term:`Module` to learn more.) diff --git a/docs_nnx/quick_start.ipynb b/docs_nnx/quick_start.ipynb deleted file mode 100644 index 32530b9b..00000000 --- a/docs_nnx/quick_start.ipynb +++ /dev/null @@ -1,701 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "6eea21b3", - "metadata": {}, - "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb)\n", - "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb)\n", - "\n", - "# Quick start\n", - "\n", - "Welcome to Flax!\n", - "\n", - "Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", - "network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n", - "the network for image classification on the MNIST dataset." - ] - }, - { - "cell_type": "markdown", - "id": "nwJWKIhdwxDo", - "metadata": {}, - "source": [ - "## 1. Install Flax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb81587e", - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], - "source": [ - "!pip install -q flax>=0.7.5" - ] - }, - { - "cell_type": "markdown", - "id": "b529fbef", - "metadata": {}, - "source": [ - "## 2. Loading data\n", - "\n", - "Flax can use any\n", - "data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the\n", - "samples to floating-point numbers." - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "bRlrHqZVXZvk", - "metadata": {}, - "outputs": [], - "source": [ - "import tensorflow_datasets as tfds # TFDS for MNIST\n", - "import tensorflow as tf # TensorFlow operations\n", - "\n", - "def get_datasets(num_epochs, batch_size):\n", - " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", - " train_ds = tfds.load('mnist', split='train')\n", - " test_ds = tfds.load('mnist', split='test')\n", - "\n", - " train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", - " tf.float32) / 255.,\n", - " 'label': sample['label']}) # normalize train set\n", - " test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", - " tf.float32) / 255.,\n", - " 'label': sample['label']}) # normalize test set\n", - "\n", - " train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - " train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - " test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", - " test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", - "\n", - " return train_ds, test_ds" - ] - }, - { - "cell_type": "markdown", - "id": "7057395a", - "metadata": {}, - "source": [ - "## 3. Define network\n", - "\n", - "Create a convolutional neural network with the Linen API by subclassing\n", - "[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", - "Because the architecture in this example is relatively simple—you're just\n", - "stacking layers—you can define the inlined submodules directly within the\n", - "`__call__` method and wrap it with the\n", - "[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n", - "decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide." - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "cbc079cd", - "metadata": {}, - "outputs": [], - "source": [ - "from flax import linen as nn # Linen API\n", - "\n", - "class CNN(nn.Module):\n", - " \"\"\"A simple CNN model.\"\"\"\n", - "\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", - " x = nn.relu(x)\n", - " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", - " x = nn.relu(x)\n", - " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = x.reshape((x.shape[0], -1)) # flatten\n", - " x = nn.Dense(features=256)(x)\n", - " x = nn.relu(x)\n", - " x = nn.Dense(features=10)(x)\n", - " return x" - ] - }, - { - "cell_type": "markdown", - "id": "hy7iRu7_zlx-", - "metadata": {}, - "source": [ - "### View model layers\n", - "\n", - "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "lDHfog81zLQa", - "metadata": { - "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[3m CNN Summary \u001b[0m\n", - "┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflops \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mvjp_flops\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", - "│ │ CNN │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 8708106 │ 26957556 │ │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Conv_0 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 455424 │ 1341472 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2mKB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Conv_1 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 6566144 │ 19704320 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[6… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m18,496 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(74.0 KB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Dense_0 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 1605888 │ 5620224 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m803,072 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(3.2 MB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│ Dense_1 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 5130 │ 17940 │ bias: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[1… │\n", - "│ │ │ │ │ │ │ kernel: │\n", - "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", - "│ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ \u001b[1m2,570 \u001b[0m │\n", - "│ │ │ │ │ │ │ \u001b[1;2m(10.3 KB)\u001b[0m │\n", - "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", - "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m824,458 \u001b[0m\u001b[1m \u001b[0m│\n", - "│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", - "└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘\n", - "\u001b[1m \u001b[0m\n", - "\u001b[1m Total Parameters: 824,458 \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\n", - "\n", - "\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp # JAX NumPy\n", - "\n", - "cnn = CNN()\n", - "print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),\n", - " compute_flops=True, compute_vjp_flops=True))" - ] - }, - { - "cell_type": "markdown", - "id": "4b5ac16e", - "metadata": {}, - "source": [ - "## 4. Create a `TrainState`\n", - "\n", - "A common pattern in Flax is to create a single dataclass that represents the\n", - "entire training state, including step number, parameters, and optimizer state.\n", - "\n", - "Because this is such a common pattern, Flax provides the class\n", - "[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)\n", - "that serves most basic usecases." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "qXr7JDpIxGNZ", - "metadata": { - "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" - }, - "outputs": [], - "source": [ - "!pip install -q clu" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "CJDaJNijyOji", - "metadata": {}, - "outputs": [], - "source": [ - "from clu import metrics\n", - "from flax.training import train_state # Useful dataclass to keep train state\n", - "from flax import struct # Flax dataclasses\n", - "import optax # Common loss functions and optimizers" - ] - }, - { - "cell_type": "markdown", - "id": "8b86b5f1", - "metadata": {}, - "source": [ - "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "7W0qf7FC9uG5", - "metadata": {}, - "outputs": [], - "source": [ - "@struct.dataclass\n", - "class Metrics(metrics.Collection):\n", - " accuracy: metrics.Accuracy\n", - " loss: metrics.Average.from_output('loss')" - ] - }, - { - "cell_type": "markdown", - "id": "f3ce5e4c", - "metadata": {}, - "source": [ - "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", - "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "e0102447", - "metadata": {}, - "outputs": [], - "source": [ - "class TrainState(train_state.TrainState):\n", - " metrics: Metrics\n", - "\n", - "def create_train_state(module, rng, learning_rate, momentum):\n", - " \"\"\"Creates an initial `TrainState`.\"\"\"\n", - " params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image\n", - " tx = optax.sgd(learning_rate, momentum)\n", - " return TrainState.create(\n", - " apply_fn=module.apply, params=params, tx=tx,\n", - " metrics=Metrics.empty())" - ] - }, - { - "cell_type": "markdown", - "id": "a15de484", - "metadata": {}, - "source": [ - "## 5. Training step\n", - "\n", - "A function that:\n", - "\n", - "- Evaluates the neural network given the parameters and a batch of input images\n", - " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n", - " method (forward pass)).\n", - "- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n", - "- Evaluates the gradient of the loss function using\n", - " [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).\n", - "- Applies a\n", - " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", - " of gradients to the optimizer to update the model's parameters.\n", - "\n", - "Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", - "decorator to trace the entire `train_step` function and just-in-time compile\n", - "it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n", - "that run faster and more efficiently on hardware accelerators." - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "9b0af486", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def train_step(state, batch):\n", - " \"\"\"Train for a single step.\"\"\"\n", - " def loss_fn(params):\n", - " logits = state.apply_fn({'params': params}, batch['image'])\n", - " loss = optax.softmax_cross_entropy_with_integer_labels(\n", - " logits=logits, labels=batch['label']).mean()\n", - " return loss\n", - " grad_fn = jax.grad(loss_fn)\n", - " grads = grad_fn(state.params)\n", - " state = state.apply_gradients(grads=grads)\n", - " return state" - ] - }, - { - "cell_type": "markdown", - "id": "0ff5145f", - "metadata": {}, - "source": [ - "## 6. Metric computation\n", - "\n", - "Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`." - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "961bf70b", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def compute_metrics(*, state, batch):\n", - " logits = state.apply_fn({'params': state.params}, batch['image'])\n", - " loss = optax.softmax_cross_entropy_with_integer_labels(\n", - " logits=logits, labels=batch['label']).mean()\n", - " metric_updates = state.metrics.single_from_model_output(\n", - " logits=logits, labels=batch['label'], loss=loss)\n", - " metrics = state.metrics.merge(metric_updates)\n", - " state = state.replace(metrics=metrics)\n", - " return state" - ] - }, - { - "cell_type": "markdown", - "id": "497241c3", - "metadata": {}, - "source": [ - "## 7. Download data" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "bff5393e", - "metadata": {}, - "outputs": [], - "source": [ - "num_epochs = 10\n", - "batch_size = 32\n", - "\n", - "train_ds, test_ds = get_datasets(num_epochs, batch_size)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "809ae1a0", - "metadata": {}, - "source": [ - "## 8. Seed randomness\n", - "\n", - "- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible.\n", - "- Get one\n", - " [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)\n", - " and use it for parameter initialization. (Learn\n", - " more about\n", - " [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", - " and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "xC4MFyBsfT-U", - "metadata": {}, - "outputs": [], - "source": [ - "tf.random.set_seed(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "e4f6f4d3", - "metadata": {}, - "outputs": [], - "source": [ - "init_rng = jax.random.key(0)" - ] - }, - { - "cell_type": "markdown", - "id": "80fbb60b", - "metadata": {}, - "source": [ - "## 9. Initialize the `TrainState`\n", - "\n", - "Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics\n", - "and puts them into the training state dataclass that is returned." - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "445fcab0", - "metadata": {}, - "outputs": [], - "source": [ - "learning_rate = 0.01\n", - "momentum = 0.9" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "id": "5221eafd", - "metadata": {}, - "outputs": [], - "source": [ - "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", - "del init_rng # Must not be used anymore." - ] - }, - { - "cell_type": "markdown", - "id": "b1c00230", - "metadata": {}, - "source": [ - "## 10. Train and evaluate\n", - "\n", - "Create a \"shuffled\" dataset by:\n", - "- Repeating the dataset equal to the number of training epochs\n", - "- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from\n", - " - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer\n", - "\n", - "Define a training loop that:\n", - "- Randomly samples batches from the dataset.\n", - "- Runs an optimization step for each training batch.\n", - "- Computes the mean training metrics across each batch in an epoch.\n", - "- Computes the metrics for the test set using the updated parameters.\n", - "- Records the train and test metrics for visualization.\n", - "\n", - "Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy." - ] - }, - { - "cell_type": "code", - "execution_count": 62, - "id": "74295360", - "metadata": {}, - "outputs": [], - "source": [ - "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", - "num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs" - ] - }, - { - "cell_type": "code", - "execution_count": 63, - "id": "cRtnMZuQFlKl", - "metadata": {}, - "outputs": [], - "source": [ - "metrics_history = {'train_loss': [],\n", - " 'train_accuracy': [],\n", - " 'test_loss': [],\n", - " 'test_accuracy': []}" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "2c40ce90", - "metadata": { - "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203\n", - "test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688\n", - "train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938\n", - "test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164\n", - "train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469\n", - "test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578\n", - "train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672\n", - "test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125\n", - "train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797\n", - "test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312\n", - "train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547\n", - "test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438\n", - "train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539\n", - "test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164\n", - "train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375\n", - "test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578\n", - "train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156\n", - "test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438\n", - "train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297\n", - "test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562\n" - ] - } - ], - "source": [ - "for step,batch in enumerate(train_ds.as_numpy_iterator()):\n", - "\n", - " # Run optimization steps over training batches and compute batch metrics\n", - " state = train_step(state, batch) # get updated train state (which contains the updated parameters)\n", - " state = compute_metrics(state=state, batch=batch) # aggregate batch metrics\n", - "\n", - " if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed\n", - " for metric,value in state.metrics.compute().items(): # compute metrics\n", - " metrics_history[f'train_{metric}'].append(value) # record metrics\n", - " state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch\n", - "\n", - " # Compute metrics on the test set after each training epoch\n", - " test_state = state\n", - " for test_batch in test_ds.as_numpy_iterator():\n", - " test_state = compute_metrics(state=test_state, batch=test_batch)\n", - "\n", - " for metric,value in test_state.metrics.compute().items():\n", - " metrics_history[f'test_{metric}'].append(value)\n", - "\n", - " print(f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", - " f\"loss: {metrics_history['train_loss'][-1]}, \"\n", - " f\"accuracy: {metrics_history['train_accuracy'][-1] * 100}\")\n", - " print(f\"test epoch: {(step+1) // num_steps_per_epoch}, \"\n", - " f\"loss: {metrics_history['test_loss'][-1]}, \"\n", - " f\"accuracy: {metrics_history['test_accuracy'][-1] * 100}\")" - ] - }, - { - "cell_type": "markdown", - "id": "gfsecJzvzgCT", - "metadata": {}, - "source": [ - "## 11. Visualize metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "Zs5atiqIG9Kz", - "metadata": { - "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3cAAAE/CAYAAADlpzo+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABsiElEQVR4nO3dd3yddd3/8dcneyfNaJs26aJ7JAFKyxLUKlBWAQEBmQK9uRUEb/EWt7hufooDFcEyBVFUFK1QLFBEZLeFpLt00qRJ23Rk7+T7++O6kp6maXPSjJPxfj4e53HONc/3nJ7mOu/zXeacQ0RERERERAa2sFAXQERERERERLpP4U5ERERERGQQULgTEREREREZBBTuREREREREBgGFOxERERERkUFA4U5ERERERGQQULgTEREREREZBBTuRPqYmW03s0+EuhwiIiK9ycxeNbMDZhYd6rKIDBUKdyIiIiLSo8xsHPARwAEX9uHzRvTVc4n0Rwp3Iv2AmUWb2c/NrNi//bz1l04zSzez58yszMz2m9l/zCzM3/YVM9tpZpVmttHM5oX2lYiIiABwLfA28DhwXetKM8s2s7+aWamZ7TOzXwVsu9nM1vvXtHVmdoK/3pnZxID9Hjez7/uPP2pmRf71cBfwmJkN86+bpX7N4XNmlhVwfKqZPeZfbw+Y2d/89WvM7IKA/SLNbK+Z5fXSeyTS4xTuRPqHrwMnA3lALjAH+Ia/7UtAEZABjAC+BjgzmwLcCpzknEsEzga292mpRUREOnYt8JR/O9vMRphZOPAc8CEwDhgNPA1gZpcB3/GPS8Kr7dsX5HONBFKBscBCvO+3j/nLY4Ba4FcB+z8JxAEzgOHAz/z1TwBXB+x3LlDinMsPshwiIaeqa5H+4TPAbc65PQBmdjfwG+CbQCOQCYx1zm0G/uPv0wxEA9PNrNQ5tz0UBRcREQlkZqfjBas/Oef2mtkW4Cq8mrxRwJedc03+7q/79zcBP3LOLfeXN3fhKVuAbzvn6v3lWuAvAeX5AfAv/3EmMB9Ic84d8Hf5t3//O+CbZpbknKsArsELgiIDhmruRPqHUXi/ZLb60F8H8GO8i9yLZrbVzO4C8IPeHXi/dO4xs6fNbBQiIiKhdR3wonNur7/8e39dNvBhQLALlA1sOcbnK3XO1bUumFmcmf3GzD40swrgNSDFrznMBvYHBLs2zrli4A3gU2aWghcCnzrGMomEhMKdSP9QjPcrZ6sx/jqcc5XOuS855yYAFwD/09q3zjn3e+dc6y+kDvh/fVtsERGRg8wsFrgcONPMdvn94L6I1+VgNzDmCIOeFALHHeG0NXjNKFuNbLfdtVv+EjAFmOucSwLOaC2e/zypfnjryG/xmmZeBrzlnNt5hP1E+iWFO5HQiDSzmNYb8AfgG2aWYWbpwLfwmodgZueb2UQzM6ACaAaazWyKmX3cH3ilDq8ZSnNoXo6IiAgAF+Fdi6bj9SPPA6bhdSm4CCgB7jGzeP8aeJp/3MPAnWZ2onkmmlnrj575wFVmFm5m5wBndlKGRLxrYpmZpQLfbt3gnCsBXgB+7Q+8EmlmZwQc+zfgBOB2vD54IgOKwp1IaCzBu/C03mKAFcAqYDXwHvB9f99JwMtAFfAW8Gvn3Kt4/e3uAfYCu/A6hX+tz16BiIjI4a4DHnPO7XDO7Wq94Q1ociVeC5SJwA68wcI+DeCc+zPwA7wmnJV4ISvVP+ft/nFleH3U/9ZJGX4OxOJdH98G/tlu+zV4/dk3AHvwujjgl6O1v9544K/Bv2yR/sGca1+TLSIiIiIyNJnZt4DJzrmrO91ZpJ/RaJkiIiIiInhz4AE34tXuiQw4apYpIiIiIkOemd2MN+DKC86510JdHpFjoWaZIiIiIiIig4Bq7kRERERERAYBhTsREREREZFBYEANqJKenu7GjRsX6mKIiEgvW7ly5V7nXEaoyzFQ6PooIjJ0HO0aOaDC3bhx41ixYkWoiyEiIr3MzD4MdRkGEl0fRUSGjqNdI9UsU0REpIeZ2aNmtsfM1hxhu5nZL8xss5mtMrMTAradY2Yb/W139V2pRURkoFO4ExER6XmPA+ccZft8YJJ/Wwg8AGBm4cD9/vbpwJVmNr1XSyoiIoOGwp2IiEgP8+fI2n+UXRYATzjP20CKmWUCc4DNzrmtzrkG4Gl/XxERkU4NqD53IiL9RWNjI0VFRdTV1YW6KANaTEwMWVlZREZGhroofW003mTJrYr8dR2tn3ssT6DP6MAzhP8/iEgPUbgTETkGRUVFJCYmMm7cOMws1MUZkJxz7Nu3j6KiIsaPHx/q4vS1jj407ijrDz+B2UK8Jp2MGTPmsO36jA4sQ/z/g4j0EDXLFBE5BnV1daSlpelLczeYGWlpaUO1ZqkIyA5YzgKKj7L+MM65Rc652c652RkZh4+Irc/owDLE/z+ISA9RuBMROUb60tx9Q/g9XAxc64+aeTJQ7pwrAZYDk8xsvJlFAVf4+x6TIfz+Dkj69xKR7lKzTBERkR5mZn8APgqkm1kR8G0gEsA59yCwBDgX2AzUADf425rM7FZgKRAOPOqcW9vnL0BERAYk1dyJiAxAZWVl/PrXv+7yceeeey5lZWVdPu7666/nmWee6fJxQ5Vz7krnXKZzLtI5l+Wce8Q596Af7PBHyfy8c+4459ws59yKgGOXOOcm+9t+ELpX0X19/TkVERnqhlS4W1tczlPvHHFCdxGRAeNIX5qbm5uPetySJUtISUnppVKJHGqwfk47K7+IDG1NzS0cqG7gw33VrC4q5/VNe1myuoSn393Bpt2VvfrcQTXLNLNzgPvwmog87Jy7p932zwBf8RergP92zhUc7VgzSwX+CIwDtgOXO+cOdPP1HNVL63Zz37JNLMgbTUK0WqSKyMB11113sWXLFvLy8oiMjCQhIYHMzEzy8/NZt24dF110EYWFhdTV1XH77bezcOFCAMaNG8eKFSuoqqpi/vz5nH766bz55puMHj2av//978TGxnb63MuWLePOO++kqamJk046iQceeIDo6GjuuusuFi9eTEREBGeddRb33nsvf/7zn7n77rsJDw8nOTmZ1157rbffGulH+vpz+tBDD7Fo0SIaGhqYOHEiTz75JHFxcezevZtbbrmFrVu3AvDAAw9w6qmn8sQTT3DvvfdiZuTk5PDkk09y/fXXc/7553PppZcCkJCQQFVVFa+++ip33313UOX/5z//yde+9jWam5tJT0/npZdeYsqUKbz55ptkZGTQ0tLC5MmTefvtt0lPT++DfwkR6YrmFkdlXSMVtU1U1DVSUdvo3zdR3va4kYq6pkO2ta6vbjjyD0DfXTCDSSMSe63snSYcMwsH7gc+iTeK13IzW+ycWxew2zbgTOfcATObDywC5nZy7F3AMufcPWZ2l7/8FXpRbnYKzsHqonJOOS6tN59KRIaQu/+xlnXFFT16zumjkvj2BTOOuP2ee+5hzZo15Ofn8+qrr3LeeeexZs2atiHUH330UVJTU6mtreWkk07iU5/6FGlph/7d27RpE3/4wx946KGHuPzyy/nLX/7C1VdffdRy1dXVcf3117Ns2TImT57MtddeywMPPMC1117Ls88+y4YNGzCztiZ13/3ud1m6dCmjR49WM7sQCsVnFPr+c3rJJZdw8803A/CNb3yDRx55hNtuu40vfOELnHnmmTz77LM0NzdTVVXF2rVr+cEPfsAbb7xBeno6+/cfbc55z7vvvttp+VtaWrj55pt57bXXGD9+PPv37ycsLIyrr76ap556ijvuuIOXX36Z3NxcBTuRXtLc4qiq88JWeQfh69BQdvi2qvqmo54/zCAxJpKk2AiSYiJJiolkXHocybHe46TYSJJiIvx7f9nfNzU+qldfezDVV3OAzc65rQBm9jSwAGgLd865NwP2fxtv6ObOjl2A19kc4LfAq/R2uMtKAWBVUZnCnYgMKnPmzDlkbqxf/OIXPPvsswAUFhayadOmw740jx8/nry8PABOPPFEtm/f3unzbNy4kfHjxzN58mQArrvuOu6//35uvfVWYmJiuOmmmzjvvPM4//zzATjttNO4/vrrufzyy7nkkkt64JXKQNbbn9M1a9bwjW98g7KyMqqqqjj77LMBeOWVV3jiiScA2mqRn3jiCS699NK2gJWamtoj5S8tLeWMM85o26/1vJ/97GdZsGABd9xxB48++ig33HBDp88nMpQ456hvaqGqvonq+iYq/ZBV1Xpff+hyZZ23X1V9E5X1TVTVNVJd39y239GYQWJ0YPiKYExq3CHLHQY1/3F8VARhYf1zdNtgwt1ooDBguQiYe5T9bwReCOLYEf6wzzjnSsxseFAl7obU+CiyU2MpKCrr7acSkSGks9qLvhAfH9/2+NVXX+Xll1/mrbfeIi4ujo9+9KMdzp0VHR3d9jg8PJza2tpOn8e5DufTJiIignfffZdly5bx9NNP86tf/YpXXnmFBx98kHfeeYfnn3+evLw88vPzD/vyLr2vP3xGofc/p9dffz1/+9vfyM3N5fHHH+fVV1894r7OuQ6nHoiIiKClpaVtn4aGhi6V/0jnzc7OZsSIEbzyyiu88847PPXUU0csm8hA0tzi2gLV4aGskar65oDHTf5yY1tAaz2uqr6JxuaOrzGBwgwSoiNIjIkkITqChBgviGWlxLYtx0e3hrODAS65tfYsNpKEfhzOuiuYcNfRK+/wnTezj+GFu9O7euwRn9xsIbAQYMyYMV05tEO5WSm8v6Os2+cREQmlxMREKis77pRdXl7OsGHDiIuLY8OGDbz99ts99rxTp05l+/btbN68ua1P05lnnklVVRU1NTWce+65nHzyyUycOBGALVu2MHfuXObOncs//vEPCgsLFe6GkL7+nFZWVpKZmUljYyNPPfUUo0ePBmDevHk88MAD3HHHHTQ3N1NdXc28efO4+OKL+eIXv0haWhr79+8nNTWVcePGsXLlSi6//HL+/ve/09jY2KXyn3LKKXz+859n27Ztbc0yW2vvbrrpJq6++mquueYawsPDu/16RXqbc47SynoKD9SwY38NO/bVtj3eeaCWAzUN1Bylf1mg2MhwEmIiSIz2wldCdATZqXEkBgQyL7R59223mIP3idGRxESGaU7Iowgm3BUB2QHLWUBx+53MLAd4GJjvnNsXxLG7zSzTr7XLBPZ09OTOuUV4ffiYPXt2l4JhR3KzUnhuVQmllfVkJEZ3foCISD+UlpbGaaedxsyZM4mNjWXEiBFt28455xwefPBBcnJymDJlCieffHKPPW9MTAyPPfYYl112WduAKrfccgv79+9nwYIFbTUXP/vZzwD48pe/zKZNm3DOMW/ePHJzc3usLNL/9fXn9Hvf+x5z585l7NixzJo1qy1Y3nfffSxcuJBHHnmE8PBwHnjgAU455RS+/vWvc+aZZxIeHs7xxx/P448/zs0338yCBQuYM2cO8+bNO6S2LtCRyp+RkcGiRYu45JJLaGlpYfjw4bz00ksAXHjhhdxwww1qkin9SlV9E4X7vcBW2Ho7UMuO/TUUHaihrrHlkP1HJEWTPSyOOeNTSYuPagtfiYcFtEhvW1QE8dHhRIQPqUH6Q8aO1MSmbQezCOADYB6wE1gOXBU4qaqZjQFeAa4N7H93tGPN7MfAvoABVVKdc/97tLLMnj3brVix4mi7dOrdbfu5/Ddv8ch1s5k3bUTnB4iIdGD9+vVMmzYt1MUYFDp6L81spXNudoiKNOB0dH3UZ7T/WbFiBV/84hf5z3/+c8R99O8mPa2xuYWSsrq2GrdDgtyBWvZXNxyyf2uN2pjUWLKHxTEmLY7sYXFkp8aRNSyWmEjVOofa0a6RndbcOeeazOxWYCnedAaP+uHsFn/7g8C3gDTg1341aZNzbvaRjvVPfQ/wJzO7EdgBXNatVxmkmaOTCDMoKCpXuBMREZE+cc899/DAAw+or530OOcc+6ob2kJb0YFaduyraQtzJeV1NLccrMyJCDNGD4tlTGocZ49KZkxqHNmp3nL2sDhS4iLV7HEAC2qyN+fcEmBJu3UPBjy+Cbgp2GP99fvwavT6VFxUBJNHJFJQWNbXTy0i0u99/vOf54033jhk3e23365mZNKvDMTP6V133cVdd90V6mLIAFXb0EzhgcBat4PNJnfsrzms31t6QjRjUmM5ceywttCW7Ye4kUkxaiI5iA3Jmbxzs1JYum7XEUe0EhEZqu6///5QF0GkU/qcykDV3OKoqG2krNabf628tpGymgZvXY2/3LrNX95X3cDeqvpDzhMXFd4W2E45Lo0xqXF+DZzXdDIuakh+xReGaLjLyU7mjysKKdxfy5i0uFAXR0REREQGCOcc1Q3NlNU0HBLCytrCmj9xdm0jZbUNh6yrrDv6/GtxUeEkx0a23calx3H8mJS20NYa4NLio1RBIR0akuGudTLz/KIyhTsRERGRIayusZnNe6rYVV4XUGvWcFhgqwiobWtqOfKAhJHhdkhAG54Yw6ThiW3LKXGH3ifHRrVti4pQc0npniEZ7qaMTCQ6IoxVhWVcmDsq1MURERERkV7mnGN3RT3rSypYv6uCDSWVrC+pYOve6kMGHAEwg8ToCFLiotqC2OhhsaTEHjmYta6LiwpXrZqEzJAMd5HhYcwYlURBUVmoiyIiIjJolZWV8fvf/57Pfe5zXT725z//OQsXLiQuTi1spOvqGpvZtLuqLcitL6lgw65KymoOTko/OiWWaZmJnD1jJNMyk8gaFtsW0BJjIgkPU0CTgWdIhjuA3OwU/vDuDpqaWzRikIgMOL39pXncuHGsWLGC9PT07hRThriysjJ+/etfH/Pn9Oqrr+4X4a6pqYmIiCH7lalfc85RUl7Hhl0VrPdr4jbsqmRraRWtlXGxkeFMGZnI/JkjmToyiWmZSUwZ6TWTFBlshuxfqtysFB57Yzub9lQxLTMp1MUREemSwfKlWQa3u+66iy1btpCXl8cnP/lJhg8fzp/+9Cfq6+u5+OKLufvuu6murubyyy+nqKiI5uZmvvnNb7J7926Ki4v52Mc+Rnp6Ov/61786PP9///d/s3z5cmpra7n00ku5++67AVi+fDm333471dXVREdHs2zZMuLi4vjKV77C0qVLMTNuvvlmbrvttkN+yFixYgV33nknr776Kt/5zncoLi5m+/btpKen88Mf/pBrrrmG6upqAH71q19x6qmnAvCjH/2IJ598krCwMObPn8/NN9/MZZddxnvvvQfApk2buOKKK1i5cmUfvOuDV21DMx/srjwsyJXXHqyNyxoWy7TMJM6dOZKpmV6QG5Map1o4GTKGbrjLTgGgoLBM4U5EuueFu2DX6p4958hZMP+eI27u7S/NgX7605/y6KOPAnDTTTdxxx13dHjuT3/609x1110sXryYiIgIzjrrLO69994ee0ukG0LwGQVv4u41a9aQn5/Piy++yDPPPMO7776Lc44LL7yQ1157jdLSUkaNGsXzzz8PQHl5OcnJyfz0pz/lX//611Frj3/wgx+QmppKc3Mz8+bNY9WqVUydOpVPf/rT/PGPf+Skk06ioqKC2NhYFi1axLZt23j//feJiIhg//79nb7ElStX8vrrrxMbG0tNTQ0vvfQSMTExbNq0iSuvvJIVK1bwwgsv8Le//Y133nmHuLg49u/fT2pqKsnJyeTn55OXl8djjz3G9ddf36W3dyhzzrGzrJYNJQFBblcF2/dWt9XGxUV5tXHn5WQybWQiU/3auKQY1cbJ0DZkw924tDiSYiIoKCrjijljQl0cEZEu6e0vza1WrlzJY489xjvvvINzjrlz53LmmWeydevWw869f/9+nn32WTZs2ICZUVZW1ptvgQwwL774Ii+++CLHH388AFVVVWzatImPfOQj3HnnnXzlK1/h/PPP5yMf+UjQ5/zTn/7EokWLaGpqoqSkhHXr1mFmZGZmctJJJwGQlOT9gPvyyy9zyy23tDWvTE1N7fT8F154IbGxsQA0NjZy6623kp+fT3h4OB988EHbeW+44Ya2mvDW895000089thj/PSnP+WPf/wj7777btCvayipaWhi465KNuyqZEPJwSAXOGXAmNQ4po5M5IKcUUzLTGTqSK82Lky1cSKHGbLhzszIzU6hoLA81EURkYGuk9qL3tYbX5pbvf7661x88cXEx8cDcMkll/Cf//yHc84557BzNzU1ERMTw0033cR5553H+eef36OvU7ohxJ9R8GpjvvrVr/Jf//Vfh21buXIlS5Ys4atf/SpnnXUW3/rWtzo937Zt27j33ntZvnw5w4YN4/rrr6eurg7nXIcjFR5pfUREBC0tLQDU1dUdsq31cw/ws5/9jBEjRlBQUEBLSwsxMTFHPe+nPvUp7r77bj7+8Y9z4oknkpaW1ulrGuz2Vzew8sMDfnNKL8ht31eN82vj4qPCmZqZxIW5o5iWmcS0zESmjEwiIXrIfl0V6bIh/b8lNyuFB/69hdqGZmKjwkNdHBGRY9LTX5rbn7sjkydP7vDc7777LsuWLePpp5/mV7/6Fa+88soxvSYZHBITE6msrATg7LPP5pvf/Caf+cxnSEhIYOfOnURGRtLU1ERqaipXX301CQkJPP7444cce6Qa5oqKCuLj40lOTmb37t288MILfPSjH2Xq1KkUFxezfPlyTjrpJCorK4mNjeWss87iwQcf5KMf/Whbs8zU1FTGjRvHypUrmT9/Pn/5y1+O+FrKy8vJysoiLCyM3/72tzQ3NwNw1lln8d3vfperrrrqkGaZMTExnH322fz3f/83jzzySM++sQNEaWU972zbxztb9/Putv1s3F3Ztm1cWhxTRyaxIM8PciO90SpVGyfSPUM63OVkJdPc4lhXUs6JYztvniEi0l/05pfmQGeccQbXX389d911F845nn32WZ588kmKi4sPO3dVVRU1NTWce+65nHzyyUycOLE33wIZANLS0jjttNOYOXMm8+fP56qrruKUU04BICEhgd/97nds3ryZL3/5y4SFhREZGckDDzwAwMKFC5k/fz6ZmZkd9g3Nzc3l+OOPZ8aMGUyYMIHTTjsNgKioKP74xz9y2223UVtbS2xsLC+//DI33XQTH3zwATk5OURGRnLzzTdz66238u1vf5sbb7yRH/7wh8ydO/eIr+Vzn/scn/rUp/jzn//Mxz72sbZavXPOOYf8/Hxmz55NVFQU5557Lj/84Q8B+MxnPsNf//pXzjrrrB59X/urXeV1vLNtH29v3c872/axtdQbfCYuKpwTxw7jwrxRzBmfyvTMJOJVGyfSK+xIv8r2R7Nnz3YrVqzosfPtqahjzg+X8c3zp3Pj6eN77LwiMvitX7+eadOmhbQMV111FatWrWL+/PlkZWXx8MMPA0f/0jx79mx++ctfcv/99x/xSzMcOhVCRwOqLF269LBzjx49mgULFrQ1jbvzzju57rrrOn0dHb2XZrbSOTe7m2/RkNHR9bE/fEaHunvvvZfy8nK+973vBX3MQPp3KzpQwzt+kHtn234+3FcDeJN/zx43jLkT0pg7PpWZo5OJ1LRTIj3maNfIIR3uAE75v2WcNC6VX1x5fI+eV0QGt4H0Bay/U7jrPoW7/ufiiy9my5YtvPLKK12aL7K//rs55/hwXw3vbtvP235Ty51ltQAkx0Zy0rhUTp6QytzxaUwflaSpB0R60dGukUO+TjwnK5lVRWWhLoaIiIgcwdy5c6mvrz9k3ZNPPsmsWbNCVKLOPfvss6EuQrc459hSWn1In7ldFd6AM2nxUcwZn8rNHxnP3AlpTBmRqL5yIv3EkA93udkpLF27m7KaBlLiokJdHBGRPjUQvzTL0PPOO++EugiDXkuLY9OeqrYw9862/eyt8v42ZCRGM3d8KnMnpHHy+FQmDk/ocIRQEQk9hbusFABWFZVzxuSM0BZGRKSP6UuzyNDU3OLYsKuirc/cu9v2c6CmEYDM5BhOn5jW1mdufHq8wpzIADHkw92srGQACgrLFO5EpEuONL+VBG8g9fseiPQZHVh68/9DU3MLa4sr2mrmlm/fT4U/UXh2aizzpo1gzvhUTh6fRnZqrD43IgPUkA93STGRHJcRT4H63YlIF8TExLBv3z7S0tL0JegYOefYt29f22TQ0rP0GR1Yevr/Q2NzC6uKytvC3MoPD1BV74W58enxnDsrk7n+ACijUmJ75DlFJPSGfLgDr2nma5v26hdOEQlaVlYWRUVFlJaWhrooA1pMTAxZWVmhLsagpM/owNPd/w81DU28tG43/ygo4Y3Ne6lt9CZanzg8gQV5o9qaWY5I0g8qIoOVwh3eoCp/fX8nJeV1+vVKRIISGRnJ+PGaH1P6L31Gh4b6pmb+vbGUxQXFLFu/h9rGZkYmxXDZ7CxOnpDGnPGppCdEh7qYItJHFO7wpkMAWFVUpnAnIiIi/VpTcwtvbd3H4vxi/rl2F5V1TQyLi+SSE0ZzYe4oThqXqqkJRIYohTtgWmYSkeFGfmE558zMDHVxRERkEDCzc4D7gHDgYefcPe22DwMeBY4D6oDPOufW+NtuB24GDHjIOffzPiy69EMtLY73dhxgcUExS1aXsLeqgYToCM6aMYILc0dx2sR0IsPDQl1MEWlpgeo9UF4EZTu8+/IiKC/0bmd8GaYv6LWnDyrcBXGBmgo8BpwAfN05d6+/fgrwx4BdJwDfcs793My+g3fhau0M8DXn3JJuvJZjFhMZzrTMJAoKy0Lx9CIiMsiYWThwP/BJoAhYbmaLnXPrAnb7GpDvnLvYv47eD8wzs5l418c5QAPwTzN73jm3qW9fhYSac461xRX8o6CY51aVsLOsluiIMOZNG86FuaP46JThxESGh7qYIkNLY92hYa01vLUGuYqd0Nxw6DHRyZCc5d0i43u1eJ2GuyAvUPuBLwAXBR7rnNsI5AWcZyfwbMAuP2sNgqGWk5XM394vpqXFqSmDiIh01xxgs3NuK4CZPQ0sAAKvndOB/wNwzm0ws3FmNgKYBrztnKvxj/03cDHwoz4sv4TQltIqFucX849VxWwtrSYizPjIpHTuPHsyn5g2gsSYyFAXUWRwcg5q9kO5H9TKCg8PctXtBqmyMEjM9ILb6BO8WrnkLEjOhpRs73FMcp+9hGBq7jq9QDnn9gB7zOy8o5xnHrDFOfdhN8rba3KzUvjd2zvYureKicMTQ10cEREZ2EYDhQHLRcDcdvsUAJcAr5vZHGAskAWsAX5gZmlALXAusKLXSywhtbOsln8UFPOPgmLWFldgBnPHp3LT6RM4Z+ZIUuOjQl1EkYGvqQEqiwNCW9HhQa6p9tBjIuMOhrWROX5gyz64LmkUhPefH1yCCXfBXKCCcQXwh3brbjWza/EuWl9yzh04hvP2iLzsFAAKCssV7kREpLs6agLSfobqe4D7zCwfWA28DzQ559ab2f8DXgKq8EJg02FPYLYQWAgwZsyYniu59JnSynqWrC5hcUExKz/0vgLlZqfwzfOnc96sTEYma8qCXtdYB1W7oGoPRCdC2iQI15AU/V5LMzTVef9+Te1v9VBXfrDGLbD2rXIXh/0pjh/uBbUR02Hy2QHBLQtSxkDsMBhAU6UF8+kN5gJ19BOYRQEXAl8NWP0A8D3/XN8DfgJ8toNj++TiNSEjgfiocAqKyvjUiZpzSUREuqUIyA5YzgKKA3dwzlUANwCYN8nqNv+Gc+4R4BF/2w/989Hu+EXAIoDZs2d36bosoVNe28jSNbtYXFDMm1v20uJgyohEvnz2FC7IGcWYtLhQF3Hgc877cl+1xwtulbv9+11Qtdu/97fVlR96bEQsjJzp1dBk5nq34dMgQtNJHKalxavlaqr3g1bA46b6Q7c11bfbHhDEDglpQR7XctjvXR0LjzoY1I6b5we2wFq30RA5uH5ECSbcdXqBCsJ84D3n3O7WFYGPzewh4LmODuyri1d4mDErK1mDqoiISE9YDkwys/F4/c2vAK4K3MHMUoAa51wDcBPwmh/4MLPhzrk9ZjYGr+nmKX1ZeOlZNQ1NvLx+D/8oKObfG0tpaG5hTGocn/voRC7IHcWUkT3cYqiuwqupiIyFqHivWVlkHIQN8NE0W1qgZm+7kLb70Metwa190zqAiBhIGAGJIyFjCkw4ExKGQ8JIb33tASgp8G6r/wwrHvGOC4uE4VO9oDfSD3wjZ3rv7VDQVA/7tkDpBtj7gXdf+gHs23T4wCFdER7l/ZtERHuhOiL64HJkrFdjFhETcPO3R7ZbDtwe6Z8nKtELcPEZA/9z30XBhLtOL1BBuJJ2TTLNLNM5V+IvXozXxyCkcrNSeOyN7dQ3NRMdodGnRETk2DjnmszsVmAp3kjTjzrn1prZLf72B/EGTnnCzJrx+rHfGHCKv/h97hqBz4ey24Icm/qmZl77YC//KCjmpXW7qW1sZkRSNNecMpYLckeRm5WM9WRTL+egaAWsfBzW/KXjcNMa8qLivBH7olqX4ztY39n2OIhKOPg4rBvfm5rq/Zq01nAWWNu2+2CAq9oDrvnw42OS/YA2HLLnHAxwCSMhccTBbTHJnTevy/20d9/SAmXbD4a9klWw8QV4/3f+jgbpkw7W7mXmwshZXiAZqOqrvPAWGOBKN8CB7QHvu8GwsZA+BSZ+3AtPnQWtIwW47nxm5Ig6DXfBXKDMbCRev7kkoMXM7gCmO+cqzCwOb6TN/2p36h+ZWR5es8ztHWzvc7nZKTQ0t7ChpJJcvw+eiIjIsfCn91nSbt2DAY/fAiYd4diP9G7ppDc0tzje2rKPxQU7+eeaXVT4k4tffMJoLsgZxZzxqYT39IjctWVeLdPKx2H3Gi9w5X4axp/hhaaGamisgYYaaKz2ltse13jbKorb7VMDLY1dK0dETCdB0F8fEQ3Vew8NcLUd/XZhXnBoDWcjZ/phbeTB2rbEEV6Qi4ztgTeynbAwSJ3g3WZc7K1zznuvSgpg1yrv/sM3vfe/VcpYP+zlQGae9zhheM+Xrztq9kPpRti78WCA2/uB1yetVVgEpB4HI2bAzEsgYyqkT/YCbW+839JjguoxGsQFahdec82Ojq0B0jpYf02XStoHWgPdqqIyhTsRERHplHP+5OL5xTy/ehd7q+qJjwrn7BkjuSB3FKdP6oXJxZ2DouV+Ld1fvVq6UcfDBffBzE95A4N0V3NjB8GwXShsqAp43NG+NV5tW+D6pgaIS/OCWdpxMPZUP7CNOPQ+Lr3/DWxiBsmjvdvUcw+ur94bUMPnB7/1iw9uTxh5aA1fZo7X36s3B+lwzqsF3bvRC3KlGw/WyAUO5R8R6wW2MSdDxnVejVzGVEgd369GgJTg9bP/NaE1KjmG9IQo8gvLuUa9G0REROQIymsaWfSfLfzt/WJ2ltUSFRHGvKnDuSB3FB+f2kuTi9cegFV/8kLdnnVev6K8K+GE62BUXs8+V3gkxKZ4Nzm6+HSYOM+7taorh12rveacraFv80vgWrztscP8ppytA7fkebWEXe0f1tLiDeXfFuBaw9wHUB8wWEx0MmRM9kaDbA1wGZMhecyQ65M22CncBTAzcrNSKCgqC3VRREREpB9qbnH8cXkhP166gfLaRs6YnMH/fHIyZ83opcnFnYPCd7xAt/ZZb6TAUSfABb/wa+kSev45pftikmHc6d6tVUONF8pL8g/243vnwYODkkQleP32AkNfxhQvaDc3wv6th4e4vZsO7V8ZP9w7ZtalBwNcxlSvRnQADecvx07hrp2crBRe2biHyrrG3vkjLSIiIgPSiu37+fbitawtrmDO+FS+c8EMpo9K6p0nqz0ABX/0Ql3per+W7jNw4nXel34ZeKLiIGu2d2vV1OA1lQzsx/feE15TVoDwaG+S7PLCQ4f/T872+sCN+8jBAJc+GeJS+/Y1Sb+jcNdObnYyzsHqneWcelx6qIsjIiIiIba7oo57XtjAs+/vZGRSDL+48nguyMns2dEuwaul2/G2F+jW/c2rpRt9Ilz4S5hxiWrpBqOIKH/wlZyD61qavakHSgq8Wr7yIphxUcCgJpP1WZAjUrhrJzcrBYCCQoU7ERGRoay+qZlHX9/OL1/ZRFOz49aPTeRzHzuOuKge/vpUsx8KnvZC3d6NEJ0Ex1/t9aUL/NIvQ0NYuF8bNxlyLgt1aWSAUbhrZ1h8FGNS41ilfnciIiJD1isbdvPdf6xj+74aPjl9BN84bxpj03pw0mrnYMdbsOIxWPd3aK6H0bNhwf3e0PtDZYJsEelRCncdyM1OYeX2/aEuhoiIiPSxbXur+d5z63hlwx4mZMTz+A0n8dEpPThPWc1+KPiDX0v3gVdLd8K1Xl+6kbN67nlEZEhSuOtAblYy/ygoZk9lHcMTY0JdHBEREell1fVN/PKVzTzy+laiI8L5+rnTuO7UcURF9MAw8c7Bh2/4fen+7o2OmHUSLPi115dKtXQi0kMU7jrQNpl5YTmfmK5wJyIiMlg55/h7fjH/98J6dlfUc+mJWfzvOVN65sfd6n1Q8HtY+VvYt8mba+zE672+dCNndv/8IiLtKNx1YMaoJMLDjIKiMj4xfUSoiyMiIiK9YM3Ocr69eC0rPzxATlYyD1x9IieMGda9kzoH21/3aunWL/Zq6bLnwkcegOkXecPhi4j0EoW7DsRFRTBpeAIFReWhLoqIiIj0sP3VDfx46UaeXr6D1LgofvSpHC49MYuwsG5MbVC9F/J/D+/9FvZt9iaxnv1Zr5ZuxPSeK7yIyFEo3B1BXnYK/1y7C+dcz89jIyIiIn2uqbmFp97ZwU9e3Eh1QzOfPW08X5g3ieTYyGM7oXOw7TW/lu4f0NII2SfDR+6E6QtUSycifU7h7ghyslJ4enkhO/bX9OzQxyIiItLn3tyyl7sXr2Pj7kpOm5jGdy6YwaQRicGfoKEa9m6C0o3eXHSlG6FkFZTv8GrpTrrJG/Fy+LTeexEiIp1QuDuC3OxkAPILyxTuREREBqidZbX88Pn1PL+6hKxhsTx49YmcPWPEkVvl1B6A0g8OBrjWW/mOg/tYOKQdB6Ny4eNf92rpImP75gWJiByFwt0RTB6RSExkGAWF5SzIGx3q4oiIiEgX1DU2s+i1rfz61c0A/M8nJ7PwjAnERIZ7zSkrd0PpBm+uudKNBx9X7T54kogYSJsE2XPghGsgfTJkTIXUCRARFaJXJiJyZAp3RxAZHsaMUcmsKioLdVFEREQkSM45lq7dzfefX8fOA9VcPTWc23ObSa9dCktaw9wGqAsYNC0qETKmwMRPePfpUyBjMqSMhbDw0L0YEZEuUrg7itysFH7/7oc0NbcQEd4Dk5iKiIhIz2tuggPbKNn8Pv95800iDmziscgSJsQXE769Frb7+8Wle+Ft5qcOBriMqZCYCRo8TUQGAYW7o8jNTubRN1r4YHcV00clhbo4IiIiQ1tjnTcZeGs/uL0bofQD3L7NWEsjmcDlQHXcCGJHzyAs4ywvzLXWxsWnhfoViIj0KoW7o8jNSgGgoKhM4U5ERCQUtvwL3vmN15Sy7ENwLd56C8MNG0dxxBheclNY3TiSCdNO4Mr580hNSw9tmUVEQkTh7ijGpsWRHBtJQWEZV84ZE+riiIiIDC3vPgQv/C8kjoKs2ZBzeVstXH5NGt9esoWCHeWcOHYYd184g5mjk0NdYhGRkFK4OwozIycrmYKi8s53FhERkZ7R0gxLvw7vPACTzoZLH4Fob066PZV1/OifG3lm5XsMT4zm55/OY0HeqCNPbSAiMoQo3HUiLzuFX7+6hdqGZmKjNGKWiIhIr6qvgr/cCB/8E07+HJz1fQgLp6Gphd++uZ37lm2ivqmZW848jls/PpGEaH2VERFpFdRfRDM7B7gPCAceds7d0277VOAx4ATg6865ewO2bQcqgWagyTk321+fCvwRGIc3jtXlzrkD3Xs5PS83K4XmFsfa4nJmj0sNdXFEREQGr/Kd8IdPw+61cO69MOdmAF77oJS7/7GWLaXVfHzqcL55/nTGp8eHuLAiIv1Pp+P7m1k4cD8wH5gOXGlm09vtth/4AnAvHfuYcy6vNdj57gKWOecmAcv85X4nJ9trv59fWBbagoiIiAxmxe/DQx+H/dvhqj+3Bbs/vLuDax99l+YWx6PXz+bR609SsBMROYJgJm+bA2x2zm11zjUATwMLAndwzu1xzi0HGrvw3AuA3/qPfwtc1IVj+8zwxBhGJceo352IiEhvWf8cPHYuhEfCjUth0ifaNr24dhcT0uNZ+sUz+PjUESEspIhI/xdMuBsNFAYsF/nrguWAF81spZktDFg/wjlXAuDfD+/COftUTlYKq4rKQl0MERGRwcU5ePOX8MerYfg0uGkZjJgRsNmxqsgbDTM6Qv3eRUQ6E0y462j4KdeF5zjNOXcCXrPOz5vZGV04FjNbaGYrzGxFaWlpVw7tMbnZKXy4r4YD1Q0heX4REZFBp7kRnrsDXvwGTL8QrnsOEg+tmSs6UMu+6gZys1NCUkQRkYEmmHBXBGQHLGcBxcE+gXOu2L/fAzyL18wTYLeZZQL493uOcPwi59xs59zsjIyMYJ+2R+X6/e5W7VTTTBERkW6rLYOnLoWVj8Pp/wOXPg5RcYftVuC3mslTuBMRCUow4W45MMnMxptZFHAFsDiYk5tZvJkltj4GzgLW+JsXA9f5j68D/t6VgvelWaOTMYMCDaoiIiLSPQe2wyNnwfbXYcH98IlvQ1jHX0cKCsuIighjysjEvi2jiMgA1elUCM65JjO7FViKNxXCo865tWZ2i7/9QTMbCawAkoAWM7sDb2TNdOBZf2LRCOD3zrl/+qe+B/iTmd0I7AAu69FX1oMSYyI5LiNB4U5ERKQ7Ct+FP1wJLU1wzbMw/ug9NQoKy5k5KonI8GB+ixYRkaDmuXPOLQGWtFv3YMDjXXjNNdurAHKPcM59wLygSxpiOVnJvPbBXpxz+GFVREREgrX6Gfjb5yBpFHzmz5A+6ai7NzW3sHpnOZ8+Kfuo+4mIyEH6KSxIedkp7K2qp7i8LtRFERERGTicg3//GP5yI4w+0RsRs5NgB7C5tIraxmb1txMR6QKFuyDlZqUA6ncnIiLBMbNzzGyjmW02s7s62D7MzJ41s1Vm9q6ZzQzY9kUzW2tma8zsD2YW07el7yFN9fDsLfCv70POp+Hav0F8WlCHtl5vNVKmiEjwFO6CNDUzkchwaxu5S0RE5EjMLBy4H28aoOnAlWY2vd1uXwPynXM5wLXAff6xo4EvALOdczPx+rtf0Vdl7zE1++GJi2DV0/Cxr8PFv4GI6KAPzy8sJykmgnFph4+iKSIiHVO4C1J0RDjTM5NUcyciIsGYA2x2zm11zjUATwML2u0zHVgG4JzbAIwzs9aJ3iKAWDOLAOLowhRE/cLeTfDwPNi5Ej71CJz5v9DF/uoFhWXkZqeon7uISBco3HVBTlYKa3ZW0NzSlTncRURkCBoNFAYsF/nrAhUAlwCY2RxgLJDlnNsJ3Is3knQJUO6ce7HXS9xTtv0HHv4E1FXAdf+AWZd2+RS1Dc1s3F2p/nYiIl2kcNcFudkpVNU3sbW0KtRFERGR/q2j6qb2vwzeAwwzs3zgNuB9oMnMhuHV8o0HRgHxZnb1YU9gttDMVpjZitLS0h4t/DF7/3fw5MWQMAJuehnGzD2m06wtLqe5xbX1dxcRkeAo3HVBXnYyAPlqmikiIkdXBASO4Z9Fu6aVzrkK59wNzrk8vD53GcA24BPANudcqXOuEfgrcGr7J3DOLXLOzXbOzc7IyOillxGklhZ4+W74++dh3Glw44uQOv6YT9d6nc3xr7siIhIchbsumJCeQEJ0BKuKykNdFBER6d+WA5PMbLyZReENiLI4cAczS/G3AdwEvOacq8BrjnmymcWZ1+FsHrC+D8veNY218MwN8PpP4YTr4DPPQGxKt065qqicUckxDE8cmIOEioiESlCTmIsnLMyYNTpZI2aKiMhROeeazOxWYCneaJePOufWmtkt/vYHgWnAE2bWDKwDbvS3vWNmzwDvAU14zTUXheBldK5qD/zhSm/glLO+D6fc2uWBUzpSUFSmKRBERI6Bwl0X5WQn8+jr26hvaiY6IjzUxRERkX7KObcEWNJu3YMBj98COpzN2zn3beDbvVrA7tq9Dn7/aajZC5/+HUw7v0dOe6C6gQ/31XDlnDE9cj4RkaFEzTK7KC8rhcZmx/qSylAXRUREJDQ2vwyPnAXNDXDDkh4LdkBb6xgNpiIi0nUKd13U2kxE892JiMiQtPxheOpyGDYObn4FRh3fo6cvKCzHDGZlaTAVEZGuUrPMLspMjiE9IVr97kREZGhpaYYXvwFv/xomnQ2XPgLRiT3+NAVFZUwa7g1gJiIiXaO/nF1kZuRlJ6vmTkREho76KvjLTfDBCzD3v+HsH0BYz/c7d85RUFjGx6YO7/Fzi4gMBWqWeQxys1LYureairrGUBdFRESkd5XvhMfOgU1L4dx7Yf49vRLsAHaW1bKvukEjZYqIHCOFu2OQk52Cc7BG892JiMhgVpwPD8+D/dvhqj/DnJt79ekKCr3rap4GUxEROSYKd8cg1+/kna9+dyIiMlhteB4emw9hEXDjUpj0iV5/yoKiMqIiwpgysuf78omIDAUKd8cgJS6KsWlxrCpUzZ2IiAwyzsGbv4KnPwMZU+GmZTBiRp88dX5hGTNGJREVoa8nIiLHQn89j1FuVopGzBQRkcGluRGe+yK8+HWYfiFc/zwkjuiTp25qbmF1UbnmtxMR6QaFu2OUm51CSXkdeyrqQl0UERGR7qsrh6cug5WPwelfhEsfh6i4Pnv6zaVV1DY2k5ut+e1ERI6Vwt0xau13V6BBVUREZKCr2Q+PnAXb/wML7odPfAfC+vYrQusUQ6q5ExE5dgp3x2jGqGTCw0zz3YmIyMAXOwzGngbXPAvHXx2SIhQUlZMUE8G4tPiQPL+IyGCgScyPUWxUOJNHJKrfnYiIDHxmcP5PQ1qEgsIycrNTCAuzkJZDRGQgC6rmzszOMbONZrbZzO7qYPtUM3vLzOrN7M6A9dlm9i8zW29ma83s9oBt3zGznWaW79/O7ZmX1HfyspMpKCzDORfqooiIiAxYdY3NbNhVqSaZIiLd1Gm4M7Nw4H5gPjAduNLMprfbbT/wBeDeduubgC8556YBJwOfb3fsz5xzef5tybG+iFDJzUqhoq6J7ftqQl0UERGRAWttcTnNLY7c7JRQF0VEZEALpuZuDrDZObfVOdcAPA0sCNzBObfHObccaGy3vsQ5957/uBJYD4zukZL3Azn+L4yr1DRTRETkmOX788a2DlYmIiLHJphwNxooDFgu4hgCmpmNA44H3glYfauZrTKzR81sWFfPGWqTRyQQExlGvgZVEREROWYFhWVkJscwPCkm1EURERnQggl3HfVs7lInMzNLAP4C3OGcq/BXPwAcB+QBJcBPjnDsQjNbYWYrSktLu/K0vS4iPIyZo5JZpekQREREjtmqojL1txMR6QHBhLsiIDtgOQsoDvYJzCwSL9g95Zz7a+t659xu51yzc64FeAiv+edhnHOLnHOznXOzMzIygn3aPpObncKaneU0NreEuigiIiIDTllNA9v31ai/nYhIDwgm3C0HJpnZeDOLAq4AFgdzcjMz4BFgvXPup+22ZQYsXgysCa7I/Utudgr1TS1s3FUZ6qKIiIgMOAV+65fcbPW3ExHprk7nuXPONZnZrcBSIBx41Dm31sxu8bc/aGYjgRVAEtBiZnfgjayZA1wDrDazfP+UX/NHxvyRmeXhNfHcDvxXD76uPtPa+XtVUTkzR+vCJCIi0hUFhWWYwSxdQ0VEui2oScz9MLak3boHAx7vwmuu2d7rdNxnD+fcNcEXs/8akxpHSlwkBYVlXDV3TKiLIyIiMqAUFJYxMSOBxJjIUBdFRGTAC2oSczkyMyMnK4UCTYcgIiLSJc45CorK1N9ORKSHKNz1gLysZD7YXUlNQ1OoiyIiIjJg7CyrZW9Vg+a3ExHpIQp3PSA3O4UWB2t2VnS+s4iIiAC0TSWkmjsRkZ6hcNcDcvy5eVapaaaIiEjQCgrLiAoPY+rIpFAXRURkUFC46wEZidGMToklv7As1EUREREZMPILy5g+KomoCH0dERHpCfpr2kNyspLbmpeIiIjI0TW3OFbvLCdPTTJFRHqMwl0Pyc1OYcf+GvZXN4S6KCIiIv3e5j1V1DQ0a/JyEZEepHDXQ3L9fneaEkFERKRzBX5Xhtbrp4iIdJ/CXQ+ZlZWMGawqVNNMERGRzuQXlZEYE8G4tPhQF0VEZNBQuOshCdERTMxIUM2diIhIEFYVlZGblUJYmIW6KCIig4bCXQ/KzU5hVVEZzrlQF0VERELMzM4xs41mttnM7upg+zAze9bMVpnZu2Y2018/xczyA24VZnZHn7+AXlTX2MyGkkr1txMR6WEKdz0oNyuZvVUN7CyrDXVRREQkhMwsHLgfmA9MB640s+ntdvsakO+cywGuBe4DcM5tdM7lOefygBOBGuDZvip7X1hbXEFTi1N/OxGRHqZw14Ny/eGcC9TvTkRkqJsDbHbObXXONQBPAwva7TMdWAbgnNsAjDOzEe32mQdscc592NsF7kutg6loGgQRkZ6lcNeDpo5MIio8jFXqdyciMtSNBgoDlov8dYEKgEsAzGwOMBbIarfPFcAfeqmMIVNQVEZmcgzDk2JCXRQRkUFF4a4HRUWEMW1UEvn+L5IiIjJkdTRKSPsO2fcAw8wsH7gNeB9oajuBWRRwIfDnDp/AbKGZrTCzFaWlpT1S6L5SUFimJpkiIr1A4a6H5WUls2ZnOc0tGlRFRGQIKwKyA5azgOLAHZxzFc65G/y+ddcCGcC2gF3mA+8553Z39ATOuUXOudnOudkZGRk9WvjeVFbTwPZ9NeRoMBURkR6ncNfDcrJSqG5oZktpVaiLIiIiobMcmGRm4/0auCuAxYE7mFmKvw3gJuA151xFwC5XMgibZK4q8vql56nmTkSkxync9bDWQVXUNFNEZOhyzjUBtwJLgfXAn5xza83sFjO7xd9tGrDWzDbg1dLd3nq8mcUBnwT+2rcl730FhWWYwcws1dyJiPS0iFAXYLCZkB5PYnQEq4rKuHx2ducHiIjIoOScWwIsabfuwYDHbwGTjnBsDZDWqwUMkYKiMo7LSCApJjLURRERGXRUc9fDwsKMWVnJmg5BRESkHecc+YXlGkxFRKSXKNz1gtzsFNaXVFDX2BzqooiIiPQbxeV17K2qJ0+DqYiI9AqFu16Qm5VMU4tjfUlF5zuLiIgMEa2Tl+dq8nIRkV6hcNcLWi9aBRpURUREpE1BURlR4WFMHZkU6qKIiAxKQYU7MzvHzDaa2WYzu6uD7VPN7C0zqzezO4M51sxSzewlM9vk3w/r/svpH0YmxZCRGN023LOIiIh4P3pOG5VEVIR+WxYR6Q2d/nU1s3DgfrxhmqcDV5rZ9Ha77Qe+ANzbhWPvApY55yYBy/zlQcHMyM1KIb+oLNRFERER6ReaWxyri8rJ0xQIIiK9JpifzuYAm51zW51zDcDTwILAHZxze5xzy4HGLhy7APit//i3wEXH9hL6p7zsZLaWVlNe2/4tERERGXq2lFZR3dCs/nYiIr0omHA3GigMWC7y1wXjaMeOcM6VAPj3wzs6gZktNLMVZraitLQ0yKcNvRx/mOc1O9U0U0REJF+DqYiI9Lpgwp11sM4Fef7uHOvt7Nwi59xs59zsjIyMrhwaUjl+s5N8DaoiIiJCQWEZiTERjE+LD3VRREQGrWDCXRGQHbCcBRQHef6jHbvbzDIB/Ps9QZ5zQEiJi2JcWhyr1O9ORESEgqIycrKSCQvr6HdfERHpCcGEu+XAJDMbb2ZRwBXA4iDPf7RjFwPX+Y+vA/4efLEHhtzsFAoK1SxTRESGtrrGZjaUVJLrd1kQEZHe0Wm4c841AbcCS4H1wJ+cc2vN7BYzuwXAzEaaWRHwP8A3zKzIzJKOdKx/6nuAT5rZJuCT/vKgkpuVwq6KOnZX1IW6KCIiIiGzrqSCphan/nYiIr0sIpidnHNLgCXt1j0Y8HgXXpPLoI711+8D5nWlsANNbrbX766gsIyzZowMcWlERERCo8Dvf56ncCci0qs0i2gvmjEqmfAwo0D97kREZAgrKCxjZFIMI5JiQl0UEZFBTeGuF8VEhjN1ZCKritTvTkREhq6CovK21iwiItJ7FO56WU5WCgWFZbS0dGkGCBERkUGhvKaRbXur1d9ORKQPKNz1srzsZCrqmti+rzrURREREelzq3aWAWikTBGRPqBw18ty/IuZmmaKiMhQ1DqYyqwsNcsUEeltCne9bNLwBGIjw8n3L24iIiJDSX5hOcdlxJMUExnqooiIDHoKd70sIjyMWaOTWaURM0VEZIhxzpFfWKb+diIifUThrg/kZCWzpriCxuaWUBdFRESkz5SU17G3ql7z24mI9BGFuz6Qm51CQ1MLG3dVhrooIiIifaa1v50GUxER6RsKd32g9aKmycxFRGQoyS8qIyo8jKmZiaEuiojIkKBw1weyU2MZFhfZ9gumiIjIULCqsJxpmYlER4SHuigiIkOCwl0fMDNys1M0HYKIiAwZzS2O1TvLNZiKiEgfUrjrIzlZKXywu5Lq+qZQF0VERKTXbS2toqq+Sf3tRET6kMJdH8nLTqbFwZqdqr0TEZHBr3V+V9XciYj0HYW7PpLj/3KpppkiIjIUFBSVkRgdwYT0+FAXRURkyFC46yPpCdGMToklXyNmiojIEFBQWE5OdjJhYRbqooiIDBkKd30oLztFI2aKiMigV9fYzPqSCvW3ExHpYwp3fSgnK5miA7Xsq6oPdVFERER6zfqSCppaXFuXBBER6RsKd32otVO5+t2JiMhg1tpKJU+DqYiI9CmFuz40c3QyZl4ncxERkcGqoKicEUnRjEyOCXVRRESGFIW7PpQQHcGk4QnqdyciMgSY2TlmttHMNpvZXR1sH2Zmz5rZKjN718xmBmxLMbNnzGyDma03s1P6tvTdU1BYpv52IiIhoHDXx3KzUigoKsc5F+qiiIhILzGzcOB+YD4wHbjSzKa32+1rQL5zLge4FrgvYNt9wD+dc1OBXGB975e6Z5TXNLJ1b7XmtxMRCYGgwl0Qvz6amf3C377KzE7w108xs/yAW4WZ3eFv+46Z7QzYdm6PvrJ+Kic7hf3VDRQdqA11UUREpPfMATY757Y65xqAp4EF7faZDiwDcM5tAMaZ2QgzSwLOAB7xtzU458r6rOTdtGpnGaD+diIiodBpuAvy18f5wCT/thB4AMA5t9E5l+ecywNOBGqAZwOO+1nrdufcku6+mE61NEPlrl5/mqPJ85upqN+diMigNhooDFgu8tcFKgAuATCzOcBYIAuYAJQCj5nZ+2b2sJkdNhO4mS00sxVmtqK0tLQ3XsMxae16MCsrObQFEREZgoKpuQvm18cFwBPO8zaQYmaZ7faZB2xxzn3Y7VIfq1fvgQdPh8J3Q1aEKSMTiQoP04iZIiKDW0czd7dvj38PMMzM8oHbgPeBJiACOAF4wDl3PFANHNZqxjm3yDk32zk3OyMjoyfL3i0FReVMyIgnKSYy1EURERlyggl3wfz6GMw+VwB/aLfuVr8Z56NmNiyIsnTPrMsgKgEePx9W/anXn64jURFhTB+VRL4GVRERGcyKgOyA5SygOHAH51yFc+4Gv3XLtUAGsM0/tsg5946/6zN4Ya/fc86RX1jW1kpFRET6VjDhLphfH4+6j5lFARcCfw7Y/gBwHJAHlAA/6fDJe7LZScZkuPkVyDoJ/nozLPsetLR075zHIC87hTU7y2lu0aAqIiKD1HJgkpmN96+BVwCLA3fwR8SM8hdvAl7zA98uoNDMpvjb5gHr+qrg3bGroo7SynoNpiIiEiLBhLtOf30MYp/5wHvOud2tK5xzu51zzc65FuAhvOafh+nxZidxqXDNs3D8NfCfe+GZ66Ghpvvn7YKcrGRqGprZvKeqT59XRET6hnOuCbgVWIo30uWfnHNrzewWM7vF320asNbMNuBdJ28POMVtwFNmtgrvR9Af9lnhu6G1v53CnYhIaEQEsU/br4/ATrxfH69qt89ivCaWTwNzgXLnXEnA9itp1yTTzDID9rkYWHMM5T82EVFw4S8hYyq8+A048CFc+QdIGtUnT9960SsoLGPKyMQ+eU4REelb/kBhS9qtezDg8Vt4A5F1dGw+MLs3y9cb8gvLiQw3pmXq2iYiEgqd1twF+evjEmArsBmvFu5zrcebWRzwSeCv7U79IzNb7f8q+THgi919MV1iBqfeClc+Dfs2w0Mfh+L3++Spx6fFkxgToREzRURkUCkoLGN6ZhLREeGhLoqIyJAUTM1dML8+OuDzRzi2BkjrYP01XSppb5lyDnx2KfzhCnh0Plz8IMy4qFefMizMyMlKVrgTEZFBo6XFsXpnORcf3348NRER6StBTWI+6I2c6Q20MnIW/Pk6eO3H4Hp3sJPcrBQ2lFRS19jcq88jIiLSF7buraKqvkn97UREQkjhrlXCcLjuHzDrcnjl+/DXhdBY12tPl5OVQlOLY11JRa89h4iISF/JL/Tmb83L1uTlIiKhonAXKDIGLlkEH/8mrP4T/PYCqNrTK0+VFzCoioiIyEBXUFhGQnQEE9ITQl0UEZEhS+GuPTM44064/AnYtdobaGVXzw/kOTI5hhFJ0awqKu/xc4uIiPS1gqIycrKSCQvraOpbERHpCwp3RzJ9AXz2BWhpgkfPho0v9PhT5GSlqOZOREQGvLrGZtaXVKi/nYhIiCncHc2o472BVtImwh+uhDd/2aMDreRlp7B1bzXltY09dk4REZG+tr6kgsZmR25WSqiLIiIypCncdSZpFNzwAky/0JvwfPFt0NTQI6fOyfI6na9W00wRERnAWrsY5GowFRGRkFK4C0ZUHFz6OJzxZXj/SXjyYqjZ3+3T5oxOAdB8dyIiMqAVFJYxPDGakUkxoS6KiMiQpnAXrLAw+Pg34JKHoGi5N9BK6QfdOmVyXCQT0uPV705ERAa0/KIycrNTMNNgKiIioaRw11U5l8P1z0FDFTz8Cdi8rHuny0pWzZ2IiAxY5bWNbC2tbpviR0REQkfh7lhkz/EGWknOgqcug3cfOuZT5WansLuinl3lvTdhuoiISG9p7TeuwVREREJP4e5YpYyBG5fCpE/Ckjvh+TuhuanLp8nxL4aqvRMRkYGo9fo1K0uDqYiIhJrCXXdEJ8IVv4dTb4PlD8HvL4Pasi6dYsaoJCLCTP3uRERkQMovLGNCRjzJsZGhLoqIyJCncNddYeFw1vfhwl/CttfgkU/Cvi1BHx4TGc7UzETV3ImIyIC0qqhMTTJFRPoJhbuecsK1cO3foboUHp4H218P+tCcrBRWFZXT0tJzE6SLiIj0tl3ldeyuqCdXTTJFRPoFhbueNO50uGkZxGfAExfBe08GdVheVgqVdU1s21fdu+UTERHpQfl+l4JcjZQpItIvKNz1tLTj4MaXYPxHYPGt8OI3oKX5qIfkZHu/eD76+jbqGo++r4iISH9RUFRGZLgxLTMp1EUREREU7npHbApc9Wc46WZ485fw9GegvvKIu08ensgVJ2Xz1Ds7OOtnr/GvjXv6rqwiIiLHqKCwjGmZScREhoe6KCIigsJd7wmPgPPuhXPvhU0vwqPnQNmODncNCzPu+VQOv795LpHhxg2PLee/f7eSkvLaPi60iIhIcFpaHKuLyjWYiohIP6Jw19vm3Ayf+TOUFcJDH4fCd4+466nHpfPC7Wfw5bOn8MqGPcz7yb956LWtNDa39GGBRUREOrd1bzWV9U3qbyci0o8o3PWFifPgppchKgEePx9W/emIu0ZFhPH5j03k5f85k5MnpPGDJeu54Jevs2L7/j4ssIiIyNG1zs+qkTJFRPoPhbu+kjEZbn4Fsk6Cv94Mr3wfWo5cI5edGscj183mN9ecSEVtI5c++Bb/+0wB+6sb+rDQIiIiHSsoKiMhOoIJGQmhLoqIiPgU7vpSXCpc8ywcfw289mN45npoqDni7mbG2TNG8vKXzuS/zpzAX9/bycd/8ip/XL5Dc+KJiEhIFRSWMWt0MuFhFuqiiIiIL6hwZ2bnmNlGM9tsZnd1sN3M7Bf+9lVmdkLAtu1mttrM8s1sRcD6VDN7ycw2+ffDeuYl9XMRUXDhL+Gs78O6xfDYfKgoPuohcVERfHX+NJ7/wkeYPDyRr/xlNZf95i3Wl1T0UaFFREQOqm9qZl1JhfrbiYj0M52GOzMLB+4H5gPTgSvNbHq73eYDk/zbQuCBdts/5pzLc87NDlh3F7DMOTcJWOYvDw1mcOptcOXTsG+zN9BK8fudHjZlZCJ//K+TufeyXLbtreb8X77O959bR1V9Ux8UWkRExLO+pJLGZkdetvrbiYj0J8HU3M0BNjvntjrnGoCngQXt9lkAPOE8bwMpZpbZyXkXAL/1H/8WuCj4Yg8SU86Bzy6FsAh4dD786VpY9j0oeBqKVkJd+WGHmBmXnpjFK186k8tnZ/Pw69v4xE/+zZLVJTinppoiItL72gZTUc2diEi/EhHEPqOBwoDlImBuEPuMBkoAB7xoZg74jXNukb/PCOdcCYBzrsTMhnf05Ga2EK82kDFjxgRR3AFm5ExvoJWlX4Od78H658A1H9yeMALSJkH6RP9+EqRNJCVlLP93ySwum53F159dw+eeeo8zJ2fw3QUzGJsWH7rXIyIig15BURnDE6MZmRQT6qKIiEiAYMJdRz2l21cRHW2f05xzxX54e8nMNjjnXgu2gH4YXAQwe/bswVk1lTAcPvWw97ipAQ5sh32bYK9/27fJ659XGzAdQlgkpE7ghPRJPD9tIq9nDuM3a7dy6c8+5OqP5nHLRycQHREekpcjIiKDW0FhGTlZKZhpMBURkf4kmHBXBGQHLGcB7UcAOeI+zrnW+z1m9ixeM8/XgN1mlunX2mUCe47tJQwyEVHetAkZkw/fVrP/YNjbu8nrr7d3E2EfLOWMlkbOMCAC9v0nkQ/ezCJj/ExGjp/p1falT4Zh4yA8sq9fkYiIDCIVdY1sKa3m4uNHh7ooIiLSTjDhbjkwyczGAzuBK4Cr2u2zGLjVzJ7Ga7JZ7oe2eCDMOVfpPz4L+G7AMdcB9/j3f+/2qxns4lJhzFzvFqi5Cco+bAt+DVtW07xtDeGbX4Qtfz64X1iEF/AOa+Y5CeLTvYFeREREjmJ1kdcfXP3tRET6n07DnXOuycxuBZYC4cCjzrm1ZnaLv/1BYAlwLrAZqAFu8A8fATzrN9uIAH7vnPunv+0e4E9mdiOwA7isx17VUBMeAWnHeTfOIfNUGNbYzIP/3sKTr65iUvhuPjezmdNTDhC2f7NX47flFWiuP3iOmJSDQS8w+KVOgIjoUL0yEZEBy8zOAe7Du3Y+7Jy7p932YcCjwHFAHfBZ59waf9t2oBJoBprajTYdUvn+YCo5o1NCWg4RETlcMDV3OOeW4AW4wHUPBjx2wOc7OG4rkHuEc+4D5nWlsBK8mMhw7vjEZC7KG823Fq/l2pWlzBiVxPcv+gLHjxkGLc1QXgh7N8PeDw429dz6Lyj4/cETWRgMnwHTLoAZF0HGlJC9JhGRgSJgGqFP4nVdWG5mi51z6wJ2+xqQ75y72Mym+vsHXhc/5pzb22eFDlJBYRkT0uNJjlMzfxGR/iaocCcD17j0eH57w0ksWb2L7z63lkseeJMr54zhK2dPJXnYOK+Z5qRPHHpQfaXfn88Pfttfh1f/D179IWRMhekLYPpFMHyamnKKiHSsbRohAL/bwgIgMNxNB/4PwDm3wczGmdkI59zuPi9tFxQUlXHqcemhLoaIiHRA4W4IMDPOy8nkjMnp/OylTTz+5jaWrtnFV8+dxqdOGH34aGfRiTDqeO/WqqIENjwH6/4Or/0Y/v3/vKabMy7ywt6ImQp6IiIHBTONUAFwCfC6mc0BxuINSLabI08jFFK7yuvYXVFPbpYmLxcR6Y+CmcRcBonEmEi+dcF0/nHb6YxNi+POPxfw6UVv88Huys4PTsqEOTfD9c/BlzbCeT+FpFHwn5/Ag6fDL0+Al++G4nzQZOoiIsFMI3QPMMzM8oHbgPeBJn/bac65E4D5wOfN7IzDnsBsoZmtMLMVpaWlPVfyoygoKgMgR4OpiIj0Swp3Q9CMUck8c8up3HPJLD7YXcm59/2He17YQE1DU+cHgzcv30k3wnWL4c5NcMF9XvPON+6DRWfCL/LgpW/BzpUKeiIyVHU6jZBzrsI5d4NzLg+4FsgAtvnb2qYRAlqnEaLd8Yucc7Odc7MzMjJ65UW0V1BYRkSYMT0zqU+eT0REukbhbogKCzOumDOGZf9zJhcfP5oH/72FT/70NV5cu6trJ4pPhxOvh2uehS9vhgt/5TXXfOt+eOjj8PMcWPp1KFwOLS298lpERPqhtmmEzCwKbxqhxYE7mFmKvw3gJuA151yFmcWbWaK/T+s0Qmv6sOxHVFBUxrTMJGIiw0NdFBER6YD63A1xaQnR/PiyXC4/KZtvPLuGhU+u5BPThvPtC2aQnRrXtZPFpcIJ13i32gOw8QWvj967i+CtX0HSaH8wlgWQNQfC9NvCkOUcVBR7tbvF70HJKu+HgrGnwbjTvSk41IdTBrAgpxGaBjxhZs14A63c6B9+tGmEQqalxbGqsJwFx48KdVFEROQIzA2gZnOzZ892K1asCHUxBq3G5hYee2MbP395Ey3OcdvHJ3HzRyYQFdHNEFZXDhv/6QW9zS978+sljITpF3qjbo45GcL0K/CgVrPfC3E7/Vvxe1DlDwgYFgEZ06BqF1T7/YYSM/2gdxqMPd2bc1Fhb0gxs5X9aW63/q4vro+b91TxiZ/+mx9fmsNls7M7P0BERHrF0a6RqrmTNpHhYSw84zjOzxnF3f9Yy4+XbuSv7xXx9fOmccakDCLCjzHkxSRD7qe9W10FbHoR1v0N3nvCq9WLH35wHr0xp3qTssvA1VANJQV+kPNr5g5sP7g9fTJM+BiMPgFGnQAjZ0FkjFebt3cTfPg6bH/Dm4JjzTPeMfHD/aDn1+xlTFXYE+ljBf7k5XkaTEVEpN9SzZ0c0SsbdvOtv6+l6EAtqfFRnDNzJOfPymTuhDTCw3rgi3V9lR/0/u7dN9ZAXDpMO9+r0Rv3EQW9/q6pAfasPVgbt/M9KN0Azu9fmZTlhbjWIDcqzwv7wXAO9m/1Qt6Hftir2Olti0s7GPTGngbDp6uZ7yCjmruu6Yvr47f/voZnVhax6jtn98w1QEREjsnRrpEKd3JUdY3NvLqxlOdXl7Bs/W5qGppJT4hi/sxMzsvJ5KRxqT1zkW+o9ppsrvu714SzsRpiU2HqeV6N3vgzITyy+88jx66lBfZtCghyK2HXGq+ZLXj/XqNPPBjkRp/gjazaU5zzagA/fMOr2fvwdSjb4T/3MK/Wd9zpXg3fiJlq6jvAKdx1TV9cHxfc/waxkWE8vfCUXn0eERE5OoU76RG1Dc28unEPz60qYdmG3dQ1tjA8MZpzZ3lB78QxwwjriaDXWAubl/lB7wVoqISYFC/oTb8IJnwUIqI6OYl0i3NQXnhojVxxvvdvARAZ79XCBQa5lLF931SybMfBoLf9DTiwzVsfnQxjTzlYuzcyR7XAA4zCXdf09vWxvqmZWd9+kRtOH8dX50/rtecREZHOqc+d9IjYqHDmz8pk/qxMahqaeGXDHp5fVcIf3t3B429uZ0SSF/TOz8nk+OxuBL3IWK9p5rTzobEOtv7LC3rrn4P8p7wv7lPmw9RzvRqb/igswnsdkXHefUSsvxzbP2uUqve2C3LvHRzcJCwSRs6EnMsP1sylT+4fryNlDOSNgbwrveXynQebcH74BnzgDzAYlegN3DPuNK+5b2auaoJFumBDSSUNzS3kZaWEuigiInIUCndyTOKiIjg/ZxTn54yiqr6JZet38/yqEp56ZwePvbGdUckxbTV6edkp2LHW6ETGeEFuynxoqoet//aC3obnYNXTPfui+kp49MGgd0gAjDn4ODLOe+1ty4EBMa7d8f66iHb7Hym81Fd6tXCBQa61eSMGGVNg4icP9pUbMRMiovvq3eme5NFeCM253Fuu3BXQZ+8NePklb31kPIyZe7Bmb9QJqg0WOYqCojIAcjWYiohIv6ZmmdKjKusaedkPev/+oJTGZsfolFjOz/GC3qzRycce9AI1N3oBpbW/V3/T3Og1L22sgaa6g48bawNu/nJTXcC2Gq+28pD9a4Bj+H8aFnF4OGxpgn2bD54vZczBZpWtA55EJ/bgG9HPVO0J6LP3BuxZ562PiIXsOQcHaMma3TuB1jlobmj3714HTbUH/91bHzcFfFbaPkOB2+u8QB+dGHBLarfcbl1k7IAZZVTNMrumt6+P//OnfP6zaS/vfm1ez/wNFxGRY6ZmmdJnEmMiufj4LC4+Povy2kZeWreb51cV88jr2/jNa1sZkxrHeTmZnDcrkxmjko79S0J4JGSf1LOF76/aAkH7cBgYENstN7YLjK1BwjmYdenBQBefHupX17cShsOMi70bQPU+2PGmV7u3/Q341w8B59WuZs/xgt7Imf77Hxi4AgNZQIAPJrC1jiTaVeFRATW4MV6wa6rzamLrKqClsfNzWDhEJxwhBB4pHLZbH5Xg3feHZrnSZ1YVlZOb1Y1WGCIi0icU7qTXJMdGcumJWVx6YhZlNQ28uG43z60qYdFrW3ng1S2MS/OC3vk5o5g6MlFfGo7EzKtFiojuv30MB6r4NG+OxWkXeMu1B+DDt/zavf/Aaz86chiz8HbNaf3AFRkLUXHedA2tTWtb17feH/Y49gj7Bpy3szDVVO8FvfoK/77Sm27ksHWVh66r2Q8HPjy4vrE6uPcuMv7IQTB5NHz8G8H/O0i/VlHXyJbSKhbkjgp1UUREpBMKd9InUuKiuHx2NpfPzmZ/dQMvrt3Fc6tKeODVLdz/ry1MyIjn/FmZnJcziikjB3GzQOnfYod5A/VMPddbriv35tprC1wBIay/DcjS+gNAd2tjW5qhoarjINhhQAzYt3qvd5+QoXA3iKwpKsc59bcTERkIFO6kz6XGR3HFnDFcMWcM+6rq+efaXTy/qoRf/Wszv3hlM5OGJ/g1eplMHK6gJyEUkwyjjg91KfpWWLj3uoOdbF4GvXx/MJWcLH0mRET6O4U7Cam0hGg+M3csn5k7ltLKev65poTnVpVw37JN/PzlTUwdmch5/qibEzISQl1cEZEhp6CwjPHp8aTEaURZEZH+TuFO+o2MxGiuOWUc15wyjj0VdSxZXcLzq0v4yUsf8JOXPmBaZpI36uasTMalx4e6uCIiQ0JBYTknT0gNdTFERCQICnfSLw1PiuH608Zz/WnjKSmv5YXVu3huVTE/XrqRHy/dyMzRSZw7K5NTJqQxY1QyURFhoS6yiMigs7uijl0VdepvJyIyQCjcSb+XmRzLZ08fz2dPH8/OslpeWO013fzRPzcCEBURxqzRyZwwJoUTxgzjhLHDGJEUE+JSi4gMfAWFZYAGUxERGSiCCndmdg5wHxAOPOycu6fddvO3nwvUANc7594zs2zgCWAk0AIscs7d5x/zHeBmoNQ/zdecc0u6/YpkUBudEstNH5nATR+ZwO6KOt778ADv7TjAezvK+O1bH/LQf7a17Xf8mBSOHzOME8akqHZPROQYFBSVERFmTM9MCnVRREQkCJ2GOzMLB+4HPgkUAcvNbLFzbl3AbvOBSf5tLvCAf98EfMkPeonASjN7KeDYnznn7u25lyNDyYikGObPymT+rEwA6puaWVdcwXs7yrzA9+EBnltVAqh2T0TkWBQUljM1M5GYSE1aLyIyEARTczcH2Oyc2wpgZk8DC4DAcLcAeMI554C3zSzFzDKdcyVACYBzrtLM1gOj2x0r0iOiI8I5fswwjh8zjBsZD8Cu8rq2oPfejgP89s3Da/daw970zCTV7omI+FpaHAVFZVyoyctFRAaMYMLdaKAwYLkIr1aus31G4wc7ADMbBxwPvBOw361mdi2wAq+G70DQJRcJwsjkGM6dlcm5AbV7a4sreO/DA7y/o4yVAbV70a21e2OHtdXwDVftnogMUdv2VVNZ16T+diIiA0gw4c46WOe6so+ZJQB/Ae5wzlX4qx8Avufv9z3gJ8BnD3tys4XAQoAxY8YEUVyRI4uOCPdq6sYMa1tXUl7Lex/6TTl3HODxN7az6LUWwKvdCwx700clERmu2j0RGfxaB1PJU7gTERkwggl3RUB2wHIWUBzsPmYWiRfsnnLO/bV1B+fc7tbHZvYQ8FxHT+6cWwQsApg9e3b7UCnSbZnJsZyXE8t5OQdr99bsrOB9P+wt37affxR4H/noiDByspI5wW/+ecLYFIYnqnZPRAafgsIy4qPCOS4jIdRFERGRIAUT7pYDk8xsPLATuAK4qt0+i/GaWD6N12Sz3DlX4o+i+Qiw3jn308ADAvrkAVwMrOnG6xDpMdER4Zw4dhgnjj1Yu1dcVuv33fNq+B59YxuNr20FIGtYrF8bmMIJY4cxLVO1eyIy8BUUlTMrK5nwsI4a54iISH/UabhzzjWZ2a3AUrypEB51zq01s1v87Q8CS/CmQdiMNxXCDf7hpwHXAKvNLN9f1zrlwY/MLA+vWeZ24L966DWJ9LhRKbGMSonl/BxvYIG6xmbWFpe3hb13tu1jsV+7FxMZxpSRSUwansDE4QlMzPDus1Pj9CVJRAaEhqYW1hVXcMNp40JdFBER6YKg5rnzw9iSduseDHjsgM93cNzrdNwfD+fcNV0qqUg/EhMZzoljUzlxbCoAzjmKyw/Ou7ehpJJ/f1DKMyuL2o6JighjQnq8F/j826ThiYxLjyM6QsOMi0j/sWFXBQ3NLRpMRURkgAkq3InI0ZkZo1NiGZ0SywUBw4aX1zSyubSSzXuq2m4FRWU8v7oE5/cgDQ8zxqTGcVxGApNGHKzpO254AgnR+i8qIn2vdTAVhTsRkYFF3xxFelFyXOQhNXytahua2VJaxZbSg6Fv054qXt24h6aWg+MGZSbHHFLT1xr80hKi+/qliMgQkl9YTnpCNKOSNWCUiMhAonAnEgKxUeHMHJ3MzNHJh6xvbG7hw301bN7jBb9NuyvZXFrF0+8WUtvY3LZfanwUEzO82j2vead3n5kcgzeOkYjIsSsoKiMvO1l/T0REBhiFO5F+JDI8rK2WLlBLi6O4vPaQ5p2b91TxwpoSymoa2/aLjwr3Al9GAhMDmniOSY0jQiN4ikgQKusa2VJaxYKAJuYiIjIwKNyJDABhYUbWsDiyhsXx0SnD29Y759hX3dDWrHOLH/re3LKPv76/s22/qPAwxqV7/fpGJMUwIimG4YnR/uNohifGkBQboV/pRYTVO8txTv3tREQGIoU7kQHMzEhPiCY9IZqTJ6Qdsq2irrEt7G0u9YLfxt2V/GfTXqrqmw47V3REGMOTohmR6IW/DD/8KQSKDC0FheUA5GQld7KniIj0Nwp3IoNUUkwkx48ZxvFjhh22rbq+iT2V9eypqGO3f7+nsp7dFXXsqahn/a4K/v1BvUKgyBBUUFjGuLQ4UuKiQl0UERHpIoU7kSEoPjqC8dERjE+PP+p+rSFwtx/+AkPg7oo6hUCRQaigqIw541M731FERPodhTsROaLuhMDdAUEwmBCY4TcvTU/07jMSorz7xIPr46PCFQRlwDCzc4D7gHDgYefcPe22DwMeBY4D6oDPOufWBGwPB1YAO51z5/dFmXdX1FFSXkduVkpfPJ2IiPQwhTsR6bbuhsDdFfXsrapn+75qVnx4gP3VDR0eHxMZ1tbH0At+Ue2Wo0lPiCI9MZrEaNUISuj4wex+4JNAEbDczBY759YF7PY1IN85d7GZTfX3nxew/XZgPZDUR8XW5OUiIgOcwp2I9JlgQ2BTcwv7qxvYU+mFvr1VDd59wHLRgRryC70gGDDve5uoiDC/NjAgAAaEwYy2GsJoNQ2V3jAH2Oyc2wpgZk8DC4DAcDcd+D8A59wGMxtnZiOcc7vNLAs4D/gB8D99VehVReVEhBkzRvVZnhQRkR6kcCci/U5EeBjDk2IYnhTT6b7NLY791X74a71VesulfhAsKa9j9c5y9lU30NxBEowKDyOtLQRGHdI8ND0hiuTYSFLi/PvYSJJiIwkPUxiUoxoNFAYsFwFz2+1TAFwCvG5mc4CxQBawG/g58L9AYq+XNLBARWVMGZlITGR4Xz6tiIj0EIU7ERnQwsOMjESvJq4zLS2OAzUNB2sCq+opraw/dLmqnvUlleyrrqexuYMqQV9iTATJsZF+8Iv0H0cdspzib09uXY6LUr/BoaOjf+T2H6h7gPvMLB9YDbwPNJnZ+cAe59xKM/voEZ/AbCGwEGDMmDHdLnBLi6OgsIzzNXm5iMiApXAnIkNGWJiRlhBNWkI0UzqpEHHOUV7byN6qBsprG6mobaSstoHymkbKahspr22kvMa7L6ttZHdFFWU1jZTXNhw1FEaEWVsoTD4sBEYdstwWGv376AjVpgwgRUB2wHIWUBy4g3OuArgBwLzEv82/XQFcaGbnAjFAkpn9zjl3dbvjFwGLAGbPnn3kD12Qtu+rpqKuiTwNpiIiMmAp3ImIdMDMSImL6vJcX845ahubvdDXGv5qAsJhwPry2kb2VzewtbTaC5B1jbijfEWPjQxvC31JfgBMiokkKTaCpBh/OTaSpJiIg9v95QQNMNPXlgOTzGw8sBMvsF0VuIOZpQA1zrkG4CbgNT/wfdW/4dfc3dk+2PWGgqIyQIOpiIgMZAp3IiI9yMyIi4ogLiqCzOTYLh3b3OKoqms6LASWtdYc1hxcX1bbSOH+GirrmqiobaSyg2kmAoUZJLYFQC8MtgbDgyGxo6DorYuNVHPSrnDONZnZrcBSvKkQHnXOrTWzW/ztDwLTgCfMrBlvoJUbQ1ZgoKCwnLiocCYOTwhlMUREpBsU7kRE+onwMPOaYMZFdvnY5hZHZV0jFbVNVNR5YbCirtFvUhq4rqmtmenWvVVt22oamo96/shwOxgA/ZrBpHY1h621henxUZw6Mf1Y34ZBwzm3BFjSbt2DAY/fAiZ1co5XgVd7oXiHyS8sY9boZA0WJCIygCnciYgMAuFhx9aMtFVDU4sXDv2awI6DYSPltQe3F5fVti03NLe0nWtcWhyvfvljPfXSpA80NLWwrqSCG04dF+qiiIhINyjciYgIURFhbYPNHIu6xmY/BDbRGBD0ZGCIDDf+eftHiIoIC3VRRESkGxTuRESk22Iiw4mJDGd4n87KJj3FzJiQob52IiIDnX6iExERERERGQQU7kRERERERAYBhTsREREREZFBIKhwZ2bnmNlGM9tsZnd1sN3M7Bf+9lVmdkJnx5pZqpm9ZGab/PthPfOSREREREREhp5Ow52ZhQP3A/OB6cCVZja93W7z8ebqmQQsBB4I4ti7gGXOuUnAMn9ZREREREREjkEwNXdzgM3Oua3OuQbgaWBBu30WAE84z9tAiplldnLsAuC3/uPfAhd176WIiIiIiIgMXcGEu9FAYcBykb8umH2OduwI51wJgH8/PPhii4iIiIiISKBgwp11sM4FuU8wxx79yc0WmtkKM1tRWlralUNFRERERESGjGDCXRGQHbCcBRQHuc/Rjt3tN93Ev9/T0ZM75xY552Y752ZnZGQEUVwREREREZGhJ5hwtxyYZGbjzSwKuAJY3G6fxcC1/qiZJwPlflPLox27GLjOf3wd8PduvhYREREREZEhy5zrvJWkmZ0L/BwIBx51zv3AzG4BcM49aGYG/Ao4B6gBbnDOrTjSsf76NOBPwBhgB3CZc25/J+UoBT7s+ss8RDqwt5vnGGr0nnWd3rOu03vWdYP5PRvrnFNzjSD10PURBvdnqrfoPes6vWddo/er6wb7e3bEa2RQ4W4wMbMVzrnZoS7HQKL3rOv0nnWd3rOu03smPU2fqa7Te9Z1es+6Ru9X1w3l9yyoScxFRERERESkf1O4ExERERERGQSGYrhbFOoCDEB6z7pO71nX6T3rOr1n0tP0meo6vWddp/esa/R+dd2Qfc+GXJ87ERERERGRwWgo1tyJiIiIiIgMOkMq3JnZOWa20cw2m9ldoS5Pf2dm2Wb2LzNbb2Zrzez2UJdpIDCzcDN738yeC3VZBgIzSzGzZ8xsg/9ZOyXUZervzOyL/v/JNWb2BzOLCXWZZGDT9bFrdH08drpGdo2ukV031K+RQybcmVk4cD8wH5gOXGlm00Nbqn6vCfiSc24acDLweb1nQbkdWB/qQgwg9wH/dM5NBXLRe3dUZjYa+AIw2zk3E28O0StCWyoZyHR9PCa6Ph47XSO7RtfILtA1cgiFO2AOsNk5t9U51wA8DSwIcZn6NedciXPuPf9xJd4flNGhLVX/ZmZZwHnAw6Euy0BgZknAGcAjAM65BudcWUgLNTBEALFmFgHEAcUhLo8MbLo+dpGuj8dG18iu0TXymA3pa+RQCnejgcKA5SL0hzhoZjYOOB54J8RF6e9+Dvwv0BLicgwUE4BS4DG/mc7DZhYf6kL1Z865ncC9wA6gBCh3zr0Y2lLJAKfrYzfo+tglP0fXyK7QNbKLdI0cWuHOOlinoUKDYGYJwF+AO5xzFaEuT39lZucDe5xzK0NdlgEkAjgBeMA5dzxQDai/z1GY2TC8WpXxwCgg3syuDm2pZIDT9fEY6foYPF0jj4mukV2ka+TQCndFQHbAchZDrJr2WJhZJN6F6ynn3F9DXZ5+7jTgQjPbjtes6eNm9rvQFqnfKwKKnHOtv3g/g3chkyP7BLDNOVfqnGsE/gqcGuIyycCm6+Mx0PWxy3SN7DpdI7tuyF8jh1K4Ww5MMrPxZhaF17lycYjL1K+ZmeG1817vnPtpqMvT3znnvuqcy3LOjcP7fL3inBtSvxZ1lXNuF1BoZlP8VfOAdSEs0kCwAzjZzOL8/6PzUAd76R5dH7tI18eu0zWy63SNPCZD/hoZEeoC9BXnXJOZ3QosxRs551Hn3NoQF6u/Ow24BlhtZvn+uq8555aErkgyCN0GPOV/qdwK3BDi8vRrzrl3zOwZ4D28EfveBxaFtlQykOn6eEx0fZS+omtkF+gaCeacmtWLiIiIiIgMdEOpWaaIiIiIiMigpXAnIiIiIiIyCCjciYiIiIiIDAIKdyIiIiIiIoOAwp2IiIiIiMggoHAnIiIiIiIyCCjciYiIiIiIDAIKdyIiIiIiIoPA/wdADUHMxal/GAAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt # Visualization\n", - "\n", - "# Plot loss and accuracy in subplots\n", - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", - "ax1.set_title('Loss')\n", - "ax2.set_title('Accuracy')\n", - "for dataset in ('train','test'):\n", - " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", - " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", - "ax1.legend()\n", - "ax2.legend()\n", - "plt.show()\n", - "plt.clf()" - ] - }, - { - "cell_type": "markdown", - "id": "qQbKS0tV3sZ1", - "metadata": {}, - "source": [ - "## 12. Perform inference on test set\n", - "\n", - "Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels." - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "DFwxgBQf44ks", - "metadata": {}, - "outputs": [], - "source": [ - "@jax.jit\n", - "def pred_step(state, batch):\n", - " logits = state.apply_fn({'params': state.params}, test_batch['image'])\n", - " return logits.argmax(axis=1)\n", - "\n", - "test_batch = test_ds.as_numpy_iterator().next()\n", - "pred = pred_step(state, test_batch)" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "id": "5d5nF3u44JFI", - "metadata": { - "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqkAAAKqCAYAAAAZssdpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABhcUlEQVR4nO3debxV8/7H8c+neZ7k0qDipktRcRMqDcqQuBVFbshM5Ip0yVS5dCV0dQ0ZKq6hIopKoSRjUt1QJNU9NKGRSnPr98c5Hr/z+e5jD2dP33XO6/l47MfjvPdee63vOefb2p+z+uzv1iAIBAAAAPBJiWwPAAAAAHBRpAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7oShSVTVHVTvFsV2gqg0LeYxCPxf+YK4gHswTxIu5gngwT9IjFEWqr1S1rKqOVtUfVXWzqk5V1TrZHhf8o6odVHWOqv6sqjnZHg/8pKr9VXWVqv6iqutUdaSqlsr2uOAfzimIh6oOUdW9qro93+2IbI8rXhSpyblRRE4WkaYiUltEtorIv7M5IHhrh4iMFZGB2R4IvDZVRI4PgqCKiBwjIs1E5G/ZHRI8xTkF8ZoYBEGlfLdV2R5QvEJVpKpqS1X9RFW3qup6VX1UVcs4m52VdyVio6qOUNUS+Z5/uap+rapbVPUtVa2f5JAOF5G3giD4MQiCXSIyQUSaJLlPpIBvcyUIgvlBEDwvIqE5ORQHHs6TlUEQbP1t9yJyQESK1X/v+crDucI5xUO+zZOwC1WRKiL7ReQmEakpuVcwO4rIdc423UWkhYgcLyJdReRyERFV7SYit4vIuSJysIh8ICLjCzqIqt6WN8EKvOXbdIyItFbV2qpaQUR6i8iMlHynSJZvcwV+8m6eqOpfVfUXEdkouVdSn0zFN4qkeTdX4CUf58k5mtuSuFRV+6bim8yYIAi8v4lIjoh0KuD+/iIyOV8OROTMfPk6EZmd9/UMEbki32MlRORXEamf77kNExxXFcmdQIGI7BOR/4pIjWz/vIrzzde5km9fnUQkJ9s/p+J+832e5D3/SBH5h4gcmu2fV3G++T5XOKf4cfN1nohIY8ltRywpIq1EZL2IXJjtn1e8t1BdSVXVRqo6TVV/yLvSMExy/1rJb3W+r7+T3F+OiEh9EXkk318ZmyX3v9OSeaPTEyJSTkQOEpGKIvKacCXVCx7OFXjI53kSBMG3IrJURB5Pxf6QHJ/nCvzh2zwJguCrIAjWBUGwPwiCj0XkERHpUdj9ZVqoilTJLQqXiciRQe4bC26X3F9gfofl+7qeiKzL+3q1iFwTBEG1fLfyeb80Q1VvV/tOOHPLt2kzEXk2CILNQRDsltw3TbVUVXdCIvN8myvwk+/zpJSI/LHQ3x1Syfe5Aj/4Pk+CAsbjrbAVqZVF5BcR2a6qR4lIQb0VA1W1uqoeJrnvvp+Yd/9oERmkqk1ERFS1qqr2LOggQRAMC+w74cwt36aficglefsqLbmX7dcFQbAxNd8ukuDVXFHVEqpaTkRK50Ytp5HN9Mg83+bJlar6h7yvG4vIIBGZnapvFknxba5wTvGTb/Oka96xVFVbSu5qIa+n7ttNr7AVqbeIyF9FZJuIPC3//4vN73URWSgii0VkuuS+uUmCIJgsIsNFZELeJfglItI5BePZJSLfisgGETlLchuikX2+zZW2IrJTRN6U3L+cd4rI20nuE8nzbZ60FpEvVXWH5M6VNyX3Sgyyz7e5wjnFT77Nk14isiJvPP8RkeFBEDyX5D4zRvMaawEAAABvhO1KKgAAAIoBilQAAAB4hyIVAAAA3qFIBQAAgHdKRXtQVXlXVcgFQZCR9dCYK+GXibnCPAk/zimIF+cUxCPaPOFKKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAO6WyPQCgqLj00ktNHjdunMmzZs0y+bTTTkv3kIq92rVrm1yrVi2TDzrooIT2d+qpp0bdfxAEEc+ZPn26ybNnzzZ506ZNCY0B4fTBBx+YXNDvvXfv3ibv2LEjrWMCfMeVVAAAAHiHIhUAAADeoUgFAACAd+hJdbRp08bkbt26mVyjRg2Tt27davJ9991n8vjx4012+xDfeOMNk7t27RrvUOGZ008/3eQDBw6Y3LZtW5M7dOhg8pw5c9IzsCLs+eefN7ljx44mly9f3uRy5cqZXLZsWZML6imNRlVjPv+iiy4yecuWLSa/9957Ji9cuNDkf/3rXybv3LkzoTHCD7t37zb5zDPPjNjmyCOPNHnx4sXpHBI81LlzZ5Pr1q0bsc2DDz5ocpUqVUyeNm2ayY8//rjJM2bMSGaIGcWVVAAAAHiHIhUAAADeoUgFAACAdzRaD5aqJtag5blOnTqZfOedd0Zs4/akliiRWB2/fv16k911GV179+412e2RS1YQBBp7q+QVtbkSj+rVq5u8YMECkxs0aGDyrl27TG7atKnJK1euTN3gCiETcyXV88Tt+3XPZ+7jGzZsSOXhI5QqFdnmH2st1lh9rVOmTDF5wIABJufk5MQ/wBTgnFI4vXr1Mrl///4R29x8880mf/zxx+kcUtqF8ZySbu7ayi+99JLJzZo1M9ntNy2MX375xWS319mdi8uWLTPZ7adOtWjzhCupAAAA8A5FKgAAALxDkQoAAADvFOl1Um+44QaT3TVMK1WqFHMfbm+G23d43HHHmdykSZNEhpj1PkQU3iWXXGKy24Pq+vbbb03md5+8M844I+rje/bsMXnu3LnpHE5EP5mIyNNPP23yHXfcYbLbt/7EE0+Y7K7VPHr0aJMz3ZOKwrn++utNPvHEEyO26d27t8lh70mFyCGHHGLy1KlTTW7evHnax+D2tbprdi9atMjke+65x+ShQ4emZ2Bx4EoqAAAAvEORCgAAAO9QpAIAAMA7oe5JLVOmjMlPPvmkyW7PoLse4aZNmyL26X6esrue2P79+01210CcPn26yS1btow4Rn7Dhw+P+jj81aNHj4S2z2ZfT1H1zjvvZHsIxueffx5xX6xzwOmnnx71cfe8hXCaNWuWya1bt87SSJBJ1apVMzkTPajJcteQr1Gjhsk33nhjxsbClVQAAAB4hyIVAAAA3qFIBQAAgHdC3ZPavn17k/v06RN1e7cH9eyzz47YZuHChQmNoUQJW+dXrFgx6vZr1qwxec6cOQkdD9nhrmX5e/dFM23atFQNByFSr149k/v162eyu35muXLlTH7xxRdNTvdar0iPr776KttDQAZUrVrV5Lvvvjvlx9i+fbvJa9euNdl9r0zNmjVNHjx4sMnu+3t27txp8gUXXFCocaYCV1IBAADgHYpUAAAAeIciFQAAAN7RIAh+/0HV33/QA7Nnzza5Q4cOUbd310B9++23Ez5mnTp1THbXE7vmmmuiPv+EE04wOdEe2EQFQZCRRRZ9nyvJateuXcR97777btTnfPrppyafcsopJrtr7mZbJuZK2OdJ+fLlTXb7z6666qqI51x99dUm165d2+Q9e/aY7K6d7Ga3XyzTOKekhvt7F4l8TevcuXOmhpMWxfGc8vrrr5tc0HtfEjFz5syI+55++mmTp0yZYvJJJ51ksrsm79SpU01evnx5EiNMXrR5wpVUAAAAeIciFQAAAN6hSAUAAIB3QrVOauXKlU3+4x//GHV7t5fD/ezkeBx22GEmP/744yZ36dLF5AMHDph88803m7xo0aKEx4DMO/TQQ00eM2ZMwvu47777TPatBxWRevbsafJ5551n8tFHH23ysccea3K0Hv/fc/HFF5s8adKkhPeB8HF71kUi3/OA8Dn99NOTev7kyZNN7t27d8Q2u3fvjrqPefPmRc1hwpVUAAAAeIciFQAAAN6hSAUAAIB3QtWT6vZ/uZ+J7XLXAnP7RQvi7nP69OkmN2nSxOR9+/aZfMcdd5g8atSomMeEf2rUqGHy4YcfHvM5H374ocmx1lFF6rl96/fee6/Jxx9/vMnu+oHJUk18WchHHnnE5BtvvNHkVatWRX3+iy++aPLcuXNNjtW/huz48ssvI+5z19lt2LChyStWrEjrmJC8jRs3muyui+x6//33TXbXWi/u/365kgoAAADvUKQCAADAOxSpAAAA8E6oelITFatntVWrVhH3jR071uRGjRpF3ceTTz5p8ogRI+IcHXxWt27dhJ/j9hZl+zPWi6MjjjjC5H79+iX0/LVr15q8YcMGk//3v/+Z/PHHH8fcZ7ly5UyuVq2ayZ06dTK5YsWKJp977rkmV6hQwWR3nVV3fWh3vd6PPvoo+oCRNSVLljTZ7ZmmJ9V/t99+u8nPPvts1O1r1qxpcvXq1U3etGlTSsYVVlxJBQAAgHcoUgEAAOAdilQAAAB4hyIVAAAA3gnVG6fmz58fNbds2dLks88+2+SlS5eaPHTo0IhjuIu2u2+k+Nvf/mbylClTfn/ACA33zS1///vfYz7nxx9/NPmpp55K6ZiQOPeNTg8++GDU7d2F8NevXx91f9ngvnmzS5cuJt95550mn3HGGSZ37NjR5AceeMDku+66K9khAiikxo0bm9ytWzeTY53DijqupAIAAMA7FKkAAADwDkUqAAAAvKNBEPz+g6q//6AHBgwYYHIqFtJ/5513oh5jyZIlSR8jk4Ig0Ewcx/e5Eovb11dQv7Jr6tSpJru9RGGTibkS9nnio4MOOsjkxx9/3OQePXqYvG7dOpMPO+ywhI7HOSU13N+TiEjfvn1NvvTSS01+7rnn0jmklCuO5xT3g2BmzJhhstuD6lqzZo3JTZo0idhm+/bthRydn6LNE66kAgAAwDsUqQAAAPAORSoAAAC8E6p1Ul0vvfSSyYn2pL7yyisR91100UUm7927N/GBIXRq1KiR8HMee+yxNIwESMymTZtMHjZsmMk9e/Y0uU6dOmkfEwon2ntEEA5uT+m//vUvky+//HKTTzrpJJPdntbXXnst4hjPPPOMyS+//HKiwwwNrqQCAADAOxSpAAAA8A5FKgAAALwT6p7U0047LaHtN2/ebPLFF18csQ09qMVDpUqVTL7hhhuibn/gwIGI+7Zt25bSMQGpcOWVV5rs9jkuXLgwk8NBAvbv32/y3LlzszQSpMqYMWNMdntMx44da3Lr1q1N7tixY8Q+K1asaPKcOXNM3rBhQ8Lj9BVXUgEAAOAdilQAAAB4hyIVAAAA3glVT+rxxx9v8ujRoxN6fpUqVUxu2bJlxDYffvhh4gND6Nx9990mlygR/e+1mTNnRtw3b968lI4JiatatarJ+/btM3nHjh2ZHE5G/PnPfzb5jjvuMLlLly4mu/3U48ePT8/AkDT3d5WTk5OdgSBttmzZYnL37t2j5kmTJkXsw11b1e1rdddG3rVrV8Lj9AVXUgEAAOAdilQAAAB4hyIVAAAA3glVT+o//vEPk1XV5MWLF5vcvHlzk0uVst9u9erVUzY2hEvfvn2jPr57926TR4wYkc7hoJC++uorkx9++GGTH3rooUwOJyXcfrLjjjvOZHcd1Jo1a5rsrovqnjdHjhyZ7BABpMl7771nckHvfXB7Us866yyTBw0aZPLgwYNTM7gs4EoqAAAAvEORCgAAAO9QpAIAAMA7Xvekuj2lZ5xxhslvvvmmyS+++KLJrAeI39StW9fkWOuirly50uT3338/5WNC8mrXrm3y7bffbnLp0qVNXrRokclvv/22ye3btze5TJkySY4wct1Sd71n97O6E/XOO++YfNNNN5ns9u3CDxdffHG2hwAP7dy50+QVK1ZEbOP2pBZlXEkFAACAdyhSAQAA4B2KVAAAAHjH657UY445xmS3j7BevXqZHA5CrGvXriaXK1cu6vYvvPBCOoeDFOnXr5/J7jqp9913X9Tn//TTTya7a47G6l1212p21yiNx+eff26y+9ner776qsnuZ3n/8ssvJrs9bfBT2bJlI+6bNm1aFkaCVCpZsqTJbl98LHfccYfJF110UdJjCjOupAIAAMA7FKkAAADwDkUqAAAAvON1T2os7tqX3bp1y85A4L0TTzwx6uO//vqryXPnzk3ncJAijz32mMmLFy82+amnnjK5Vq1aJlerVs3kDRs2RD2eu37uJ598YnJBPam7du0y2e0x/eKLL6IeE8XH+vXrsz0EJOnPf/6zye6/d3dt51RwX7/mz5+f8mNkC1dSAQAA4B2KVAAAAHiHIhUAAADe8bon1e0vW7BggcktWrQw+YILLoi6v++//97kWbNmFX5wCJXHH3/c5PPPP99kdz3NefPmpX1MSL2PPvrI5CZNmpjs9qS6edGiRekZGIBiwe0HXbVqlcmp6El97733TL7++utNXrZsWdLH8AVXUgEAAOAdilQAAAB4hyIVAAAA3tFonzWtqol/EHUatWnTxuSpU6eaXLVqVZPdNQ87d+5scnHoPwuCQGNvlTzf5goSl4m5wjwJP84piBfnFJGzzz7bZLd/9PTTTzd506ZNJg8aNChin+77a955551khph10eYJV1IBAADgHYpUAAAAeIciFQAAAN4JVU8qEkf/GOJF/xjiwTkF8eKcgnjQkwoAAIBQoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4J2o66QCAAAA2cCVVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeCUWRqqo5qtopju0CVW1YyGMU+rnwB3MF8WCeIF7MFcSDeZIeoShSfaeqZVR1maquyfZY4CdV7aCqc1T1Z1XNyfZ44CdVHaKqe1V1e77bEdkeF/zDOQXxUtXjVfX9vPPJj6p6Y7bHFC+K1NQYKCI/ZXsQ8NoOERkruXMFiGZiEASV8t1WZXtA8BLnFMSkqjVFZKaIPCkiB4lIQxF5O6uDSkCoilRVbamqn6jqVlVdr6qPqmoZZ7OzVHWVqm5U1RGqWiLf8y9X1a9VdYuqvqWq9VMwpsNF5CIR+Wey+0Lq+DZXgiCYHwTB8yJCweER3+YJ/OXbXOGc4iff5omI3CwibwVB8GIQBLuDINgWBMHXSe4zY0JVpIrIfhG5SURqisjJItJRRK5ztukuIi1E5HgR6Soil4uIqGo3EbldRM4VkYNF5AMRGV/QQVT1trwJVuDN2fzfefvdmfy3hxTyca7APz7Ok3NUdbOqLlXVvqn4JpESPs4V+Me3eXKSiGxW1Y9V9SdVnaqq9VL0vaZfEATe30QkR0Q6FXB/fxGZnC8HInJmvnydiMzO+3qGiFyR77ESIvKriNTP99yGCY6ru4jMzPu6vYisyfbPqrjffJ0r+fbVSURysv1zKu43X+eJiDQWkdoiUlJEWonIehG5MNs/r+J883Wu5NsX5xQPbr7OExFZLiJbReQEESknIqNE5KNs/7zivYXqSqqqNlLVaar6g6r+IiLDJPevlfxW5/v6O8k94YuI1BeRR/L9lbFZRFRE6hRyLBVF5AERuaEwz0d6+TRX4C/f5kkQBF8FQbAuCIL9QRB8LCKPiEiPwu4PqePbXIGfPJwnOyW3SP4sCIJdIjJURFqpatUk9pkxoSpSReQJEVkmIkcGQVBFci+Lq7PNYfm+rici6/K+Xi0i1wRBUC3frXzeC4GhqrerfXetueVtdqSINBCRD1T1BxF5TURq5U3MBqn6hlFoPs0V+Mv3eRIUMB5kh+9zBX7wbZ58Ibnnkd/89nUozithK1Iri8gvIrJdVY8SkYL6tQaqanVVPUxEbhSRiXn3jxaRQaraREREVauqas+CDhIEwbDAvrvW3PI2WyK5E6153u1KEfkx7+vVBewWmeXTXBFVLaGq5USkdG7UchrZTI/M822edM07lqpqSxH5m4i8nrpvF0nwba5wTvGTV/NERMaJSHdVba6qpUXkLhH5MAiCrSn5btMsbEXqLSLyVxHZJiJPy///YvN7XUQWishiEZkuImNERIIgmCwiw0VkQt4l+CUi0rmwAwmCYF8QBD/8dpPcy/IH8vL+wu4XKePNXMnTVnL/2+VNyf3LeaeEaBmQIsy3edJLRFbkjec/IjI8CILnktwnUsO3ucI5xU9ezZMgCN6V3Ku50yV3qcyGeeMLBQ2CIPZWAAAAQAaF7UoqAAAAigGKVAAAAHiHIhUAAADeoUgFAACAd0pFe1BVeVdVyAVBkJG10Jgr4ZeJucI8CT/OKYgX5xTEI9o84UoqAAAAvEORCgAAAO9QpAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7UddJBQAAQHa0b98+4r45c+aYPHToUJOHDBmSxhFlFldSAQAA4B2KVAAAAHiHIhUAAADeoScVAADAA24Pqtt/WtxwJRUAAADeoUgFAACAdyhSAQAA4B16UgEASNJDDz1kcv/+/U1+9dVXTT7//PPTPSSEUEHrosby3nvvpXwcvuBKKgAAALxDkQoAAADvUKQCAADAOxoEwe8/qPr7DxZTkyZNMrl69eomd+zYMZPDiSkIAs3EccI2VypXrmzywoULTd65c6fJN9xwQ8Q+3n///dQPLIsyMVfCNk8QiXNKwfbv32/ygQMHTF63bp3JF1xwQcQ+5s2bl/qBZRHnlNiGDBli8uDBg2M+x+1B7dChQwpHlHnR5glXUgEAAOAdilQAAAB4hyIVAAAA3qEnNYbWrVub7PaCzJ071+ROnTqle0gJoX+sYGXKlDF5xowZJrdr187k2bNnR+zjjDPOSP3Asoj+McSDc0rB3HVPx48fb3KJEvaakNuzKiJSsmTJ1A8sizinRHLXQZ0zZ07C+1DNyD/BjKEnFQAAAKFCkQoAAADvUKQCAADAO6WyPYBUqlGjhsmbN29Oep8NGzY0uVSpIvUjK7b27Nlj8saNG6NuX69evYj73L5Wd58Aig/3/R0F9Zwm8jiKhmR7UMO+BmqyuJIKAAAA71CkAgAAwDsUqQAAAPBOqBssmzRpYvKzzz5rctu2bU12P489Hsccc0zUxydMmJDwPhE+jRo1irjv5JNPNtldMxfFT7Vq1SLuGz58uMlffvmlyY8++mg6h4QMcdeudNdFdXNBXn75ZZPdtVcRPm5PaizuWuxuLm64kgoAAADvUKQCAADAOxSpAAAA8E6oe1Jvvvlmk1u0aGFy+fLlTY6nJ7Vq1aomX3311Sbv3r3b5BdeeCHmPgEUTe46yvPnz4/Yxu1TddfTHDFihMnbt283+ZVXXjF56tSpJn/66acmp2J9aCQuFeuksnZq0dOuXbuEtue9DRZXUgEAAOAdilQAAAB4hyIVAAAA3glVT2qpUna4xx9/fNTt3c9bj6dX68gjjzS5SpUqJk+ZMsXkXbt2xdwn/Pfhhx+a3KNHD5PdNRBFRPr27WsyvURFX5s2bUx2zwcFrZPq+uKLL0xu1qyZyeXKlTP52muvjZrdXvv9+/ebvGrVKpMXLFhgstsH6a7b6o4XBUvFOqnua1bdunVNXrNmTSFHh2xJdJ3UIUOGpGUcYcWVVAAAAHiHIhUAAADeoUgFAACAd0LVk3rZZZeZ3Lx586jbf//99wkfg89KLp7cz1N31zxE8VCjRg2TH3/8cZO7detmcpkyZUwu6Jxz6623mjx58mST//nPf5p80003xTXW37jrQbuaNm0aNbu9lH369DG5bNmyCY2nuPr444+j5latWplc0JqoJ554YtRMT6r/Eu0p7dChQ3oGUkRwJRUAAADeoUgFAACAdyhSAQAA4B2ve1LdXqjrr78+6vbPPvusyVu2bIm6fUG9XGeffXZ8g0ORsnfvXpPdtSbdNXpFRBo3bmxyxYoVTd6xY0eKRod0cXtQZ86caXKLFi2iPn/btm0mF9SPNnHixKj7cHtW77//fpPdeZXsOapkyZIm/+EPfzB5/vz5Se2/uHL7RV977TWTW7dubXJB66a6/cEvv/yyye7vDv4ZPHhwQtu/99576RlIPnPmzIn6uLvGt09rtXIlFQAAAN6hSAUAAIB3KFIBAADgHY22HqSqZnWxyEsvvdTkcePGRd1+9uzZJn/zzTdRtz/mmGMi7mvbtm3U57i9iytXrjR57NixJo8YMSLq/tItCILID51Pg2zPlVRz1011+09FItdSrVWrlskbNmxI/cDSKBNzJdvzJNkeVHf7e++912R3bcyiiHNK4bh97gWtk+r2qbrblC5dOvUDS6PicE5xJbrGttuHXBjt27ePmhPtk03FmBIRbZ5wJRUAAADeoUgFAACAdyhSAQAA4B2v1kl110UdMGBAQs/v2LFj1JwKbk/QUUcdZXLPnj1NznZPKlCcVa9e3eREe1BfeOEFky+77DKT3T5D4Pd8+umnJp944okR27i9gAWtpQq/JLqm6NChQ1M+Brfn1O1JDTP+BQAAAMA7FKkAAADwDkUqAAAAvEORCgAAAO949cYpdzH0Ro0aZWkk/++HH34wOdYHCvzxj39M53CQJQW9gaGgxbjhlxtvvNHkWG+Uuu+++0x23xTBG6VQWCNHjjT5pZdeitgm1mL+N910U9R9wn+JvtGqIHPmzDG5KL1RysWVVAAAAHiHIhUAAADeoUgFAACAd7zqSc3JyTH5zjvvNLlevXoJ7e/VV1812V3c392/iMjGjRtNPvroo03eunVrQmNA0VBQ/2kQBFkYCRJx8803R338/fffN9ldFJu+Y6RLQX3usRbzP+mkk9I6JqSf2z/63nvvJb2PoowrqQAAAPAORSoAAAC8Q5EKAAAA73jVk+oaMWJESvd3xhlnxNxmwoQJJtODCoTXzJkzTe7Ro4fJ7trMBfWp57dhwwaT33zzTZO/++67RIeIYqqgfudY66TSB+8ft6fU7Wt3uY/H6klNxbqqsRSmLzZTuJIKAAAA71CkAgAAwDsUqQAAAPCO1z2pqdamTZtsDwEhsWTJEpMbN26cpZEgGVdffbXJNWrUMLlDhw4mJ9r/tWPHDpMfe+yxiG1uu+22hPaJomn16tUmr1u3LmKbww47zGS3R9VdRxXZl2g/p7vGqdtnPHToUJNj9bimgntMn3AlFQAAAN6hSAUAAIB3KFIBAADgnWLVk1qvXr2Y20ycODEDI4HvjjnmmGwPASngrnPcqVMnk48//niT3XVTjzjiCJPdz06/8MILTT711FMLM0wUA/PmzTP5k08+idimbt26JrvrpLrzz83uMZB5bt/wnDlzTHZ7Ul2Z6EF1e/FZJxUAAABIAEUqAAAAvEORCgAAAO8U6Z7UBg0amFy9enWTt2zZEvGcVatWpXNICCl3vUKRgj97G+GyaNGihLY/6KCD0jQSFDcFrXnq3ueed9x1VN0eVvgn1hqksXpUUyHM6+tyJRUAAADeoUgFAACAdyhSAQAA4J0i3ZNap04dkytXrmzyd999F/Gcgj5PGcXPlClTTG7cuHHENu5nLqPocdfLveKKK7I0EhQ1I0eOjLivR48eJrt9726P6o033mjypEmTUjQ6pIq7BmmsNUmHDBkSc5/uWqruPt11UMOMK6kAAADwDkUqAAAAvEORCgAAAO8U6Z7Uo48+OurjM2bMyNBIEDb0JkNEpGLFiib/4Q9/iLr9mDFj0jkcFCHz5s2LuC/WOqlhXu8S8YmnJzWebYoKrqQCAADAOxSpAAAA8A5FKgAAALxTpHtSjzvuuKiPb9myJUMjARBGd999t8mlS5c22V3Hcvr06WkfE4quhx56yOT+/fub7PaxXnjhhekeEpBVXEkFAACAdyhSAQAA4B2KVAAAAHinSPekup9jfOmll5r83XffZXA0CJOJEyeafO2110Zss3btWpPpcS565s6da3KnTp1Mds8xa9asSfuYUHQNHDgwagaKG66kAgAAwDsUqQAAAPAORSoAAAC8o0EQ/P6Dqr//IEIhCIKMfNgzcyX8MjFXwj5P6tSpE/Vxt0+5KOKcgnhxTkE8os0TrqQCAADAOxSpAAAA8A5FKgAAALxDT2oRR/8Y4kX/GOLBOQXx4pyCeNCTCgAAgFChSAUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgnajrpAIAAADZwJVUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN4JRZGqqjmq2imO7QJVbVjIYxT6ufAHcwXxYJ4gXswVxIN5kh6hKFJ9paoDVXWJqm5T1f+p6sBsjwl+UtUZqro9322Pqn6Z7XHBL5pruKpuyrs9oKqa7XHBP7z+IB5hnyelsj2AkFMRuUREvhCRP4rI26q6OgiCCdkdFnwTBEHn/FlV3xORd7MzGnjsahHpJiLNRCQQkXdEZJWIjM7imOAnXn8Qj1DPk1BdSVXVlqr6iapuVdX1qvqoqpZxNjtLVVep6kZVHaGqJfI9/3JV/VpVt6jqW6paP5nxBEHwQBAEi4Ig2BcEwTci8rqItE5mn0gN3+aKM7YGInKKiDyfqn2icDycJ31E5KEgCNYEQbBWRB4SkUuT3CdSwLe5wuuPn5gnqRWqIlVE9ovITSJSU0ROFpGOInKds013EWkhIseLSFcRuVxERFW7icjtInKuiBwsIh+IyPiCDqKqt+VNsAJvv/McldzCY2lS3yFSxdu5Irl/1X4QBMH/kvj+kBq+zZMmIvJ5vvx53n3IPt/mSv7n8PrjD+ZJKgVB4P1NRHJEpFMB9/cXkcn5ciAiZ+bL14nI7LyvZ4jIFfkeKyEiv4pI/XzPbZjEGIdK7gtK2Wz/vIrzLSRzZYWIXJrtn1Vxvvk6TyT3Be6ofPnIvP1otn9mxfXm61xxxsLrD/OkSM6TUF1JVdVGqjpNVX9Q1V9EZJjk/rWS3+p8X38nIrXzvq4vIo/k+ytjs+T2atRJwbj6Se7VsS5BEOxOdn9InsdzpY2IHCoik5LdF5Ln4TzZLiJV8uUqIrI9yHuFQfZ4OFd+GxevPx5hnqRWqIpUEXlCRJaJyJFBEFSR3Mvi7jtfD8v3dT0RWZf39WoRuSYIgmr5buWDIPjYPYiq3q72ndjm5mx7uYjcJiIdgyBYk6LvE8nzbq7k6SMirwVBUNBjyDzf5slSyX3T1G+aSZj+a65o822u8PrjJ+ZJCoWtSK0sIr+IyHZVPUpE+hawzUBVra6qh4nIjSIyMe/+0SIySFWbiIioalVV7VnQQYIgGBYEQaXfu/22nar2lty/kk4LgmBV6r5NpIBXcyVvP+VFpKeIPJuS7xCp4Ns8+Y+I3KyqdVS1togMEOaLL7yaK7z+eIt5kkJhK1JvEZG/isg2EXla/v8Xm9/rIrJQRBaLyHQRGSMiEgTBZBEZLiIT8i7BLxGRzgU8PxH3ishBIvJZvr9gWCrGD77NFZHcpYV+FpE5KdgXUsO3efKkiEwVkS/z9jc97z5kn29zhdcfPzFPUkhpdQIAAIBvwnYlFQAAAMUARSoAAAC8Q5EKAAAA71CkAgAAwDuloj2oqryrKuSCIHDXZ0sL5kr4ZWKuME/Cj3MK4sU5BfGINk+4kgoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8E6pbA8AAAAAIg0bNjT5pptuitimb9++Ufcxbdo0k6+66iqTf/zxx0KOLvO4kgoAAADvUKQCAADAOxSpAAAA8I4GQfD7D6r+/oMZcNhhh5ncqlUrk9u0aWNyt27dTK5Ro4bJa9euNfmzzz6LOOYNN9xg8ubNm+Maq6+CINBMHCfbcwXJy8RcKerzxO0nExE5+uijk9rn22+/bfLu3buT2l+yiuo5pUKFCibfeeedJh977LEmd+nSJanj/fTTTxH3ub2ErhdffNHkhQsXmvzLL78kNaZU45wSqVQp+1agwYMHm9yvXz+Tq1SpkvQxv/jiC5Pdubtu3bqkj5GMaPOEK6kAAADwDkUqAAAAvEORCgAAAO9ktSf1yiuvNPmvf/2ryY0bNza5Zs2aJqvaNoZo30u8HnnkEZMHDBiQ9D6zqaj2jyH16B+LrWrVqiY//PDDJrvnMBGRsmXLJnXM7777zuR77rnH5HHjxiW1/0QV1XNKz549TZ4wYYI7HpOTfb1x91eYfc6aNcvkgQMHmuz2ImYa55RI119/vcmjRo0yOZ559umnn5p83HHHmVymTJmo+xw+fLjJgwYNijLi9KMnFQAAAKFCkQoAAADvUKQCAADAO6Vib5I6F198scn//ve/TS5durTJ7ppvGzZsMNnts/j2229Nfuedd0yuXbu2yZdddlnEGHv37h11jDk5ORHPAVA0HXXUUSa755Q6deqkfQz169c3efTo0SZXrlzZZLfHDfG59dZbU7q/bdu2mbxy5UqT3T7CwujUqZPJbq+h22e7ffv2pI+J5Ljrv8fyn//8J+K+a665xmR3jfgxY8aYXLFixYSO6ROupAIAAMA7FKkAAADwDkUqAAAAvJPRntTu3bubvHXrVpOfe+45kx999FGT16xZk9LxtG7dOuI+d23Wq666yuQ77rgjpWNAOBTU0+N+lrerR48eJp977rkmH3744QmNwe1pc+fqnj17EtofIntK3X/vbt+6u/3y5ctNdnvDCuOGG24w+cILLzS5WrVqJo8YMcLkL7/80uQ5c+YkPabioF69elEf37Fjh8kjR440eenSpSa/9dZbJu/atcvkeM4pJ5xwgsnu57j379/f5DPOOMPkSZMmmeyek+hRzbwuXboktL37718kcu3k+++/3+TFixebXFCtExZcSQUAAIB3KFIBAADgHYpUAAAAeEejfVZwqj8T99prrzV50aJFJs+fPz+Vh4vJ7d0Siezze+WVV0zu1atXWseUakX1c7YbNGhgsruWpPvZxiVK2L/HOnfubPJ5551n8jHHHGNy+fLlI8bwxz/+Ma6xpovb07Zz586k9lccPmd7yJAhJrs9fW7PXyxun2JB54fp06cntE+X2yt58803m9y3b1+Tf/rpJ5PbtWtn8qpVq5IaT1E9p3zwwQcmt2rVyuRmzZqZvGTJkrSPKRb3POX2oB555JEm//3vfzf5oYceSs/A8hSHc0qiBgwYYLLbU+6u/15QjbZp0yaTGzZsaPK0adNMbtOmjcnuerq33XZblBGnX7R5wpVUAAAAeIciFQAAAN6hSAUAAIB3MrpOqvuZ0z5yexe3bNmSpZEgv7Jly5rsrv3o9qSuXbvW5IMOOsjkcuXKJT2mBQsWmOx+Vrfb8+z2ybrrXT722GNRj/fjjz+afODAgXiGWawccsghJrufx3755Zeb7Pag7t+/3+ScnByT3Z6/WbNmRd0+Fb7//nuT3T5at3eyRYsWJrtrbybbk1pUnXPOOSa7c8eHHlSXO6bXXnvNZLfX0O1NTHdPKiL9+9//Ntk9j7vrae/evTtiH3feeafJpUrZUs5dz9nta432XiTfcCUVAAAA3qFIBQAAgHcoUgEAAOCdjPakZlvz5s1NdvsYRUS2bt1qsts/guxwe2h++OEHk2vWrGmy22vofka1m//zn/+Y7PZ2uccTiex7Lah3KJpYa9O5c9Fd2zXR4xUHr7/+usktW7aMur3bg+quI+l+PruP3L7C8ePHm+x+Xrv7M0Iu99/boEGDsjOQJPzjH/8w+fzzzzfZ7V1034NBn3v67dmzx2T3HFOYc87pp59uckG1TVhxJRUAAADeoUgFAACAdyhSAQAA4J0i3ZPq9tu4695VqFAh4jnuNl999VXqB4aEuX08J598ssmNGjUy2V2zdP369ekZWALcfse77ror6vYTJ040efHixakeUugtWrTIZPfz1V3uOqbu52hPnjw5JePKpMMOOyzq4xdeeKHJF198cTqHgyzauXOnye5586yzzjK5adOmJnOOCacxY8ZEfXzHjh0mf/LJJ+kcTkpxJRUAAADeoUgFAACAdyhSAQAA4J0i3ZPq9tv07Nkz5nNWrlyZruEgjZYvX57tIURwPwt86NChJpctW9Zk93PhBw4cmJ6BhdhFF11kstuDqqomL1261GS3J2/16tUpHF12tGjRIurjbm8+8JvTTjvNZHpSw8l9LXG9/PLLJr/xxhvpHE5KcfYCAACAdyhSAQAA4B2KVAAAAHinSPekup9h7SponcopU6akaTQo6g4++GCTR48ebbLbN/Tiiy+afPPNN5u8ffv2FI6uaOjcubPJsXpQ+/bta3JR6EEtV66cyW7vveu5555L53DgEXcuxFpDF+F09dVXm1yzZs2o24dx/effcCUVAAAA3qFIBQAAgHcoUgEAAOCdUPekur1Zjz32mMnu57m7Xn/99ZSPCcVHyZIlTX777bdNrlWrlsnu5yc///zzJm/YsCGFoyua3HVOXW7/5YcffpjO4WTF3//+d5OPOuqoqNu/+uqr6RwOPOL++6hYsWKWRoJUcV9HRERGjBhhchAEJn/22WcmT5s2LfUDyxCupAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO+E+o1TAwYMMLlPnz4mu83E7hsIvvvuu/QMDMXCwIEDTW7WrJnJe/fuNfnCCy802X2jFWKrVq2aye6/8aKoYcOGJl9//fVRt9+zZ4/JOTk5qR5SKFWvXt3kq666ymT3zSZffvmlye6Ha+zatSuFo4tPqVL2Jbtly5Ym33rrrSbH+veRje8B0R100EEmDxs2LGKbSpUqRd2H+8EwYcaVVAAAAHiHIhUAAADeoUgFAACAd0LVk9q2bVuTb7/99qjbr1mzxuRLLrnE5N27d6dmYCjyhg4dGnHfXXfdFfU5vXr1MjnMCyr7Ys6cOSa3b9/e5Pfeey9zg0mTOnXqmLxgwQKTq1SpEvX5TzzxhMlLlixJzcBCzu1B/ec//5nQ8z///HOTly9fbvL48eNNXrp0qckrVqxI6HgFue+++0y+5ZZbTFZVk92e1HXr1pn87LPPJj0mROf20bu2bt1qcr9+/Ux265aCfPvttyYXpffbcCUVAAAA3qFIBQAAgHcoUgEAAOAdr3tS3V6OcePGmVyuXDmT3f4btyeQHlTEq2vXribH6j8VEXnxxRdNnjFjRkrHBJH169dHfdztUXX7OX1Uu3Ztk92+21g9qK7JkycnPaaiqEWLFkk9v3nz5ia76yL37NnT5F9//dXkxYsXmzxlypSIY8yaNctkd91T9xixbNq0yeSrr77a5G3btiW0P4g0atTI5DFjxkTd/g9/+EPUx3/66SeTW7dubXI8a0F369bN5LVr18Z8TlhwJRUAAADeoUgFAACAdyhSAQAA4B2ve1Ivuugik+vXrx91+5UrV5p87LHHmjxv3rzUDAxFzoUXXmjyv/71r5jPcdfhvfjii1M5JBQRpUuXNvmyyy4zedSoUSaXKVMmof3fdtttJn/44YcJPb+oOuSQQ0x2e/1c7nqVr732WtT9denSJer+KlSoYHKrVq1ijiee/sNEDBo0yGT65BN3xx13mPyPf/zD5Fhr08Zy5JFHRt1fPNwxunP3448/NvnHH39M+BjZwpVUAAAAeIciFQAAAN6hSAUAAIB3NFr/hKqmtkEmQaNHjzbZ/ezlEiVsjX3gwIGo+3M/w3rChAkR27zxxhtR9/HNN9+YvG/fvqjbJ+qwww4z2V1DLdG1XoMgSLzBpRCyPVcSdcwxx5j8zjvvmOz2n7n9pyIinTt3Ntn9rO6wycRcSXaeuGuGur+XLVu2mOz2ar3wwgvJHD6C+++1U6dOEdu4a1ueeeaZCR1j+/btJj/11FMmu32He/fuTWj/iQrLOcXt+fzggw+ibu/2Cj/33HNRt3fX6XbXpz3jjDOiPr+g3sNke1Ldfe7cudPk4cOHm/zMM8+YnOr1NcNwTonFfQ0+6KCD3OOb/PTTT5vsrt3csGHDqMdLtse1IG4P6sMPP2zygw8+mPQxkhFtnnAlFQAAAN6hSAUAAIB3KFIBAADgHa97UqtVq2bysGHDTHb7v4444oikjxmrH+STTz4xOScnx2S3pzXW/sqXL2/yAw88EHX/J510UuSgowhL/1i6uT9ndy3J4447zuS5c+eafOWVV0bs012XN+zC2D/m9lLdcMMNUbdfvny5ye56grH06NHD5EMPPdTk6tWrJ7S/grh977fccovJ06dPT/oYyQjLOcX9N//ll1+afPjhh5vsrnvqrqtdqpRdVvy8884z2V3X210X1ZWJntRY+3v//fdN7tChQ1LHd4XxnOJasGCBye5rhfszf+KJJ0zu1q2bye45w/X222+bPHbs2Iht3HW93fdHuGstu2NcuHChySeccELUMaUbPakAAAAIFYpUAAAAeIciFQAAAN7xuic1lkqVKpncq1cvk90e1WuuucbkqlWrRuwz1WuUpXp/bl9ULGHpH0s3d03c888/3+T9+/eb7PY2umv2FkVFoX/soYceMrlfv34mly5dOp2HL5C7frO7FqU75kcffTTq87MtrOeUmTNnmnzaaaeZ7K5Hu2fPHpPd9TGTPZdv3rw54r7HH3/cZLd30P2c92+//dZkd21Yd/vZs2eb/NZbb5m8YsWKKCNOXFE4p7h96BMnTnSPb3Ki82LWrFkmd+3a1eRdu3bF3EeDBg1MHjdunMktWrQwedKkSSa7awRnGj2pAAAACBWKVAAAAHiHIhUAAADeCXVPaqIOPvhgk9u1axexzSmnnGKy+xnvbu9H/fr1ox4z2X4Vt7fkqquuSuj5Ye0fS5a7nuybb75psrsG76hRo0zu379/OobltaLQP+Zy130cM2aMyXXr1jU50Z5vt49x2rRpEdu4/dDuWsphE9ZzStu2bU2eOnWqye57HAoYj8mxzuVuT6vb/9mnT5+I5/z8889R9xk2ReGc4q63+9xzz5ns9qy68+K7774zefjw4SaPHz/e5F9++aVQ44zGrWOWLFmS8mMkg55UAAAAhApFKgAAALxDkQoAAADvFKue1FRw+1rd7Orbt29C+1+3bp3JDz/8sMm7d+9OaH9h7R9LlNsr7H7+ubs+5kcffWTy2WefbXJR6w2LR1HoH0uU27tcs2bNhJ7vrju5c+fOpMfku6JyTjnjjDNMvuKKK0w+77zz3PGY7K41+dprr5m8bNkykxcvXlyYYYZacTynIHH0pAIAACBUKFIBAADgHYpUAAAAeIee1CKuqPSPuQ499FCTZ8yYYXKzZs1M3rFjh8knn3yyyb6tG5cN9I8hHkX1nILU45yCeNCTCgAAgFChSAUAAIB3KFIBAADgncQ+qBrwhLumoduD6jr22GNNzsnJSfWQAABACnElFQAAAN6hSAUAAIB3KFIBAADgHXpSEQrNmzc3+aabboq6/dixY01et25dqocEAADSiCupAAAA8A5FKgAAALxDkQoAAADvaBD8/sfe8pm44cfnbCNefM424sE5BfHinIJ4RJsnXEkFAACAdyhSAQAA4B2KVAAAAHgnak8qAAAAkA1cSQUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgnVAUqaqao6qd4tguUNWGhTxGoZ8LfzBXEA/mCeLFXEE8mCfpEYoi1VeqOkRV96rq9ny3I7I9LvhLVcuo6jJVXZPtscA/qlpNVZ9T1Z/ybkOyPSb4SVXLqupoVf1RVTer6lRVrZPtccEvqtpBVeeo6s+qmpPt8SSKIjV5E4MgqJTvtirbA4LXBorIT9keBLw1UkQqiEgDEWkpIher6mVZHRF8daOInCwiTUWktohsFZF/Z3NA8NIOERkrua89oROqIlVVW6rqJ6q6VVXXq+qjqlrG2ewsVV2lqhtVdYSqlsj3/MtV9WtV3aKqb6lq/Qx/C8gQH+eKqh4uIheJyD+T3RdSw8N5co6IPBAEwa9BEOSIyBgRuTzJfSIFPJwrh4vIW0EQ/BgEwS4RmSAiTZLcJ5Lk2zwJgmB+EATPi0goL6CFqkgVkf0icpOI1JTcvyA7ish1zjbdRaSFiBwvIl0l7wSvqt1E5HYROVdEDhaRD0RkfEEHUdXb8iZYgTdn83Py/qtlqar2TcU3iZTwca78O2+/O5P/9pAiPs4Tdb4+pvDfHlLIt7kyRkRaq2ptVa0gIr1FZEZKvlMkw7d5Em5BEHh/E5EcEelUwP39RWRyvhyIyJn58nUiMjvv6xkickW+x0qIyK8iUj/fcxsmOK7GkvvfLCVFpJWIrBeRC7P98yrON4/nSncRmZn3dXsRWZPtn1Vxvnk8T14QkddEpLKINBSRlSKyO9s/r+J883iuVJHcAiYQkX0i8l8RqZHtn1dxvfk6T/Ltq5OI5GT755ToLVRXUlW1kapOU9UfVPUXERkmuX+t5Lc639ffSW4RKSJSX0QeyfdXxmbJvUpR6EbzIAi+CoJgXRAE+4Mg+FhEHhGRHoXdH1LHp7miqhVF5AERuaEwz0f6+DRP8vxNcq+0fysir0tuEcKb7Dzg4Vx5QkTKichBIlJRcv+44Upqlnk4T0ItVEWq5P6jXCYiRwZBUEVyL4urs81h+b6uJyLr8r5eLSLXBEFQLd+tfF5xaajq7WrfsW9uUcYXFDAeZIdPc+VIyX0jzAeq+oPkvpjUyjuJNUjVN4xC8WmeSBAEm4Mg6B0EwaFBEDSR3HP0/BR+vyg8r+aKiDQTkWfz5sxuyW0naqmqbkGEzPJtnoRa2IrUyiLyi4hsV9WjRKSgHtCBqlpdVQ+T3Hc/Tsy7f7SIDFLVJiIiqlpVVXsWdJAgCIYF9h375vbbdqraNe9YqqotJfcqyOup+3aRBJ/myhLJPSk1z7tdKSI/5n29uoDdInN8mieiqn9U1YNUtaSqdhaRq0Xk3tR9u0iCV3NFRD4TkUvy9lVacv/beF0QBBtT8+2ikLyaJ6paQlXLiUjp3KjlNPKNXN4KW5F6i4j8VUS2icjT8v+/2PxeF5GFIrJYRKZLbnO5BEEwWUSGi8iEvEvwS0Skc5Lj6SUiK/LG8x8RGR4EwXNJ7hOp4c1cCYJgXxAEP/x2k9z/wjmQl/cXdr9ICW/mSZ4/i8iXeeP5p4j0DoJgaZL7RGr4NlduEZFdktsaskFEzpLc3ndkl2/zpK3kthC9KblXbXeKyNtJ7jNjNMhtqAUAAAC8EbYrqQAAACgGKFIBAADgHYpUAAAAeIciFQAAAN4pFe1BVeVdVSEXBEFG1m1lroRfJuYK8yT8OKcgXpxTEI9o84QrqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA75TK9gCiGTdunMknn3yyyR9++KHJS5YsMfmzzz4zOScnx+R9+/ZFHLNVq1Ymn3rqqSYPHjzY5M2bN0fsAwBQtFWqVMnkOnXqmDxp0iSTmzRpYvLixYsj9vn+++9H3WbevHkmL1u2LJ6hophr0aKFyW5tdODAAZP/+te/mjxx4sT0DCwOXEkFAACAdyhSAQAA4B2KVAAAAHhHgyD4/QdVf//BNGjTpo3Js2bNMrl06dImq6rJ0b4XEZENGzaYvH///ohtatWqFXWf3bt3N/mNN96IesxsC4JAY2+VvEzPlTBasWKFyQ888IDJTz31VCaHEyETc6WozZNy5cqZ3Lp164ht3PPa4YcfbvJZZ51l8tKlS012+xLvvvtuk7dt2xbXWFOFc0qul156yeQLLrgg6X3Gek37/vvvTR41apTJI0eOTHoMqcQ5JTvatWtn8tixY01u0KCByW5PqvtadfTRR6ducAWINk+4kgoAAADvUKQCAADAOxSpAAAA8I5XPamu+++/3+SBAwea7K5R+sEHH5js9pcecsghJrv9PSIiW7ZsMfmLL74w+YknnjD5hx9+iNiHT+gfy3XppZeavHbtWpPfeeedlB/TXZvu008/NfmFF14wuU+fPikfQyLoH4vkrn15xx13mNy1a1eT3XNOOsycOdPk888/3+Tt27en9fjF5ZxSvnx5k59//nmTzzjjDJMrVKiQ9DETfZ+F+76K9957z+SLLrrI5J9++qnwgysEzinpcfDBB5vcuHFjkydMmGByzZo1Td65c6fJ7uuh+3rpvnalGj2pAAAACBWKVAAAAHiHIhUAAADeKZXtAUQzefJkk92eVPdzi88991yTy5QpEzXv3r074ph79+5NeJzwT8+ePU121yA99dRT0z6GE044wWS336ygnmhkVokS9u/0k08+2WR3HeTq1atH3V9B5xT389hXrlyZyBAj1t90e50rVqxocrp7UouLKlWqmOyukZ2sTZs2Rdznzh93/rmvcX/4wx9M7tixo8m9e/c22bd1VFE4bdu2NfnJJ580uWrVqlGf/91335ns9i5//vnnSYwutbiSCgAAAO9QpAIAAMA7FKkAAADwjtc9qccee2xSz9+zZ0/UjKLDXaPQ/XzzHTt2mJyTk5PuIUmPHj1MXr9+vclunywyz12L+ZZbbkno+dOmTTP5yiuvjNgm2bUpBw0aZLLbqz9//nyTO3ToYPKqVauSOn5xdeaZZ0Z9/OOPPzb5scceM7l58+YmL1682GR37ohErsv7zTffmHz99deb/O2335p8xBFHmOyu5froo4+azHswwsntRS5dunRCz3fXVW3Tpo3J9KQCAAAAUVCkAgAAwDsUqQAAAPCO1z2py5cvN9ldZ9JVt25dk+vXr2+yu25l2bJlI/YxY8YMk7/44ouY40T23XjjjSY3adLEZHcduTVr1qR8DC1btjT5lFNOMdld13f16tUpHwOsQw45xORRo0aZfN5550V9/tatW03+y1/+YvInn3xisvtZ6oXhrnF43XXXmdy+ffuoz3/mmWdMzsSawEVRo0aNoj7urnPsfl66m+Ph9qDG8vLLL5t82223JXxM+O2uu+6KuG/IkCFJ7fO+++4z2e2n9glXUgEAAOAdilQAAAB4hyIVAAAA3vG6J/Wss84yOQgCk9015WbPnm1yw4YNEz6mu77mww8/bLK7RqG79t2BAwcSPiaS5/YebtmyxeR///vfaR9DuXLlTC5Vyut/XsVCr169TO7Zs2fU7X/88UeTTzrpJJPdz7xOhWbNmpnsfr56rB5U9/Pe//Wvf6ViWIjB/fddsWJFk921meNRooS9buT2xd5zzz0mx+qt//nnn03et29fwmNCZvXt29fkgvpPY9UZc+fONdmtW3zuQXVxJRUAAADeoUgFAACAdyhSAQAA4B2vm+Y6deoU9fEGDRqY7Pasrlu3zuRZs2aZvGTJkoh9uusg3n777VGzuy7diBEjfn/ASJkyZcqY3LVrV5Nfeuklk7/66qu0j+m0006L+rjbw+b20br9kEjeBRdcEPVxdx3UHj16mJyOHtRLLrnE5AEDBph87LHHJrS//v37m/zGG28Ualyw3NcL91zvzpVu3bqZ3KdPH5PdPsGCXt/ctb7vvffeuMb6m19//dVkd21xt29+586dCe0fqVetWjWTY/XNx8OtjebNm5f0PrOFK6kAAADwDkUqAAAAvEORCgAAAO943ZNar169hLZ3+3cefPBBk7dt2xZzH+5ne3fs2NHkiRMnmux+Bq67buo777wT85hIXNOmTU2uX7++yW6vYTqULl3aZHdNTddRRx1l8owZM0w+/vjjUzOwYqx8+fImV6pUKer2OTk5Jn/00UcJHa958+Ymu73RIpG9i+7al+48imXTpk0mP/PMMwk9H/H573//a7L7HofatWub7Pacv/jiiya7a5q6/acikT2k7vssYpkyZYrJF198cULPR/q5Pelnn322yaecckrC+3Rf7wYOHGjywoULE96nL7iSCgAAAO9QpAIAAMA7FKkAAADwjtc9qdOnTzfZ7eW4/PLLTXb7RQuzBtzevXtNnjlzpsnnn3++ya+99prJzz33nMnu524vX7484TEhktur5X6W8TnnnGPyo48+anKsNUnLli1rcps2bSK2ueaaa0x2+5fdMbm/+7Zt20YdAxK3f/9+k2N9VvnRRx9t8jfffJPQ8dxeaHf93nR44YUXTHa/Z6SG2+fXpUsXk93XBnfdY1dBPajJmj17tsnu577DP7t37za5e/fuSe9z48aNJk+ePDnpffqCK6kAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7Gm2xYFVNbCXhFKtWrZrJboPxuHHjMjiago0ePdrkq6++2mT3jVXu4sqFeXNXIoIg0NhbJS/bc8V9A0GHDh1MXrFihcmvv/66ye5ixzfffLPJLVq0SHhMkyZNMtl9051vMjFXMj1PhgwZYvLdd9+dycOLSPILtLtz1/3Qh+3btxduYIVUXM4psfzlL38x2T3XlyiR+DUg9wNnKleuHHV79/XDXRh+zpw5CY8hlYriOSVRjRs3Nvmtt94y2f1QCFdB82jJkiUmn3766SavX78+kSFmXbR5wpVUAAAAeIciFQAAAN6hSAUAAIB3vO5JDQN3AXe3T9ZdwLl58+Ymf/HFF2kZ12+KS/9Y06ZNTX7yySdNPvHEE6M+3+0b/PTTT00uaHHkM8880+R27dqZPGDAAJNHjhwZdQzZVhT7x0qVsp9Xcuutt5rs9i67/Z6LFi0yecOGDSa7C7oXZM2aNSZPmTLF5IoVK5q8adMmk0877TSTFy9eHPOY6VRczimxuB/G4fYaxvpgh8cffzziPve8dfjhh5s8atQok+vVq2eye95q1apV1DGkW1E8p8Ti9qC6H75x7LHHJrS/ZcuWRdzXu3dvk9NdR6QbPakAAAAIFYpUAAAAeIciFQAAAN4pFXsTRPPtt9+a/MEHH5h84YUXRs1h7yXxhftzvPbaa02+4YYbTF69erXJq1atMvmll14yef/+/RHHdHtQXfPnz4/6ONJv3759Jt93331R88EHH2yy24NaGO76mW4Pqsvtc812DyoKVqdOHZNj9aA++uijJg8cODBimz179pjsrof59ddfm/z555+b7PY7uq8348ePjzpGJK9+/fomJ9qDum7dOpPdNVBFwrcOajK4kgoAAADvUKQCAADAOxSpAAAA8E6oe1Jr1aplsru+oNvfkw7ff/991IzscHu1rrzyyqT216xZs4j7OnXqZPJHH31k8ieffJLUMZF5yfag3nTTTRH3devWLepz3H7q/v37JzUGZMZ1110X9fFnnnnG5JtvvtnkgvrcY1mxYoXJbp/rLbfcYvLZZ59tMj2p6XfOOecktL37WnXRRReZXJz6TwvClVQAAAB4hyIVAAAA3qFIBQAAgHdC1ZPaqFEjk+fOnWvy2rVrTXZ7uz788MO0jCs/d4woGtzP0BaJ/Fz4BQsWmHzgwIG0jgnZV758eZP79OmT8D7++c9/muz21sNPGzdujPq4u05yYXpQY3HXAXa5r0cVKlQw+ddff035mIq7a665xuRYrwPvv/++ycuWLUv5mMKMK6kAAADwDkUqAAAAvEORCgAAAO+Eqif1sssuM/nQQw81ee/evZkcToGWL19usqqa/Kc//SmTw0EGub97FH133nmnyU2bNo35nAkTJpj86quvpnRMyIyXX37Z5L/85S8mZ+L9Ce4x3H5md91eelCT97e//c3kkSNHmlyiRPRrfy+99JLJ7vq5sLiSCgAAAO9QpAIAAMA7FKkAAADwTqh6UpcsWWJyEAQmu2sWnnzyySYvXLjQ5J07dyY9ptq1a5t85plnmuyOkf6zcGrdunXMbb766qsMjATZ1LJlS5Pj6Sdz17K87777oj6OcHB/b+65/uKLLzZ53LhxJq9ZsyZin1WqVDHZXXfX/Vz3o446yuRPP/3U5EysDV7cuL/nWOuguo8PGTIk1UMq0riSCgAAAO9QpAIAAMA7FKkAAADwTqh6Ul988UWT27dvb/Lll19u8v3332/yJZdcYvKoUaMijvH0009HHUOtWrWi7sNdJ/Hrr782+fXXX4+6f/jpo48+irhvwIABJrvr9qLoef75500uW7Zsws9ZunRpSseE7HjllVdM7tatm8m9evUyOZ7fu7vGZqx+R6RfnTp1TL766qujbr9gwQKT+/bta/L333+fmoEVE1xJBQAAgHcoUgEAAOAdilQAAAB4J1Q9qa7rrrvOZHcN0qeeesrkxo0bmzx69OiIfQ4bNsxkd020MmXKmFy5cmWTt27darK7zt327dsjjomiwV1Dc+LEiVkaCVLl2muvNblhw4YJ7+Oaa65J1XDgscGDB5vcqlUrk+vVqxdzH+7rTaJWrlyZ1PMRae3atSa7dcXDDz9scs2aNU2uUKGCyXv37k3h6Io+rqQCAADAOxSpAAAA8A5FKgAAALyj0XpgVDW5Bpksq1GjhslDhw41+bzzzot4jrvWZaweoc8++8xk97O8P/7445jjTKcgCDQTxwn7XInl+OOPj7jP/d3OmzfPZHcdX99lYq74Pk+qV69usrumYcWKFaM+/4UXXoi4z12fOew4p8TnqKOOMvmtt94yuW7duhHPUbU/Wnf+TZ482WT3HDRjxgyTs/0eiKJwTnnmmWdMPvHEE01+6aWXTH777bdNXrhwYXoGVoREmydcSQUAAIB3KFIBAADgHYpUAAAAeKdI96SC/rF0GjJkiMnr1q0z2V1Pz3dFoX8sWaVLlzZ57NixJvfu3dvke+65x+RHHnkkYp9btmxJ0ej8wDkF8eKcgnjQkwoAAIBQoUgFAACAdyhSAQAA4B16Uos4+scQL/rHEA/OKYgX5xTEg55UAAAAhApFKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvRF0nFQAAAMgGrqQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8E4oilRVzVHVTnFsF6hqw0Ieo9DPhT+YK4gH8wTxYq4gHsyT9AhFkeorzTVcVTfl3R5QVc32uOAfVR2oqktUdZuq/k9VB2Z7TPCPqnZQ1Tmq+rOq5mR7PPCXqg5R1b2quj3f7Yhsjwt+Cfs8oUhNztUi0k1EmolIUxE5W0SuyeaA4C0VkUtEpLqInCki/VS1V3aHBA/tEJGxIsIfMYjHxCAIKuW7rcr2gOCl0M6TUBWpqtpSVT9R1a2qul5VH1XVMs5mZ6nqKlXdqKojVLVEvudfrqpfq+oWVX1LVesnOaQ+IvJQEARrgiBYKyIPicilSe4TKeDbXAmC4IEgCBYFQbAvCIJvROR1EWmdzD6RPA/nyfwgCJ4XkdC8iBQXvs0V+Il5klqhKlJFZL+I3CQiNUXkZBHpKCLXOdt0F5EWInK8iHQVkctFRFS1m4jcLiLnisjBIvKBiIwv6CCqelveBCvwlm/TJiLyeb78ed59yD7f5kr+56iInCIiS5P6DpEK3s4TeMfHuXKOqm5W1aWq2jcV3ySSxjxJpSAIvL+JSI6IdCrg/v4iMjlfDkTkzHz5OhGZnff1DBG5It9jJUTkVxGpn++5DRMc134ROSpfPjJvP5rtn1lxvfk6V5yxDJXcP2jKZvvnVVxvvs8TEekkIjnZ/jlx83euiEhjEaktIiVFpJWIrBeRC7P98yquN+ZJem6hupKqqo1UdZqq/qCqv4jIMMn9ayW/1fm+/k5yfzkiIvVF5JF8f2Vsltw+wTpJDGm7iFTJl6uIyPYgb2YgezycK7+Nq5/k9qZ2CYJgd7L7Q3J8nSfwj29zJQiCr4IgWBcEwf4gCD4WkUdEpEdh94fUYJ6kVqiKVBF5QkSWiciRQRBUkdzL4u676Q/L93U9EVmX9/VqEbkmCIJq+W7l835phqrervadcOaWb9Olkvumqd80E/4L1xe+zRVR1ctF5DYR6RgEwZoUfZ9IjnfzBN7yfa4EBYwHmcc8SaGwFamVReQXEdmuqkeJSEG9FQNVtbqqHiYiN4rIxLz7R4vIIFVtIiKiqlVVtWdBBwmCYFhg3wlnbvk2/Y+I3KyqdVS1togMEJFnU/KdIllezRVV7S25f1GfFoTonZXFgG/zpISqlhOR0rlRy2nkmy6QHb7Nla55x1JVbSkif5PcN2Qiu5gnKRS2IvUWEfmriGwTkafl/3+x+b0uIgtFZLGITBeRMSIiQRBMFpHhIjIh7xL8EhHpnOR4nhSRqSLyZd7+pufdh+zzba7cKyIHichn+f7aHZ3kPpE83+ZJWxHZKSJvSu4Vlp0i8naS+0Rq+DZXeonIirzx/EdEhgdB8FyS+0TymCcppLRPAgAAwDdhu5IKAACAYoAiFQAAAN6hSAUAAIB3KFIBAADgnVLRHlRV3lUVckEQZGQ9NOZK+GVirjBPwo9zCuLFOQXxiDZPuJIKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPBOqWwPAAAAoCgqUcJeCyxdurTJBw4cMLlcuXJRny8iUrJkyaTGtG3bNpP37t2b1P7SiSupAAAA8A5FKgAAALxDkQoAAADv0JOKYqlChQomn3baaSa3bds24X0uWrTI5Dlz5pi8bt26hPcJuHP1zjvvNPlPf/qTyeedd17axwQgPpdccknU7PaHNm3a1OQqVapE7LNatWpJjWnmzJkmu69V//vf/6Juv2PHjqSOnwiupAIAAMA7FKkAAADwDkUqAAAAvKNBEPz+g6q//2Ax1b59e5MHDx4c9fEOHTqY/N5776VhVL8vCALNxHF8nyuVK1c2edy4cSZ369bNZFX7Y4v27+T3bNiwweRjjjnG5E2bNiW8z3TKxFzxfZ74wJ0nEydONHnChAkmjxo1yuSff/45PQPLwzmlcNy1LUuVinxLSKtWrUyuX79+1H3WqlXL5PXr10fd3l2js1+/fibXrFnT5I0bN5p8wQUXmLxs2bKoxyuO5xR3XdPnn3/e5F69emVyOCkxbdo0k7t27ZrS/UebJ1xJBQAAgHcoUgEAAOAdilQAAAB4h3VSHYn2nMbirj+W7R7V4qpJkyYmuz2orldeecXkgtY4ffnll00+7rjjTH700Uej7rNnz54m+9ajisxwe1Dfeustk59++mmTH3vsMZPT3YOK+Li9iO65/rbbbjO5Xr16Efto0KCByW4PaWF64/NLtNf+0EMPNfnYY481OVZPanHUqFEjk8PYg+ravn171o7NlVQAAAB4hyIVAAAA3qFIBQAAgHeKVU/qkCFDTHb7TTPBPSY9qZnxww8/mPz++++b7Pb9jR8/PuFjzJs3z+QDBw6Y7PYSnn/++SY/8cQTCR8T4eP2HU6fPt3kkSNHmvzwww+b7M4rZIfbU/rnP//Z5NGjR5vsrkGaCV9++aXJ7txxe1Ld91B89NFHJr/xxhspHF3R5L7XIBPctZT3799vsnuOqVKlStT9ffrppyb/9NNPSYwuOVxJBQAAgHcoUgEAAOAdilQAAAB4p0j3pLr9NYmucZoJ7pjoUU2PnJwck0899dS0H/PNN9802V2jsHPnzia7fbH79u1Lz8CQVSNGjDD5+++/N/nBBx/M5HAQp4YNG5o8aNAgky+77DKTY61B6v57FxH5+uuvTXZ7A9evXx9znPmtXr3aZPqZ/eP2i7777rsmv/rqqzH3sWbNGpOTXU/XJ1xJBQAAgHcoUgEAAOAdilQAAAB4p0j1pLr9nZnoQR06dGjUx9u1a2dyrDG6+3PXdkV4uX1C1atXN9n9nG56UouGK664wuRmzZqZfMIJJ2RyOCik8847z+RLL7006vYrV640edKkSSbPnDkz4jnu+s0In5tuuinq42PHjjW5b9++JnPet7iSCgAAAO9QpAIAAMA7FKkAAADwTqh7Ut1+zcGDByf0/IL6Sd19utldxzTRdU2L0vpliK5Xr15RH1+7dq3JO3fuTOdwkAFdu3aNuO+ee+4x+d577zX5559/TuuYUDglS5Y0uXv37ia76x5PmTLF5HPPPTct44Lf3PcWuOrWrWvySSedZPLSpUtN3rJlS2oGFlJcSQUAAIB3KFIBAADgHYpUAAAAeEej9UiqqlcNlMn2oHbo0MHkRPtJUyHR78Hte0pUEATJ7SBOvs0VH7i9hhUrVjTZXTcxVg9rumVirhT1ebJgwYKI+2rUqGHyEUcckanhpEVxOadUqlTJ5Llz55rcvHlzk/v162fyE088kZZxhUlxPKdMmDDB5J49eyb0/DVr1pg8Z86cmM/58MMPTXZrmxUrViQ0hkyLNk+4kgoAAADvUKQCAADAOxSpAAAA8E6o1klt165dQtsnu6ZpOrg9qe731L59e5PdfhS3rxbZU6FCBZMnTpxocuXKlU12+7/dHjeEz913321ys2bNIrZxexURDtu3bzd58eLFJrs9qaNGjTL51FNPNfnVV181uaD+Zd97BxHbzJkzTW7durXJtWvXjvp8dx3Viy++OOYx3W1+/fVXk999912Tn3nmGZOnTp0a8xjZwpVUAAAAeIciFQAAAN6hSAUAAIB3vF4nNdHPufdhHdRYYvWcuoYOHWqy29MaS3FZ09Dt86lXr57JXbp0ifr8H374weSFCxfGPOYtt9xicrdu3Ux217h15+P5559v8qZNm2IeM52K45qGiSpXrpzJH330kcnVq1ePeI7bu/jLL7+kfFyZVFzOKa7OnTubPG3aNJPdf++xXr927doVcZ/b9zp//nyTH3vsMZN972HlnCLSoEEDk6+44gqTjz32WJPbtGljckHnlGS5a3i/9tprJt94440m79ixI+VjyI91UgEAABAqFKkAAADwDkUqAAAAvOPVOqmJ9lu6fOxBTVZR/J7i4fbxtW3b1uQLLrjA5Fq1aplcv359kxPtb060vywes2bNMjnbPahI3L333mvycccdZ/Ktt94a8Zyw96Ai1zvvvGNykyZNTJ49e7bJhx56aNT9uf3NIiInnXSSySeffLLJXbt2Ndmdf26vIbIvJyfH5Lvuuivq9m4P6wknnBCxjdsf7c6Dpk2bRj1G1apVTb7sssuiPn7RRReZvHv37qj7TyWupAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9kdTH/RBe2d4Vh8X6X+z26PwNXst9jWBfedt8AULFixUTHY/K6detMnjBhgslXXXWVyZUrVza5MG+ccsfgNpu7C8Gfe+65Jm/bti3hYyaDhbcjlS5d2uQlS5aY7H5oxNFHHx2xD/eNE2EX1nNKprmLsru5UaNGEc85/vjjTW7WrJnJ7nno1VdfNblPnz4m//rrr/ENNk04p2RGpUqVTHbftHf33Xeb3Lt374T2X7NmTZO3bNmS0PNjYTF/AAAAhApFKgAAALxDkQoAAADvZHQx/+LQg5rs9+j2MRZXVapUMfnAgQMmu71W7kL57qLrbj/osGHDTHZ7UN3fw549eyLG+OGHH5r85ptvRmyTX5cuXUx2P7Bg69atJrs9qu5i4tnuNysOnnjiCZMbNmxo8t/+9jeTi1r/KQrPPT+4OR6nn366yc8//7zJ7utNhQoVTOYcUTy4r0+rV682+frrrze5du3aJru1lU+4kgoAAADvUKQCAADAOxSpAAAA8E5We1JjGTp0qMlFsQfV516QbHJ7UN31AadPn27yiBEjTL711ltNPvHEE02uU6dO1P2762H2798/YoyJ/q5Hjhxpsvu7f+aZZ0x+7bXXoubLL7/c5Eyvq1oUVatWzWT3Z7xy5UqTn3766XQPCcWY+5q3adMmkw8++OAMjqZocN/vsHPnTpNr1aqV0uO5fcJuf2gquOvruuumHnPMMSk/ZqZwJRUAAADeoUgFAACAdyhSAQAA4J2M9qS2a9cuoe196EF1e04HDx4c9XGX21c7ZMiQFIyq6Pv2229NdtenPPXUU03+y1/+YnLZsmVNdntOXW6P67XXXmvy+vXroz6/MNyeVvezvdesWWNy9+7do+7Pt8/tDqNYP+P777/f5ILWzwVSxf03/ac//cnkzZs3Z3I4RcIDDzxgsrtmqLueNbKLK6kAAADwDkUqAAAAvEORCgAAAO94vU5qrO1j9ay6z4/VX1oY7hhY9zQ1OnXqZPLw4cNNvuCCCxLan9tz+o9//MPk//73vybv378/of2ngtv32qRJE5MnT55ssts/6fbhun26iOSuiThs2DCTVdXkwnz+OvzUrVs3k93f7caNGzM4mlxXXnmlyU8++aTJbm/96NGjTc7GmMPmmmuuMdldkxsi69atM3nfvn1ZGglXUgEAAOAhilQAAAB4hyIVAAAA3tFo60eqavTFJZMUa+1KH7k9pz6s5RpNEAQae6vkpXuuuBo1amSyu7bdyJEjMzmcjOjXr5/JjzzySNTtS5YsmdD+MzFXMj1PYjnqqKNMXrp0qckff/yxyR07djS5OK6TWlTOKU2bNjXZXT/ztttuM3nx4sVJHa9y5com9+rVK2IbtwfV7Yl210U97rjjTP7++++TGWLK+XhOGTRokMn33ntvSscTBm7/9fvvv2/ymDFjTM7JyUnreKLNE66kAgAAwDsUqQAAAPAORSoAAAC8k9F1Ul1uP2ei66gme7y5c+fG3Mb3ntPiavny5VFzUeT2cIexpztsxo4da3Jx7EEtLk477TSTjzzySJOnTZtmstuv/Omnn5p8+umnmzxixAiTK1WqFDEG99/09u3bTb7kkktM9q0HNQzc3uOXX37ZZHf9XPf9D7FceumlJpcqlf4ya8GCBSa7/dNTp041+d133zX5119/Tcu4UoErqQAAAPAORSoAAAC8Q5EKAAAA72R1ndRYhgwZktbti4OisqYhItdZdPslu3fvbnKivVA+rmmYbrHWSX3rrbdM7tOnj8kbNmxIz8A8VlTOKTVq1DB55syZJv/5z39OaH/umqaF6Rl3e0zdtVTdvlffFcdzSrVq1Ux250U67Nq1y+SdO3em/ZipxDqpAAAACBWKVAAAAHiHIhUAAADe8bonFckrKv1jiFShQgWT3R67NWvWJLS/4tg/duihh5o8ceJEk8eNGxf18bD1fqVCUT2nuOukTpo0yeSC1jXNL9GeVHfdVZHINTa3bNkSdR++K47nFCSOnlQAAACECkUqAAAAvEORCgAAAO/Qk1rEFdX+MaQe/WOIR3E5p9StW9fkCy64wOR+/fqZ7PakLly40OQ33njD5Oeffz7imAcOHEh4nD7jnIJ40JMKAACAUKFIBQAAgHcoUgEAAOAdelKLuOLSP4bk0T+GeHBOQbw4pyAe9KQCAAAgVChSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHhHg4CPvQUAAIBfuJIKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDv/Bx6z9iwB7yj7AAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", - "for i, ax in enumerate(axs.flatten()):\n", - " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", - " ax.set_title(f\"label={pred[i]}\")\n", - " ax.axis('off')" - ] - }, - { - "cell_type": "markdown", - "id": "edb528b6", - "metadata": {}, - "source": [ - "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", - "the same example, but structured differently as a couple of Python modules, test\n", - "modules, config files, another Colab, and documentation in Flax's Git repo:\n", - "\n", - "[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst", - "main_language": "python" - }, - "language_info": { - "name": "python", - "version": "3.9.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs_nnx/quick_start.md b/docs_nnx/quick_start.md deleted file mode 100644 index ac8a9fb8..00000000 --- a/docs_nnx/quick_start.md +++ /dev/null @@ -1,355 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - main_language: python - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.13.8 ---- - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb) -[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb) - -# Quick start - -Welcome to Flax! - -Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural -network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train -the network for image classification on the MNIST dataset. - -+++ - -## 1. Install Flax - -```{code-cell} -:tags: [skip-execution] - -!pip install -q flax>=0.7.5 -``` - -## 2. Loading data - -Flax can use any -data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the -samples to floating-point numbers. - -```{code-cell} -import tensorflow_datasets as tfds # TFDS for MNIST -import tensorflow as tf # TensorFlow operations - -def get_datasets(num_epochs, batch_size): - """Load MNIST train and test datasets into memory.""" - train_ds = tfds.load('mnist', split='train') - test_ds = tfds.load('mnist', split='test') - - train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'], - tf.float32) / 255., - 'label': sample['label']}) # normalize train set - test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'], - tf.float32) / 255., - 'label': sample['label']}) # normalize test set - - train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from - train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency - test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from - test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency - - return train_ds, test_ds -``` - -## 3. Define network - -Create a convolutional neural network with the Linen API by subclassing -[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). -Because the architecture in this example is relatively simple—you're just -stacking layers—you can define the inlined submodules directly within the -`__call__` method and wrap it with the -[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact) -decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. - -```{code-cell} -from flax import linen as nn # Linen API - -class CNN(nn.Module): - """A simple CNN model.""" - - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) - return x -``` - -### View model layers - -Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. - -```{code-cell} -:outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da - -import jax -import jax.numpy as jnp # JAX NumPy - -cnn = CNN() -print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), - compute_flops=True, compute_vjp_flops=True)) -``` - -## 4. Create a `TrainState` - -A common pattern in Flax is to create a single dataclass that represents the -entire training state, including step number, parameters, and optimizer state. - -Because this is such a common pattern, Flax provides the class -[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) -that serves most basic usecases. - -```{code-cell} -:outputId: 1249b7fb-6787-41eb-b34c-61d736300844 - -!pip install -q clu -``` - -```{code-cell} -from clu import metrics -from flax.training import train_state # Useful dataclass to keep train state -from flax import struct # Flax dataclasses -import optax # Common loss functions and optimizers -``` - -We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). - -```{code-cell} -@struct.dataclass -class Metrics(metrics.Collection): - accuracy: metrics.Accuracy - loss: metrics.Average.from_output('loss') -``` - -You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need -to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. - -```{code-cell} -class TrainState(train_state.TrainState): - metrics: Metrics - -def create_train_state(module, rng, learning_rate, momentum): - """Creates an initial `TrainState`.""" - params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image - tx = optax.sgd(learning_rate, momentum) - return TrainState.create( - apply_fn=module.apply, params=params, tx=tx, - metrics=Metrics.empty()) -``` - -## 5. Training step - -A function that: - -- Evaluates the neural network given the parameters and a batch of input images - with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) - method (forward pass)). -- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding. -- Evaluates the gradient of the loss function using - [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad). -- Applies a - [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) - of gradients to the optimizer to update the model's parameters. - -Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) -decorator to trace the entire `train_step` function and just-in-time compile -it with [XLA](https://www.tensorflow.org/xla) into fused device operations -that run faster and more efficiently on hardware accelerators. - -```{code-cell} -@jax.jit -def train_step(state, batch): - """Train for a single step.""" - def loss_fn(params): - logits = state.apply_fn({'params': params}, batch['image']) - loss = optax.softmax_cross_entropy_with_integer_labels( - logits=logits, labels=batch['label']).mean() - return loss - grad_fn = jax.grad(loss_fn) - grads = grad_fn(state.params) - state = state.apply_gradients(grads=grads) - return state -``` - -## 6. Metric computation - -Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. - -```{code-cell} -@jax.jit -def compute_metrics(*, state, batch): - logits = state.apply_fn({'params': state.params}, batch['image']) - loss = optax.softmax_cross_entropy_with_integer_labels( - logits=logits, labels=batch['label']).mean() - metric_updates = state.metrics.single_from_model_output( - logits=logits, labels=batch['label'], loss=loss) - metrics = state.metrics.merge(metric_updates) - state = state.replace(metrics=metrics) - return state -``` - -## 7. Download data - -```{code-cell} -num_epochs = 10 -batch_size = 32 - -train_ds, test_ds = get_datasets(num_epochs, batch_size) -``` - -## 8. Seed randomness - -- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. -- Get one - [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey) - and use it for parameter initialization. (Learn - more about - [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) - and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) - -```{code-cell} -tf.random.set_seed(0) -``` - -```{code-cell} -init_rng = jax.random.key(0) -``` - -## 9. Initialize the `TrainState` - -Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics -and puts them into the training state dataclass that is returned. - -```{code-cell} -learning_rate = 0.01 -momentum = 0.9 -``` - -```{code-cell} -state = create_train_state(cnn, init_rng, learning_rate, momentum) -del init_rng # Must not be used anymore. -``` - -## 10. Train and evaluate - -Create a "shuffled" dataset by: -- Repeating the dataset equal to the number of training epochs -- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from - - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer - -Define a training loop that: -- Randomly samples batches from the dataset. -- Runs an optimization step for each training batch. -- Computes the mean training metrics across each batch in an epoch. -- Computes the metrics for the test set using the updated parameters. -- Records the train and test metrics for visualization. - -Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. - -```{code-cell} -# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs -num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs -``` - -```{code-cell} -metrics_history = {'train_loss': [], - 'train_accuracy': [], - 'test_loss': [], - 'test_accuracy': []} -``` - -```{code-cell} -:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 - -for step,batch in enumerate(train_ds.as_numpy_iterator()): - - # Run optimization steps over training batches and compute batch metrics - state = train_step(state, batch) # get updated train state (which contains the updated parameters) - state = compute_metrics(state=state, batch=batch) # aggregate batch metrics - - if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed - for metric,value in state.metrics.compute().items(): # compute metrics - metrics_history[f'train_{metric}'].append(value) # record metrics - state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch - - # Compute metrics on the test set after each training epoch - test_state = state - for test_batch in test_ds.as_numpy_iterator(): - test_state = compute_metrics(state=test_state, batch=test_batch) - - for metric,value in test_state.metrics.compute().items(): - metrics_history[f'test_{metric}'].append(value) - - print(f"train epoch: {(step+1) // num_steps_per_epoch}, " - f"loss: {metrics_history['train_loss'][-1]}, " - f"accuracy: {metrics_history['train_accuracy'][-1] * 100}") - print(f"test epoch: {(step+1) // num_steps_per_epoch}, " - f"loss: {metrics_history['test_loss'][-1]}, " - f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") -``` - -## 11. Visualize metrics - -```{code-cell} -:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac - -import matplotlib.pyplot as plt # Visualization - -# Plot loss and accuracy in subplots -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) -ax1.set_title('Loss') -ax2.set_title('Accuracy') -for dataset in ('train','test'): - ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') - ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') -ax1.legend() -ax2.legend() -plt.show() -plt.clf() -``` - -## 12. Perform inference on test set - -Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. - -```{code-cell} -@jax.jit -def pred_step(state, batch): - logits = state.apply_fn({'params': state.params}, test_batch['image']) - return logits.argmax(axis=1) - -test_batch = test_ds.as_numpy_iterator().next() -pred = pred_step(state, test_batch) -``` - -```{code-cell} -:outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e - -fig, axs = plt.subplots(5, 5, figsize=(12, 12)) -for i, ax in enumerate(axs.flatten()): - ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') - ax.set_title(f"label={pred[i]}") - ax.axis('off') -``` - -Congratulations! You made it to the end of the annotated MNIST example. You can revisit -the same example, but structured differently as a couple of Python modules, test -modules, config files, another Colab, and documentation in Flax's Git repo: - -[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist) diff --git a/docs_nnx/why.rst b/docs_nnx/why.rst new file mode 100644 index 00000000..ec808fbd --- /dev/null +++ b/docs_nnx/why.rst @@ -0,0 +1,434 @@ +Why Flax NNX? +============= + +In 2020, the Flax team released the Flax Linen API to support modeling research on JAX, with a focus on scaling +and performance. We have learned a lot from users since then. The team introduced certain ideas that have proven to be beneficial to users, such as: + +* Organizing variables into `collections `_. +* Automatic and efficient `pseudorandom number generator (PRNG) management `_. +* `Variable metadata `_ + for `Single Program Multi Data (SPMD) `_ annotations, optimizer metadata, and other use cases. + +One of the choices the Flax team made was to use functional (``compact``) semantics for neural network programming via lazy initialization of parameters. +This made for concise implementation code and aligned the Flax Linen API with Haiku. + +However, this also meant that the semantics of Modules and variables in Flax were non-Pythonic and often surprising. It also led to implementation +complexity and obscured the core ideas of `transformations (transforms) `_ on neural networks. + +.. testsetup:: Linen, NNX + + import jax + from jax import random, numpy as jnp + from flax import nnx + import flax.linen as nn + +Introducing Flax NNX +-------------------- + +Fast forward to 2024, the Flax team developed Flax NNX - an attempt to retain the features that made Flax Linen useful for users, while introducing some new principles. +The central idea behind Flax NNX is to introduce reference semantics into JAX. The following are its main features: + +- **NNX is Pythonic**: Regular Python semantics for Modules, including support for mutability and shared references. +- **NNX is simple**: Many of the complex APIs in Flax Linen are either simplified using Python idioms or completely removed. +- **Better JAX integration**: Custom NNX transforms adopt the same APIs as the JAX transforms. And with NNX + it is easier to use `JAX transforms (higher-order functions) `_ directly. + +Here is an example of a simple Flax NNX program that illustrates many of the points from above: + +.. testcode:: NNX + + from flax import nnx + import optax + + + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization + optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # Reference sharing. + + @nnx.jit # Automatic state management for JAX transforms. + def train_step(model, optimizer, x, y): + def loss_fn(model): + y_pred = model(x) # call methods directly + return ((y_pred - y) ** 2).mean() + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) # in-place updates + + return loss + +Flax NNX's improvements on Linen +-------------------------------- + +The rest of this document uses various examples that demonstrate how Flax NNX improves on Flax Linen. + +Inspection +^^^^^^^^^^ + +The first improvement is that Flax NNX Modules are regular Python objects. This means that you can easily +construct and inspect ``Module`` objects. + +On the other hand, Flax Linen Modules are not easy to inspect and debug because they are lazy, which means some attributes are not available upon construction and are only accessible at runtime. + +.. codediff:: + :title: Linen, NNX + :sync: + + class Block(nn.Module): + def setup(self): + self.linear = nn.Dense(10) + + block = Block() + + try: + block.linear # AttributeError: "Block" object has no attribute "linear". + except AttributeError as e: + pass + + + + + + ... + + --- + + class Block(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(5, 10, rngs=rngs) + + block = Block(nnx.Rngs(0)) + + + block.linear + # Linear( + # kernel=Param( + # value=Array(shape=(5, 10), dtype=float32) + # ), + # bias=Param( + # value=Array(shape=(10,), dtype=float32) + # ), + # ... + +Notice that in the Flax NNX example above, there is no shape inference - both the input and output shapes must be provided +to the ``Linear`` ``nnx.Module``. This is a tradeoff that allows for more explicit and predictable behavior. + +Running computation +^^^^^^^^^^^^^^^^^^^ + +In Flax Linen, all top-level computation must be done through the ``flax.linen.Module.init`` or ``flax.linen.Module.apply`` methods, and the +parameters or any other type of state are handled as a separate structure. This creates an asymmetry between: 1) code that runs inside +``apply`` that can run methods and other ``Module`` objects directly; and 2) code that runs outside of ``apply`` that must use the ``apply`` method. + +In Flax NNX, there's no special context because parameters are held as attributes and methods can be called directly. That means your NNX Module's ``__init__`` and ``__call__`` methods are not treated differently from other class methods, whereas Flax Linen Module's ``setup()`` and ``__call__`` methods are special. + +.. codediff:: + :title: Linen, NNX + :sync: + + Encoder = lambda: nn.Dense(10) + Decoder = lambda: nn.Dense(2) + + class AutoEncoder(nn.Module): + def setup(self): + self.encoder = Encoder() + self.decoder = Decoder() + + def __call__(self, x) -> jax.Array: + return self.decoder(self.encoder(x)) + + def encode(self, x) -> jax.Array: + return self.encoder(x) + + x = jnp.ones((1, 2)) + model = AutoEncoder() + params = model.init(random.key(0), x)['params'] + + y = model.apply({'params': params}, x) + z = model.apply({'params': params}, x, method='encode') + y = Decoder().apply({'params': params['decoder']}, z) + + --- + + Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) + Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) + + class AutoEncoder(nnx.Module): + def __init__(self, rngs): + self.encoder = Encoder(rngs) + self.decoder = Decoder(rngs) + + def __call__(self, x) -> jax.Array: + return self.decoder(self.encoder(x)) + + def encode(self, x) -> jax.Array: + return self.encoder(x) + + x = jnp.ones((1, 2)) + model = AutoEncoder(nnx.Rngs(0)) + + + y = model(x) + z = model.encode(x) + y = model.decoder(z) + +In Flax Linen, calling sub-Modules directly is not possible because they are not initialized. +Therefore, what you must do is construct a new instance and then provide a proper parameter structure. + +But in Flax NNX you can call sub-Modules directly without any issues. + +State handling +^^^^^^^^^^^^^^ + +One of the areas where Flax Linen is notoriously complex is in state handling. When you use either a +`Dropout` layer, a `BatchNorm` layer, or both, you suddenly have to handle the new state and use it to +configure the ``flax.linen.Module.apply`` method. + +In Flax NNX, state is kept inside an ``nnx.Module`` and is mutable, which means it can just be called directly. + +.. codediff:: + :title: Linen, NNX + :sync: + + class Block(nn.Module): + train: bool + + def setup(self): + self.linear = nn.Dense(10) + self.bn = nn.BatchNorm(use_running_average=not self.train) + self.dropout = nn.Dropout(0.1, deterministic=not self.train) + + def __call__(self, x): + return nn.relu(self.dropout(self.bn(self.linear(x)))) + + x = jnp.ones((1, 5)) + model = Block(train=True) + vs = model.init(random.key(0), x) + params, batch_stats = vs['params'], vs['batch_stats'] + + y, updates = model.apply( + {'params': params, 'batch_stats': batch_stats}, + x, + rngs={'dropout': random.key(1)}, + mutable=['batch_stats'], + ) + batch_stats = updates['batch_stats'] + + --- + + class Block(nnx.Module): + + + def __init__(self, rngs): + self.linear = nnx.Linear(5, 10, rngs=rngs) + self.bn = nnx.BatchNorm(10, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + x = jnp.ones((1, 5)) + model = Block(nnx.Rngs(0)) + + + + y = model(x) + + + + + + ... + +The main benefit of Flax NNX's state handling is that you don't have to change the training code when you add a new stateful layer. + +In addition, in Flax NNX, layers that handle state are also very easy to implement. Below +is a simplified version of a ``BatchNorm`` layer that updates the mean and variance every time it is called. + +.. testcode:: NNX + + class BatchNorm(nnx.Module): + def __init__(self, features: int, mu: float = 0.95): + # Variables + self.scale = nnx.Param(jax.numpy.ones((features,))) + self.bias = nnx.Param(jax.numpy.zeros((features,))) + self.mean = nnx.BatchStat(jax.numpy.zeros((features,))) + self.var = nnx.BatchStat(jax.numpy.ones((features,))) + self.mu = mu # Static + + def __call__(self, x): + mean = jax.numpy.mean(x, axis=-1) + var = jax.numpy.var(x, axis=-1) + # ema updates + self.mean.value = self.mu * self.mean + (1 - self.mu) * mean + self.var.value = self.mu * self.var + (1 - self.mu) * var + # normalize and scale + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) + return x * self.scale + self.bias + + +Model surgery +^^^^^^^^^^^^^ + +In Flax Linen, `model surgery `_ has historically been challenging because of two reasons: + +1. Due to lazy initialization, it is not guaranteed that you can replace a sub-``Module`` with a new one. +2. The parameter structure is separated from the ``flax.linen.Module`` structure, which means you have to manually keep them in sync. + +In Flax NNX, you can replace sub-Modules directly as per the Python semantics. Since parameters are +part of the ``nnx.Module`` structure, they are never out of sync. Below is an example of how you can +implement a LoRA layer, and then use it to replace a ``Linear`` layer in an existing model. + +.. codediff:: + :title: Linen, NNX + :sync: + + class LoraLinear(nn.Module): + linear: nn.Dense + rank: int + + @nn.compact + def __call__(self, x: jax.Array): + A = self.param(random.normal, (x.shape[-1], self.rank)) + B = self.param(random.normal, (self.rank, self.linear.features)) + + return self.linear(x) + x @ A @ B + + try: + model = Block(train=True) + model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR + + lora_params = model.linear.init(random.key(1), x) + lora_params['linear'] = params['linear'] + params['linear'] = lora_params + + except AttributeError as e: + pass + + --- + + class LoraParam(nnx.Param): pass + + class LoraLinear(nnx.Module): + def __init__(self, linear, rank, rngs): + self.linear = linear + self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank))) + self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features))) + + def __call__(self, x: jax.Array): + return self.linear(x) + x @ self.A @ self.B + + rngs = nnx.Rngs(0) + model = Block(rngs) + model.linear = LoraLinear(model.linear, rank=5, rngs=rngs) + + + + + + + ... + +As shown above, in Flax Linen this doesn't really work in this case because the ``linear`` sub-``Module`` +is not available. However, the rest of the code provides an idea of how the ``params`` structure must be manually updated. + +Performing arbitrary model surgery is not easy in Flax Linen, and currently the +`intercept_methods `_ +API is the only way to do generic patching of methods. But this API is not very ergonomic. + +In Flax NNX, to do generic model surgery you can just use ``nnx.iter_graph``, which is much simpler and easier than in Linen. Below is an example of replacing all ``nnx.Linear`` layers in a model with custom-made ``LoraLinear`` NNX layers. + +.. testcode:: NNX + + rngs = nnx.Rngs(0) + model = Block(rngs) + + for path, module in nnx.iter_graph(model): + if isinstance(module, nnx.Module): + for name, value in vars(module).items(): + if isinstance(value, nnx.Linear): + setattr(module, name, LoraLinear(value, rank=5, rngs=rngs)) + +Transforms +^^^^^^^^^^ + +Flax Linen transforms are very powerful in that they enable fine-grained control over the model's state. +However, Flax Linen transforms have drawbacks, such as: + +1. They expose additional APIs that are not part of JAX, making their behavior confusing and sometimes divergent from their JAX counterparts. This also constrains your ways to interact with `JAX transforms `_ and keep up with JAX API changes. +2. They work on functions with very specific signatures, namely: + - A ``flax.linen.Module`` must be the first argument. + - They accept other ``Module`` objects as arguments but not as return values. +3. They can only be used inside ``flax.linen.Module.apply``. + +On the other hand, `Flax NNX transforms `_ +are intented to be equivalent to their corresponding `JAX transforms `_ +with an exception - they can be used on Flax NNX Modules. This means that Flax transforms: + +1) Have the same API as JAX transforms. +2) Can accept Flax NNX Modules on any argument, and ``nnx.Module`` objects can be returned from it/them. +3) Can be used anywhere including the training loop. + +Below is an example of using ``vmap`` with Flax NNX to both create a stack of weights by transforming the +``create_weights`` function, which returns some ``Weights``, and to apply that stack of weights to a batch +of inputs individually by transforming the ``vector_dot`` function, which takes ``Weights`` as the first +argument and a batch of inputs as the second argument. + +.. testcode:: NNX + + class Weights(nnx.Module): + def __init__(self, kernel: jax.Array, bias: jax.Array): + self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) + + def create_weights(seed: jax.Array): + return Weights( + kernel=random.uniform(random.key(seed), (2, 3)), + bias=jnp.zeros((3,)), + ) + + def vector_dot(weights: Weights, x: jax.Array): + assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' + assert x.ndim == 1, 'Batch dimensions not allowed' + return x @ weights.kernel + weights.bias + + seeds = jnp.arange(10) + weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds) + + x = jax.random.normal(random.key(1), (10, 2)) + y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x) + +Contrary to Flax Linen transforms, the ``in_axes`` argument and other APIs do affect how the ``nnx.Module`` state is transformed. + +In addition, Flax NNX transforms can be used as method decorators, because ``nnx.Module`` methods are simply +functions that take a ``Module`` as the first argument. This means that the previous example can be +rewritten as follows: + +.. testcode:: NNX + + class WeightStack(nnx.Module): + @nnx.vmap(in_axes=(0, 0), out_axes=0) + def __init__(self, seed: jax.Array): + self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3))) + self.bias = nnx.Param(jnp.zeros((3,))) + + @nnx.vmap(in_axes=(0, 0), out_axes=1) + def __call__(self, x: jax.Array): + assert self.kernel.ndim == 2, 'Batch dimensions not allowed' + assert x.ndim == 1, 'Batch dimensions not allowed' + return x @ self.kernel + self.bias + + weights = WeightStack(jnp.arange(10)) + + x = jax.random.normal(random.key(1), (10, 2)) + y = weights(x) + + diff --git a/examples/gemma/sampler_test.py b/examples/gemma/sampler_test.py index 2e131dda..56870ca9 100644 --- a/examples/gemma/sampler_test.py +++ b/examples/gemma/sampler_test.py @@ -76,7 +76,7 @@ class SamplerTest(absltest.TestCase): def test_samples(self): vocab = MockVocab() - transformer_config = transformer_lib.TransformerConfig( + transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=6, num_embed=vocab.GetPieceSize(), embed_dim=768, @@ -104,7 +104,7 @@ def test_samples(self): def test_forbidden_tokens(self): vocab = MockVocab() - transformer_config = transformer_lib.TransformerConfig( + transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, @@ -152,7 +152,7 @@ def test_forbidden_tokens(self): def test_forward_equivalence(self): vocab = MockVocab() - transformer_config = transformer_lib.TransformerConfig( + transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=2, num_embed=vocab.GetPieceSize(), embed_dim=32, @@ -211,7 +211,7 @@ def test_forward_equivalence(self): def test_sampler_init_sample_state(self): vocab = MockVocab() - transformer_config = transformer_lib.TransformerConfig( + transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, @@ -247,7 +247,7 @@ def test_sampler_init_sample_state(self): def test_sampler_mask_tokens_after_eos_ids(self): vocab = MockVocab() - transformer_config = transformer_lib.TransformerConfig( + transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, diff --git a/examples/nnx_toy_examples/03_train_state.py b/examples/nnx_toy_examples/03_train_state.py new file mode 100644 index 00000000..c67ef820 --- /dev/null +++ b/examples/nnx_toy_examples/03_train_state.py @@ -0,0 +1,123 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax + +from flax import nnx +from flax.training import train_state + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w.value + self.b.value + + +class Count(nnx.Variable[nnx.A]): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, rngs=rngs) + self.linear2 = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count.value += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + +class TrainState(train_state.TrainState): + counts: nnx.State + graphdef: nnx.GraphDef + +model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) +graphdef, params, counts = nnx.split(model, nnx.Param, Count) + +state = TrainState.create( + apply_fn=None, + graphdef=graphdef, + params=params, + tx=optax.sgd(0.1), + counts=counts, +) +del params, counts + + +@jax.jit +def train_step(state: TrainState, batch): + x, y = batch + + def loss_fn(params): + model = nnx.merge(state.graphdef, params, state.counts) + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + counts = nnx.state(model, Count) + return loss, counts + + grads, counts = jax.grad(loss_fn, has_aux=True)(state.params) + # sdg update + state = state.apply_gradients(grads=grads, counts=counts) + + return state + + +@jax.jit +def test_step(state: nnx.TrainState[MLP], batch): + x, y = batch + model = nnx.merge(state.graphdef, state.params, state.counts) + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + state = train_step(state, batch) + + if step % 1000 == 0: + logs = test_step(state, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = nnx.merge(state.graphdef, state.params, state.counts) +print('times called:', model.count.value) + +y_pred = model(X) + +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/examples/nnx_toy_examples/04_data_parallel_with_jit.py b/examples/nnx_toy_examples/04_data_parallel_with_jit.py new file mode 100644 index 00000000..bd37ff63 --- /dev/null +++ b/examples/nnx_toy_examples/04_data_parallel_with_jit.py @@ -0,0 +1,102 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax import nnx +from jax.experimental import mesh_utils +import matplotlib.pyplot as plt + +# create a mesh + shardings +num_devices = jax.local_device_count() +mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh((num_devices,)), ('data',) +) +model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec()) +data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('data')) + + +# create model +class MLP(nnx.Module): + def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dmid, rngs=rngs) + self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + return self.linear2(nnx.relu(self.linear1(x))) + + +model = MLP(1, 64, 1, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adamw(1e-2)) + +# replicate state +state = nnx.state((model, optimizer)) +state = jax.device_put(state, model_sharding) +nnx.update((model, optimizer), state) + +# visualize model sharding +print('model sharding') +jax.debug.visualize_array_sharding(model.linear1.kernel.value) + + +@nnx.jit +def train_step(model: MLP, optimizer: nnx.Optimizer, x, y): + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) + return loss + + +def dataset(steps, batch_size): + for _ in range(steps): + x = np.random.uniform(-2, 2, size=(batch_size, 1)) + y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape) + yield x, y + + +for step, (x, y) in enumerate(dataset(1000, 16)): + # shard data + x, y = jax.device_put((x, y), data_sharding) + # train + loss = train_step(model, optimizer, x, y) + + if step == 0: + print('data sharding') + jax.debug.visualize_array_sharding(x) + + if step % 100 == 0: + print(f'step={step}, loss={loss}') + +# dereplicate state +state = nnx.state((model, optimizer)) +state = jax.device_get(state) +nnx.update((model, optimizer), state) + +X, Y = next(dataset(1, 1000)) +x_range = np.linspace(X.min(), X.max(), 100)[:, None] +y_pred = model(x_range) + +# plot +plt.scatter(X, Y, label='data') +plt.plot(x_range, y_pred, color='black', label='model') +plt.legend() +plt.show() \ No newline at end of file diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py new file mode 100644 index 00000000..f5cf8002 --- /dev/null +++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py @@ -0,0 +1,171 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import os +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' + +from matplotlib import pyplot as plt +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +import typing as tp + +mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh((2, 4)), + ('data', 'model'), +) + + +def named_sharding(*names: str | None) -> NamedSharding: + return NamedSharding(mesh, P(*names)) + + +@dataclasses.dataclass(unsafe_hash=True) +class MeshRules: + embed: str | None = None + mlp: str | None = None + data: str | None = None + + def __call__(self, *keys: str) -> tuple[str, ...]: + return tuple(getattr(self, key) for key in keys) + + +mesh_rules = MeshRules( + embed=None, + mlp='model', + data='data', +) + + +class MLP(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.w1 = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)), + sharding=mesh_rules('embed', 'mlp'), + ) + self.b1 = nnx.Param( + jnp.zeros((dmid,)), + sharding=mesh_rules('mlp'), + ) + self.w2 = nnx.Param( + nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)), + sharding=mesh_rules('embed', 'mlp'), + ) + + def __call__(self, x: jax.Array): + return nnx.relu(x @ self.w1 + self.b1) @ self.w2 + + +class SGDState(nnx.Variable): + pass + + +class SGD(nnx.Object): + def __init__(self, params: nnx.State, lr, decay=0.9): + def init_optimizer_state(variable: nnx.Variable): + return SGDState( + jnp.zeros_like(variable.value), **variable.get_metadata() + ) + + self.lr = lr + self.params = params + self.momentum = jax.tree.map(init_optimizer_state, self.params) + self.decay = decay + + def update(self, grads: nnx.State): + def update_fn( + params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState + ): + # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t) + momentum.value = self.decay * momentum + (1 - self.decay) * grad.value + # θ_{t+1} = θ_t - α * v_t + params.value -= self.lr * momentum + + jax.tree.map(update_fn, self.params, self.momentum, grads) + + +@nnx.jit +def create_model(): + model = MLP(1, 32, 1, rngs=nnx.Rngs(0)) + optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9) + state = nnx.state(optimizer) + sharded_state = jax.lax.with_sharding_constraint( + state, nnx.get_named_sharding(state, mesh) + ) + + def get_named_shardings(path: tuple, value: nnx.VariableState): + if path[0] == 'params': + return value.replace(NamedSharding(mesh, P(*value.sharding))) + elif path[0] == 'momentum': + # currently the same as above but in general it could be different + return value.replace(NamedSharding(mesh, P(*value.sharding))) + else: + raise ValueError(f'Unknown path: {path}') + + named_shardings = state.map(get_named_shardings) + sharded_state = jax.lax.with_sharding_constraint(state, named_shardings) + nnx.update(optimizer, sharded_state) + return model, optimizer + + +model, optimizer = create_model() + +jax.debug.visualize_array_sharding(model.w1.value) +jax.debug.visualize_array_sharding(optimizer.momentum.w1.value) + + +@nnx.jit +def train_step(model: MLP, optimizer: SGD, x, y): + def loss_fn(model): + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return loss + + loss, grad = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grad) + return loss + + +X = np.linspace(-2, 2, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size, num_steps): + for _ in range(num_steps): + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +losses = [] +for step, (x_batch, y_batch) in enumerate( + dataset(batch_size=32, num_steps=10_000) +): + x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data')) + loss = train_step(model, optimizer, x_batch, y_batch) + losses.append(float(loss)) + if step % 1000 == 0: + print(f'Step {step}: Loss = {loss}') + +plt.figure() +plt.plot(losses[20:]) + +y_pred = model(X) +plt.figure() +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/configurations.py b/flax/configurations.py index 4f61170f..ba19a572 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -162,9 +162,6 @@ def temp_flip_flag(var_name: str, var_value: bool): # Flax Global Configuration Variables: -# Whether to use the lazy rng implementation. -flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True) - flax_filter_frames = bool_flag( name='flax_filter_frames', default=True, diff --git a/flax/core/lift.py b/flax/core/lift.py index 8daa7187..f7b7bfb7 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -24,6 +24,7 @@ import warnings from flax import traceback_util +from flax import traverse_util from flax.typing import ( In, InOutAxis, @@ -1499,6 +1500,81 @@ def _hashable_filter(x): return x +class CountsHolder: + + def __init__(self, flat_d): + self.flat_d = flat_d + + @classmethod + def make(cls, d): + flat_d = traverse_util.flatten_dict(d) + flat_d = {k: v for k, v in flat_d.items()} + return cls(flat_d) + + def sub(self, other): + delta_flat_d = {} + new_flat_d = collections.defaultdict(int, self.flat_d) + old_flat_d = collections.defaultdict(int, other.flat_d) + for k in new_flat_d: + delta_flat_d[k] = new_flat_d[k] - old_flat_d[k] + return CountsHolder(delta_flat_d) + + def add(self, other): + delta_flat_d = {} + new_flat_d = collections.defaultdict(int, self.flat_d) + old_flat_d = collections.defaultdict(int, other.flat_d) + for k in new_flat_d: + delta_flat_d[k] = new_flat_d[k] + old_flat_d[k] + return CountsHolder(delta_flat_d) + + def unflat(self): + return traverse_util.unflatten_dict(self.flat_d) + + +def set_from_dict(original, updates): + for k in updates: + if k not in original: + original[k] = updates[k] + else: + if isinstance(updates[k], dict): + set_from_dict(original[k], updates[k]) + else: + original[k] = updates[k] + + +class _SideEffectCache(threading.local): + + def __init__(self): + self.cache = {} + + +_side_effect_cache = _SideEffectCache() + + +def _restore_rng_counters(scopes, fingerprint, capture_old_counts): + if fingerprint not in _side_effect_cache.cache: + capture_new_counts = jax.tree.map( + lambda s: CountsHolder.make(s.rng_counters), scopes + ) + capture_delta_counts = jax.tree.map( + lambda old, new: new.sub(old), + capture_old_counts, + capture_new_counts, + ) + _side_effect_cache.cache[fingerprint] = capture_delta_counts + else: + updated_counts = jax.tree.map( + lambda x, y: x.add(y).unflat(), + _side_effect_cache.cache[fingerprint], + capture_old_counts, + ) + jax.tree.map( + lambda s, u: set_from_dict(s.rng_counters, u), + scopes, + updated_counts, + ) + + def jit( fn: Callable[..., Any], variables: CollectionFilter = True, @@ -1599,13 +1675,18 @@ def inner( mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) rng_groups = jax.tree.map( - lambda x: x.fold() if isinstance(x, LazyRng) else x, + lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, rng_groups, is_leaf=lambda x: isinstance(x, LazyRng), ) fingerprint = (mutable, module_hash_key) - return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) + capture_old_counts = jax.tree.map( + lambda s: CountsHolder.make(s.rng_counters), scopes + ) + res = jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) + _restore_rng_counters(scopes, fingerprint, capture_old_counts) + return res return pack( inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True @@ -1692,3 +1773,64 @@ def inner_loop(scope, carry): def _unzip2(xs): ys = tuple(zip(*xs)) return ys if ys else ((), ()) + + +def fold_rngs( + fn: Callable[..., Any], + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Callable[..., Any]: + # Close over scope_fn & repack_fn to avoid recompilation + # this is impure but we use the fingerprint arg to differentiate between cases + # where scope_fn or repack_fn actually produce non-identical results. + fold_rngs_context = TransformContext[tuple[Callable, Callable]]() + + @functools.wraps(fn) + def wrapped_fold_rngs(fingerprint, variable_groups, rng_groups, *args, **kwargs): + scope_fn, repack_fn = fold_rngs_context.get() + hash_key = fingerprint[1] + # fingerprint is only used to differentiate the cache signature + # del fingerprint + scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable + y = fn(scope, hash_key, *args, **kwargs) + return y, repack_fn(scope) # pylint: disable=not-callable + + def inner_fold_rngs( + scope_fn, + repack_fn, + variable_groups, + rng_groups, + module_hash_key, + *args, + **kwargs, + ): + with fold_rngs_context.push((scope_fn, repack_fn)): + scopes: list[Scope] = jax.tree_util.tree_leaves( + scope_fn(variable_groups, rng_groups) + ) + mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) + + rng_groups = jax.tree.map( + lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, + rng_groups, + is_leaf=lambda x: isinstance(x, LazyRng), + ) + + fingerprint = (mutable, module_hash_key) + capture_old_counts = jax.tree.map( + lambda s: CountsHolder.make(s.rng_counters), scopes + ) + res = wrapped_fold_rngs( + fingerprint, variable_groups, rng_groups, *args, **kwargs + ) + _restore_rng_counters(scopes, fingerprint, capture_old_counts) + return res + + return pack( + inner_fold_rngs, + (variables,), + (variables,), + (rngs,), + name='fold_rngs', + enable_kwargs=True, + ) diff --git a/flax/core/scope.py b/flax/core/scope.py index ea8a586b..e18789f4 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -38,7 +38,6 @@ from jax import random, tree_util from flax import config as config -from flax import configurations as legacy_config # only for flax_lazy_rng from flax import errors, struct, traceback_util from flax.ids import uuid from flax.typing import ( @@ -98,37 +97,16 @@ def as_jax_rng(self) -> PRNGKey: def create( rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable ) -> 'LazyRng': - if not legacy_config.flax_lazy_rng: - if isinstance(rng, LazyRng): - assert not rng.suffix - rng = rng.rng - return LazyRng(_legacy_rng_fold_in(rng, suffix), ()) if isinstance(rng, LazyRng): return LazyRng(rng.rng, rng.suffix + suffix) else: return LazyRng(rng, suffix) - def fold(self): - key = self.as_jax_rng() + def clear_suffix(self): + key = self.rng return LazyRng(key, ()) -def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: - """Legacy RNG folding.""" - for x in data: - if isinstance(x, str): - m = hashlib.sha1() - m.update(x.encode('utf-8')) - d = m.digest() - hash_int = int.from_bytes(d[:4], byteorder='big') - rng = random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore - elif isinstance(x, int): - rng = random.fold_in(rng, x) - else: - raise ValueError(f'Expected int or string, got: {x}') - return rng - - def _fold_in_static( rng: PRNGKey, data: typing.Collection[PRNGFoldable] ) -> PRNGKey: @@ -605,13 +583,6 @@ def default_name(self, prefix: str) -> str: return name i += 1 - def fold_rngs(self): - """Folds the rngs of this scope into the parent scope.""" - self._check_valid() - for name, rng in self.rngs.items(): - assert isinstance(rng, LazyRng) - self.rngs[name] = rng.fold() - def push( self, name: str | None = None, prefix: str = '', reuse=False ) -> 'Scope': @@ -1218,12 +1189,8 @@ def _is_valid_rng(rng: Array): return False # Handle new-style typed PRNG keys - if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer - if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): - return rng.shape == () - elif hasattr(jax.random, 'PRNGKeyArray'): # Previous JAX versions - if isinstance(rng, jax.random.PRNGKeyArray): - return rng.shape == () + if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): + return rng.shape == () # Handle old-style raw PRNG keys expected_rng = jax.eval_shape( diff --git a/flax/core/spmd.py b/flax/core/spmd.py new file mode 100644 index 00000000..3c4efe40 --- /dev/null +++ b/flax/core/spmd.py @@ -0,0 +1,80 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import threading + +from flax.typing import ( + LogicalRules, + Sharding, +) + +# Dynamic Axis Mapping Context +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass +class _AxisRules(threading.local): + """Dynamic logical axis to mesh axis binding context.""" + + rules: LogicalRules = () + + +# Global axis binding context. +_axis_rules = _AxisRules() + + +def set_logical_axis_rules(rules: LogicalRules): + """Sets the global logical axis to mesh axis binding.""" + _axis_rules.rules = rules + + +def get_logical_axis_rules() -> LogicalRules: + """Returns the global logical axis to mesh axis binding.""" + return _axis_rules.rules + + +@contextlib.contextmanager +def logical_axis_rules(rules: LogicalRules): + """Context manager for setting the logical to mesh axis bindings.""" + old_rules = _axis_rules.rules + try: + _axis_rules.rules = rules + yield + finally: + _axis_rules.rules = old_rules + + +def composite_rules(rule1, rule2): + if not rule1 and not rule2: + return () + rules = {alias: value for alias, value in rule1} + for alias, value in rule2: + if alias in rules and rules[alias] != value: + raise ValueError( + f'Inconsistent logical axis annotations for {alias}: ' + f'{rules[alias]} vs {value}' + ) + rules[alias] = value + return tuple(rules.items()) + + +def from_sharding_rules( + sharding: Sharding, sharding_rules: LogicalRules +) -> Sharding: + rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} + return tuple( + rules[str(s)] if (s and str(s) in rules) else s for s in sharding + ) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index fe2ff874..b380f6aa 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -16,8 +16,6 @@ import jax -from .. import errors - def current_trace(): """Returns the current JAX state tracer.""" @@ -31,6 +29,5 @@ def current_trace(): return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): - level = current_trace() - if level != base_level: - raise errors.JaxTransformError() + pass + # TODO: re-enable when we update flax to use stackless trace context diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 9b80ca3c..9926af25 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -31,8 +31,12 @@ unbox as unbox, with_partitioning as with_partitioning, ) +from flax.core.spmd import ( + get_logical_axis_rules as get_logical_axis_rules, + logical_axis_rules as logical_axis_rules, + set_logical_axis_rules as set_logical_axis_rules, +) from .activation import ( - GeGLU as GeGLU, PReLU as PReLU, celu as celu, elu as elu, @@ -73,8 +77,8 @@ from .batch_apply import BatchApply as BatchApply from .combinators import Sequential as Sequential from .fp8_ops import ( - Fp8DotGeneralOp as Fp8DotGeneralOp, Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp, + Fp8DotGeneralOp as Fp8DotGeneralOp, NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp, ) from .initializers import ( @@ -96,8 +100,8 @@ Module as Module, Variable as Variable, apply as apply, - compact as compact, compact_name_scope as compact_name_scope, + compact as compact, disable_named_call as disable_named_call, enable_named_call as enable_named_call, init_with_output as init_with_output, @@ -115,28 +119,25 @@ LayerNorm as LayerNorm, RMSNorm as RMSNorm, SpectralNorm as SpectralNorm, - WeightNorm as WeightNorm + WeightNorm as WeightNorm, ) from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) from .recurrent import ( Bidirectional as Bidirectional, ConvLSTMCell as ConvLSTMCell, - SimpleCell as SimpleCell, GRUCell as GRUCell, - MGUCell as MGUCell, LSTMCell as LSTMCell, + MGUCell as MGUCell, OptimizedLSTMCell as OptimizedLSTMCell, RNNCellBase as RNNCellBase, RNN as RNN, + SimpleCell as SimpleCell, ) from .spmd import ( LogicallyPartitioned as LogicallyPartitioned, - get_logical_axis_rules as get_logical_axis_rules, - logical_axis_rules as logical_axis_rules, logical_to_mesh, logical_to_mesh_axes, logical_to_mesh_sharding, - set_logical_axis_rules as set_logical_axis_rules, with_logical_constraint, with_logical_partitioning as with_logical_partitioning, ) @@ -147,6 +148,8 @@ checkpoint as checkpoint, cond as cond, custom_vjp as custom_vjp, + fold_rngs as fold_rngs, + grad as grad, jit as jit, jvp as jvp, map_variables as map_variables, @@ -155,9 +158,8 @@ remat as remat, scan as scan, switch as switch, - vjp as vjp, - grad as grad, value_and_grad as value_and_grad, + vjp as vjp, vmap as vmap, while_loop as while_loop, ) diff --git a/flax/linen/activation.py b/flax/linen/activation.py index 3f36bdfc..8ccfff0d 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -98,44 +98,3 @@ def __call__(self, inputs: Array) -> Array: return jnp.where( inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs ) - -class GeGLU(Module): - """Gated Linear Unit with GELU (GeGLU) activation function. - - GeGLU is a Flax layer that combines a linear transformation with a GELU - activation function in a gating mechanism. It is often used in Transformer models - to provide non-linear capabilities while preserving a strong linear component. - - Example usage:: - >>> import flax.linen as nn - - >>> class TransformerBlock(nn.Module): - ... @nn.compact - ... def __call__(self, x): - ... x = nn.Dense(2)(x) - ... x = nn.GeGLU()(x) # initialized - ... return x - - Attributes: - features: the number of output features (default: None). - """ - output_dim: int = -1 - - @compact - def __call__(self, inputs: Array) -> Array: - """Applies the GeGLU activation to the inputs. - - Args: - inputs: the nd-array to apply the GeGLU activation function to. - - Returns: - The transformed input. - """ - if self.output_dim == -1: - output_dim = inputs.shape[-1] - else: - output_dim = self.output_dim - - x = Dense(output_dim * 2)(inputs) - x, gate = x[..., : output_dim], x[..., output_dim :] - return x * gelu(gate) \ No newline at end of file diff --git a/flax/linen/module.py b/flax/linen/module.py index 9de568b8..f8a57b95 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2702,11 +2702,13 @@ def perturb( if not self.scope.has_variable(collection, name): self.scope.reserve(name, collection) self._state.children[name] = collection - self.scope.put_variable(collection, name, jnp.zeros_like(value)) # type: ignore + zeros = jax.tree.map(jnp.zeros_like, value) + self.scope.put_variable(collection, name, zeros) # type: ignore if collection in self.scope.root._variables: if self.scope.has_variable(collection, name): - value += self.scope.get_variable(collection, name) # type: ignore + old_value = self.scope.get_variable(collection, name) + value = jax.tree.map(jnp.add, value, old_value) # type: ignore else: raise ValueError(f"Perturbation collection {collection} present, but " f"missing perturbation variable {name}") diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 64ca0da4..340b5d03 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -160,6 +160,7 @@ def _normalize( use_scale: bool, bias_init: Initializer, scale_init: Initializer, + force_float32_reductions: bool = True ): """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. @@ -179,6 +180,9 @@ def _normalize( use_scale: If true, scale the output. bias_init: Initialization function for the bias term. scale_init: Initialization function for the scaling function. + force_float32_reductions: If false, the scale and bias parameters use the + param_dtype. Otherwise, they will have at least float32 precision due to + the mean and var being promoted to float32. Returns: The normalized input. @@ -200,6 +204,8 @@ def _normalize( scale = mdl.param( 'scale', scale_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) + if not force_float32_reductions: + scale = jnp.asarray(scale, param_dtype) mul *= scale args.append(scale) y *= mul @@ -207,6 +213,8 @@ def _normalize( bias = mdl.param( 'bias', bias_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) + if not force_float32_reductions: + bias = jnp.asarray(bias, param_dtype) y += bias args.append(bias) dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) @@ -343,17 +351,35 @@ def __call__( feature_shape = [x.shape[ax] for ax in feature_axes] ra_mean = self.variable( - 'batch_stats', - 'mean', - lambda s: jnp.zeros(s, jnp.float32), - feature_shape, + 'batch_stats', + 'mean', + lambda s: jnp.zeros( + s, + jnp.float32 if self.force_float32_reductions else self.param_dtype, + ), + feature_shape, ) ra_var = self.variable( - 'batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), feature_shape + 'batch_stats', + 'var', + lambda s: jnp.ones( + s, + jnp.float32 if self.force_float32_reductions else self.param_dtype, + ), + feature_shape, ) if use_running_average: - mean, var = ra_mean.value, ra_var.value + mean = ( + ra_mean.value + if self.force_float32_reductions + else jnp.asarray(ra_mean.value, self.param_dtype) + ) + var = ( + ra_var.value + if self.force_float32_reductions + else jnp.asarray(ra_var.value, self.param_dtype) + ) else: mean, var = _compute_stats( x, @@ -386,6 +412,7 @@ def __call__( self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -502,6 +529,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -602,6 +630,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, initializers.zeros, self.scale_init, + self.force_float32_reductions, ) @@ -781,6 +810,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) @@ -905,6 +935,7 @@ def __call__(self, x, *, mask: jax.Array | None = None): self.use_scale, self.bias_init, self.scale_init, + self.force_float32_reductions, ) diff --git a/flax/linen/partitioning.py b/flax/linen/partitioning.py index 71045ba6..ab4e59c0 100644 --- a/flax/linen/partitioning.py +++ b/flax/linen/partitioning.py @@ -42,16 +42,14 @@ CollectionFilter as CollectionFilter, PRNGSequenceFilter as PRNGSequenceFilter, ) -from flax.linen.spmd import _axis_rules # pylint: disable=unused-import -from flax.linen.spmd import _AxisRules # pylint: disable=unused-import +from flax.core.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import +from flax.core.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import +from flax.core.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import from flax.linen.spmd import _is_logical_spec from flax.linen.spmd import _with_sharding_constraint # pylint: disable=unused-import -from flax.linen.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import -from flax.linen.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import from flax.linen.spmd import logical_to_mesh # pylint: disable=unused-import from flax.linen.spmd import logical_to_mesh_axes # pylint: disable=unused-import from flax.linen.spmd import RulesFallback -from flax.linen.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import from flax.linen.spmd import with_logical_constraint as with_sharding_constraint from flax.traverse_util import flatten_dict from flax.traverse_util import unflatten_dict diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index cd622bbd..5226218e 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -25,11 +25,9 @@ """ import collections -import contextlib import dataclasses import enum import functools -import threading from typing import Any from collections.abc import Callable, Sequence @@ -39,6 +37,9 @@ from flax import struct from flax.core import meta +from flax.core.spmd import ( + get_logical_axis_rules, +) from flax.typing import ( Array, LogicalNames, @@ -49,42 +50,6 @@ ) -# Dynamic Axis Mapping Context -# ------------------------------------------------------------------------------ - - -@dataclasses.dataclass -class _AxisRules(threading.local): - """Dynamic logical axis to mesh axis binding context.""" - - rules: LogicalRules = () - - -# Global axis binding context. -_axis_rules = _AxisRules() - - -def set_logical_axis_rules(rules: LogicalRules): - """Sets the global logical axis to mesh axis binding.""" - _axis_rules.rules = rules - - -def get_logical_axis_rules() -> LogicalRules: - """Returns the global logical axis to mesh axis binding.""" - return _axis_rules.rules - - -@contextlib.contextmanager -def logical_axis_rules(rules: LogicalRules): - """Context manager for setting the logical to mesh axis bindings.""" - old_rules = _axis_rules.rules - try: - _axis_rules.rules = rules - yield - finally: - _axis_rules.rules = old_rules - - class _UnassignedAxis: """Sentinel class for unassigned logical axis name.""" @@ -115,7 +80,7 @@ def _logical_to_mesh_axes( if array_dim_names is None: return None if rules is None: - rules = _axis_rules.rules + rules = get_logical_axis_rules() axis_name_counts = collections.Counter(array_dim_names) dups = tuple( k for k, v in axis_name_counts.items() if v > 1 and k is not None @@ -292,7 +257,7 @@ def with_logical_constraint( """Version of jit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: - rules = _axis_rules.rules + rules = get_logical_axis_rules() if not rules or logical_axis_resources is None: return x # Translate logical names to mesh assignments. diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 06ffd519..80c44f99 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -23,6 +23,8 @@ function that takes a ``Module`` instance as its first argument. """ +from collections.abc import Callable, Iterable, Mapping, Sequence +import contextlib import dataclasses import functools import inspect @@ -32,7 +34,6 @@ Union, ) import weakref -from collections.abc import Callable, Iterable, Mapping, Sequence from flax import core from flax import errors, struct, traceback_util @@ -41,6 +42,7 @@ from flax.core.frozen_dict import FrozenDict from flax.core.scope import ( CollectionFilter, + LazyRng, PRNGSequenceFilter, ) from flax.ids import FlaxId @@ -579,60 +581,82 @@ def wrapped_fn(self: Module, *args, **kwargs): nonlocal trafo_fn state = self._state.export() - # make a scope-function to transform - def core_fn( - prewrapped_fn, - class_fn, - scopes, - module_hash, - *args, - **kwargs, - ): - # self = hash_key.obj - self: Module = module_hash.module - if not multi_scope: - scopes = [scopes] - cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) - object.__setattr__(cloned, '_state', state.export()) - res = prewrapped_fn(cloned, *args, **kwargs) - self._state.reimport(cloned._state) - _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) - return res - - core_fns = [ - functools.wraps(class_fn)( - functools.partial(core_fn, prewrapped_fn, class_fn) + # increment rng counters for all rngs in scope + with fork_rngs(self): + # make a scope-function to transform + def core_fn( + prewrapped_fn, + class_fn, + scopes, + module_hash, + *args, + **kwargs, + ): + # self = hash_key.obj + self: Module = module_hash.module + if not multi_scope: + scopes = [scopes] + cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = prewrapped_fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values( + res, getattr(class_fn, '__name__', None) ) - for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) - ] + return res - # here we apply the given lifting transform to the scope-ingesting fn - if trafo_fn is None: - trafo_fn = transform(*core_fns, **trafo_kwargs) + core_fns = [ + functools.wraps(class_fn)( + functools.partial(core_fn, prewrapped_fn, class_fn) + ) + for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) + ] - module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + # here we apply the given lifting transform to the scope-ingesting fn + if trafo_fn is None: + trafo_fn = transform(*core_fns, **trafo_kwargs) - if not multi_scope: - if len(module_scopes) != 1: - # TODO(levskaya): transforms like jvp & vjp have args that follow the - # pytree structure of scopes. The user doesn't explicitly control shared - # modules passed as arguments to methods or as attributes to Module - # constructors. Therefore, there is no obvious API for specifying - # arguments per lifted Module. - raise NotImplementedError( - 'This transform does not yet support' - ' Modules that include other Modules passed as arguments.' - ) - module_scopes = module_scopes[0] + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) - # get a hashable proxy object for the Module - hash_key = _HashableProxy.from_module(self) + if not multi_scope: + if len(module_scopes) != 1: + # TODO(levskaya): transforms like jvp & vjp have args that follow the + # pytree structure of scopes. The user doesn't explicitly control shared + # modules passed as arguments to methods or as attributes to Module + # constructors. Therefore, there is no obvious API for specifying + # arguments per lifted Module. + raise NotImplementedError( + 'This transform does not yet support' + ' Modules that include other Modules passed as arguments.' + ) + module_scopes = module_scopes[0] + + # get a hashable proxy object for the Module + hash_key = _HashableProxy.from_module(self) - return trafo_fn(module_scopes, hash_key, *args, **kwargs) + return trafo_fn(module_scopes, hash_key, *args, **kwargs) return wrapped_fn +@contextlib.contextmanager +def fork_rngs(module: Module): + """Context manager to fork rngs in a module.""" + if module.scope is None: + yield + return + + current_rngs = module.scope.rngs.copy() + module.scope.rngs = { + name: LazyRng.create(module.make_rng(name)) for name in current_rngs + } + + try: + yield + finally: + module.scope.rngs = current_rngs + + def module_class_lift_transform_cached( transform, module_class, methods=None, **trafo_kwargs ): @@ -674,36 +698,39 @@ def create_trans_fn(fn_name, fn_trafo_args): # we need to create a scope-function from our class for the given method @functools.wraps(fn) def wrapped_fn(self: Module, *args, **kwargs): + assert self.scope is not None nonlocal trafo_fn state = self._state.export() - # make a scope-function to transform - def core_fn(scopes, module_hash, *args, **kwargs): - self: Module = module_hash.module - # make a clone of self using its arguments - attrs = { - f.name: getattr(self, f.name) - for f in dataclasses.fields(self) - if f.name != 'parent' and f.init - } - # we reference module_class, not self.__class__ to avoid infinite loop - cloned = module_class(parent=None, **attrs) - cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) - object.__setattr__(cloned, '_state', state.export()) - res = fn(cloned, *args, **kwargs) - self._state.reimport(cloned._state) - _test_transformed_return_values(res, fn_name) - return res - - # here we apply the given lifting transform to the scope-ingesting fn - trafo_fn = trafo_fn or transform(core_fn, *trafo_args, **trafo_kwargs) - module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) - - # get a hash for the Module by using its repr as a proxy - hash_key = _HashableProxy.from_module(self) - - ret = trafo_fn(module_scopes, hash_key, *args, **kwargs) - return ret + # increment rng counters for all rngs in scope + with fork_rngs(self): + # make a scope-function to transform + def core_fn(scopes, module_hash, *args, **kwargs): + self: Module = module_hash.module + # make a clone of self using its arguments + attrs = { + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name != 'parent' and f.init + } + # we reference module_class, not self.__class__ to avoid infinite loop + cloned = module_class(parent=None, **attrs) + cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values(res, fn_name) + return res + + # here we apply the given lifting transform to the scope-ingesting fn + trafo_fn = trafo_fn or transform(core_fn, *trafo_args, **trafo_kwargs) + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + + # get a hash for the Module by using its repr as a proxy + hash_key = _HashableProxy.from_module(self) + + ret = trafo_fn(module_scopes, hash_key, *args, **kwargs) + return ret return wrapped_fn @@ -2140,3 +2167,16 @@ def remove_fn(axis): mutable=True, ) return target + + +def fold_rngs( + target: Target, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Target: + return lift_transform_cached( + lift.fold_rngs, + target, + variables=variables, + rngs=rngs, + ) diff --git a/flax/nnx/README.md b/flax/nnx/README.md index 7a39eb13..14829c0f 100644 --- a/flax/nnx/README.md +++ b/flax/nnx/README.md @@ -64,10 +64,10 @@ pip install git+https://github.com/google/flax.git ### Examples -* [LM1B](https://github.com/google/flax/tree/main/flax/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. +* [LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples -* [Basic Example](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. -* [Using the Functional API](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. -* [Scan over layers](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Basic Example](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. +* [Using the Functional API](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Training a VAE](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. +* [Scan over layers](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 04554ea7..6a27b090 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -55,6 +55,7 @@ from .graph import split_context as split_context from .graph import MergeContext as MergeContext from .graph import merge_context as merge_context +from .graph import variables as variables from .nn import initializers as initializers from .nn.activations import celu as celu from .nn.activations import elu as elu @@ -85,6 +86,13 @@ from .nn.attention import dot_product_attention as dot_product_attention from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask +from .nn.recurrent import RNNCellBase as RNNCellBase +from .nn.recurrent import LSTMCell as LSTMCell +from .nn.recurrent import GRUCell as GRUCell +from .nn.recurrent import OptimizedLSTMCell as OptimizedLSTMCell +from .nn.recurrent import SimpleCell as SimpleCell +from .nn.recurrent import RNN as RNN +from .nn.recurrent import Bidirectional as Bidirectional from .nn.linear import Conv as Conv from .nn.linear import ConvTranspose as ConvTranspose from .nn.linear import Embed as Embed @@ -116,7 +124,7 @@ from .spmd import with_sharding_constraint as with_sharding_constraint from .statelib import State as State from .training import metrics as metrics -from .variables import ( +from .variablelib import ( Param as Param, ) # this needs to be imported before optimizer to prevent circular import @@ -142,15 +150,19 @@ from .transforms.iteration import pmap as pmap from .transforms.transforms import eval_shape as eval_shape from .transforms.transforms import cond as cond +from .transforms.transforms import switch as switch +from .transforms.transforms import checkify as checkify +from .transforms.iteration import while_loop as while_loop +from .transforms.iteration import fori_loop as fori_loop from .transforms.iteration import StateAxes as StateAxes -from .variables import A as A -from .variables import BatchStat as BatchStat -from .variables import Cache as Cache -from .variables import Intermediate as Intermediate -from .variables import Variable as Variable -from .variables import VariableState as VariableState -from .variables import VariableMetadata as VariableMetadata -from .variables import with_metadata as with_metadata +from .variablelib import A as A +from .variablelib import BatchStat as BatchStat +from .variablelib import Cache as Cache +from .variablelib import Intermediate as Intermediate +from .variablelib import Variable as Variable +from .variablelib import VariableState as VariableState +from .variablelib import VariableMetadata as VariableMetadata +from .variablelib import with_metadata as with_metadata from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 862ab45e..7ed1b46a 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -13,10 +13,6 @@ # limitations under the License. -from .module import ModuleMeta as ModuleMeta -from .module import Module as Module -from .module import Scope as Scope -from .module import compact as compact from .wrappers import functional as functional from .wrappers import Functional as Functional from .wrappers import ToNNX as ToNNX diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py deleted file mode 100644 index 39f66d40..00000000 --- a/flax/nnx/bridge/module.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from collections import defaultdict -import dataclasses -import functools -import threading -import typing as tp -import typing_extensions as tpe - -from flax.nnx import graph, rnglib -import flax.nnx.module as nnx_module -from flax.nnx.proxy_caller import ( - CallableProxy, - DelayedAccessor, -) -from flax.nnx.object import Object - -M = tp.TypeVar('M', bound='Module') -F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) - - -@dataclasses.dataclass -class CompactContext: - module: Module - type_counter: defaultdict[type, int] = dataclasses.field( - default_factory=lambda: defaultdict(int) - ) - - -@dataclasses.dataclass -class ModuleContext(threading.local): - parent_stack: list[tp.Optional[CompactContext]] = dataclasses.field( - default_factory=lambda: [None] - ) - - -MODULE_CONTEXT = ModuleContext() - - -@dataclasses.dataclass -class Scope(Object): - rngs: rnglib.Rngs - - -@tp.runtime_checkable -class _HasSetup(tp.Protocol): - def setup(self) -> None: ... - - -class ModuleMeta(nnx_module.ModuleMeta): - if not tp.TYPE_CHECKING: - - def __call__(cls, *args, **kwargs): - return _module_meta_call(cls, *args, **kwargs) - - -def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M: - # compact behavior - parent_ctx = MODULE_CONTEXT.parent_stack[-1] - parent = None - module: M - - if parent_ctx is not None: - if 'parent' in kwargs: - parent = kwargs.pop('parent') - if parent is not None: - raise ValueError( - f"'parent' can only be set to None, got {type(parent).__name__}" - ) - name = None - else: - type_index = parent_ctx.type_counter[cls] - parent_ctx.type_counter[cls] += 1 - - # define the name - if 'name' in kwargs: - name = kwargs.pop('name') - if not isinstance(name, str): - raise ValueError(f"'name' must be a 'str', got {type(name).__name__}") - else: - name = f'{cls.__name__}_{type_index}' - - parent = parent_ctx.module - - if hasattr(parent, name): - module = getattr(parent, name) - return module - else: - name = None - - module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs) - module.scope = None - - if parent is not None: - assert name is not None - setattr(parent, name, module) - # adopt the parent scope - module.scope = parent.scope - - if dataclasses.is_dataclass(module): - if isinstance(module, _HasSetup): - module.setup() - - return module - - -class ModuleBase: - if tp.TYPE_CHECKING: - scope: Scope | None - - -@tpe.dataclass_transform(field_specifiers=(dataclasses.field,)) # type: ignore[not-supported-yet] -class Module(nnx_module.Module, ModuleBase, metaclass=ModuleMeta): - def _set_scope(self, scope: Scope | None): - """Recursively sets the scope for the Module and its children.""" - for _, value in graph.iter_graph(self): - if isinstance(value, Module): - value.scope = scope - - @property - def init(self: M) -> M: - """Calls a method in initialization mode. - - When a method is called using ``init``, the ``is_initializing`` method - will return ``True``. This is useful to implement Modules that support - lazy initialization. - - Example:: - - >>> from flax import nnx - >>> from flax.nnx import bridge as nnb - >>> import jax - >>> import jax.numpy as jnp - ... - >>> class Linear(nnb.Module): - ... def __init__(self, dout, rngs: nnx.Rngs): - ... self.dout = dout - ... self.rngs = rngs - ... - ... def __call__(self, x): - ... if self.is_initializing(): - ... din = x.shape[-1] - ... if not hasattr(self, 'w'): - ... key = self.rngs.params() - ... self.w = nnx.Param(jax.random.uniform(key, (din, self.dout))) - ... if not hasattr(self, 'b'): - ... self.b = nnx.Param(jnp.zeros((self.dout,))) - ... - ... return x @ self.w + self.b - ... - >>> linear = Linear(3, nnx.Rngs(0)) - >>> x = jnp.ones((5, 2)) - >>> y = linear.init(x) - >>> linear.w.value.shape - (2, 3) - >>> linear.b.value.shape - (3,) - >>> y.shape - (5, 3) - """ - - def _init_context(accessor: DelayedAccessor, *args, **kwargs): - for _, value in graph.iter_graph(self): - if isinstance(value, Object): - value._object__state._initializing = True - - method = accessor(self) - try: - out = method(*args, **kwargs) - finally: - for _, value in graph.iter_graph(self): - if isinstance(value, Object): - value._object__state._initializing = False - - return out - - return CallableProxy(_init_context) # type: ignore - - def is_initializing(self) -> bool: - """Returns whether the Module is initializing. - - ``is_initializing`` returns ``True`` if the Module is currently being run - under ``init``. - """ - - return self._object__state._initializing - - def __init_subclass__(cls, experimental_pytree: bool = False) -> None: - super().__init_subclass__(experimental_pytree=experimental_pytree) - - cls = dataclasses.dataclass(repr=False)(cls) - - -def compact(f: F) -> F: - @functools.wraps(f) - def compact_wrapper(self, *args, **kwargs): - if not isinstance(self, Module): - raise ValueError( - f"Expected 'self' to be a nnx.bridge.Module, got {type(self).__name__}" - ) - - MODULE_CONTEXT.parent_stack.append(CompactContext(self)) - - try: - return f(self, *args, **kwargs) - finally: - MODULE_CONTEXT.parent_stack.pop() - - return compact_wrapper # type: ignore - - -# register Module as a dataclass_transform diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 3e799bf4..93531bb4 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -20,7 +20,7 @@ from flax.core import meta from flax.nnx import spmd from flax.nnx import traversals -from flax.nnx import variables as variableslib +from flax.nnx import variablelib as variableslib from flax.nnx.module import GraphDef import typing as tp diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 19c468af..eed4ba2f 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -16,15 +16,16 @@ import typing as tp from typing import Any -from flax import nnx from flax import linen +from flax import nnx +from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph from flax.nnx.bridge import variables as bv from flax.nnx.module import GraphDef, Module +from flax.nnx.object import Object from flax.nnx.rnglib import Rngs from flax.nnx.statelib import State -from flax.nnx.object import Object import jax from jax import tree_util as jtu @@ -220,7 +221,7 @@ class ToLinen(linen.Module): """ nnx_class: tp.Callable[..., Module] args: tp.Sequence = () - kwargs: tp.Mapping = dataclasses.field(default_factory=dict) + kwargs: tp.Mapping[str, tp.Any] = FrozenDict({}) skip_rng: bool = False metadata_type: tp.Type = bv.NNXMeta @@ -277,4 +278,4 @@ def _update_variables(self, module): def to_linen(nnx_class: tp.Callable[..., Module], *args, name: str | None = None, **kwargs): """Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields.""" - return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name) \ No newline at end of file + return ToLinen(nnx_class, args=args, kwargs=FrozenDict(kwargs), name=name) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 948d8f9f..191a0c19 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import contextlib import dataclasses import threading @@ -23,7 +24,7 @@ from flax import struct from flax.nnx.object import Object from flax.typing import Missing, PathParts -from flax.nnx import graph +from flax.nnx import graph, variablelib A = tp.TypeVar('A') @@ -119,6 +120,14 @@ def _maybe_insert(x): _maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex) ) +class PrefixMapping(abc.ABC): + @abc.abstractmethod + def map_prefix( + self, + path: variablelib.PathParts, + variable: variablelib.Variable, + /, + ) -> tp.Any: ... def check_consistent_aliasing( node: tuple[tp.Any, ...], @@ -143,11 +152,16 @@ def check_consistent_aliasing( raise ValueError( f'Cannot extract graph node from different trace level, got {value!r}' ) - if value in node_prefixes: - paths_prefixes = node_prefixes[value] - paths_prefixes.append((path, prefix)) - else: - node_prefixes[value] = [(path, prefix)] + if isinstance(prefix, PrefixMapping): + variable_prefix = prefix.map_prefix(path, value) + else: + variable_prefix = prefix + + if value in node_prefixes: + paths_prefixes = node_prefixes[value] + paths_prefixes.append((path, variable_prefix)) + else: + node_prefixes[value] = [(path, variable_prefix)] # check for inconsistent aliasing node_msgs = [] diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py index 2e4de1a1..63ed371b 100644 --- a/flax/nnx/filterlib.py +++ b/flax/nnx/filterlib.py @@ -54,13 +54,32 @@ def to_predicate(filter: Filter) -> Predicate: else: raise TypeError(f'Invalid collection filter: {filter:!r}. ') +def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]: + for i, filter_ in enumerate(filters): + if filter_ in (..., True) and i != len(filters) - 1: + remaining_filters = filters[i + 1 :] + if not all(f in (..., True) for f in remaining_filters): + raise ValueError( + '`...` or `True` can only be used as the last filters, ' + f'got {filter_} it at index {i}.' + ) + return tuple(map(to_predicate, filters)) + + +class HasTag(tp.Protocol): + tag: str + + +def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: + return hasattr(x, 'tag') + @dataclasses.dataclass(frozen=True) class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): - return hasattr(x, 'tag') and x.tag == self.tag + return _has_tag(x) and x.tag == self.tag def __repr__(self): return f'WithTag({self.tag!r})' @@ -77,6 +96,24 @@ def __repr__(self): return f'PathContains({self.key!r})' +class PathIn: + def __init__(self, *paths: PathParts): + self.paths = frozenset(paths) + + def __call__(self, path: PathParts, x: tp.Any): + return path in self.paths + + def __repr__(self): + paths_repr = ','.join(map(repr, self.paths)) + return f'PathIn({paths_repr})' + + def __eq__(self, other): + return isinstance(other, PathIn) and self.paths == other.paths + + def __hash__(self): + return hash(self.paths) + + @dataclasses.dataclass(frozen=True) class OfType: type: type diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 65eccfa9..2339f5c1 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -24,7 +24,6 @@ import numpy as np import typing_extensions as tpe -from flax.core.frozen_dict import FrozenDict from flax.nnx import filterlib, reprlib from flax.nnx.proxy_caller import ( ApplyCaller, @@ -32,8 +31,9 @@ DelayedAccessor, ) from flax.nnx.statelib import FlatState, State -from flax.nnx.variables import Variable, VariableState -from flax.typing import Key, PathParts +from flax.nnx import variablelib +from flax.nnx.variablelib import Variable, VariableState +from flax.typing import Key, PathParts, is_key_like A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -42,6 +42,7 @@ HA = tp.TypeVar('HA', bound=tp.Hashable) HB = tp.TypeVar('HB', bound=tp.Hashable) +KeyT = tp.TypeVar('KeyT', bound=Key) Index = int Names = tp.Sequence[int] @@ -93,9 +94,9 @@ def __str__(self) -> str: return repr(self) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): - type: type + type: type[Node] flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] def node_dict(self, node: Node) -> dict[Key, Leaf]: @@ -103,7 +104,7 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]: return dict(nodes) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): set_key: tp.Callable[[Node, Key, Leaf], None] pop_key: tp.Callable[[Node, Key], Leaf] @@ -115,7 +116,7 @@ def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): self.set_key(node, key, value) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] @@ -125,7 +126,8 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): ] -_node_impl_for_type: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {} def register_graph_node_type( @@ -136,7 +138,10 @@ def register_graph_node_type( create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], ): - _node_impl_for_type[type] = GraphNodeImpl( + if type in GRAPH_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + GRAPH_REGISTRY[type] = GraphNodeImpl( type=type, flatten=flatten, set_key=set_key, @@ -145,19 +150,30 @@ def register_graph_node_type( clear=clear, ) +def register_pytree_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], + unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node], +): + if type in PYTREE_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + PYTREE_REGISTRY[type] = PytreeNodeImpl( + type=type, flatten=flatten, unflatten=unflatten + ) def is_node(x: tp.Any) -> bool: - if type(x) in _node_impl_for_type: + if type(x) in GRAPH_REGISTRY: return True return is_pytree_node(x) def is_graph_node(x: tp.Any) -> bool: - return type(x) in _node_impl_for_type + return type(x) in GRAPH_REGISTRY def is_node_type(x: type[tp.Any]) -> bool: - return x in _node_impl_for_type or x is PytreeType + return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: @@ -166,22 +182,26 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: node_type = type(x) - if node_type not in _node_impl_for_type: - if is_pytree_node(x): - return PYTREE_NODE_IMPL - else: - raise ValueError(f'Unknown node type: {x}') - - return _node_impl_for_type[node_type] + if node_type in GRAPH_REGISTRY: + return GRAPH_REGISTRY[node_type] + elif node_type in PYTREE_REGISTRY: + return PYTREE_REGISTRY[node_type] + elif is_pytree_node(x): + return PYTREE_NODE_IMPL # type: ignore + else: + raise ValueError(f'Unknown node type: {x}') def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: - if x is PytreeType: - return PYTREE_NODE_IMPL - return _node_impl_for_type[x] + if x is GenericPytree: + return PYTREE_NODE_IMPL # type: ignore + elif x in PYTREE_REGISTRY: + return PYTREE_REGISTRY[x] + else: + return GRAPH_REGISTRY[x] -class _HashableMapping(tp.Mapping[HA, HB], tp.Hashable): +class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]): self._mapping = dict(mapping) @@ -202,7 +222,7 @@ def __hash__(self) -> int: def __eq__(self, other: tp.Any) -> bool: return ( - isinstance(other, _HashableMapping) and self._mapping == other._mapping + isinstance(other, HashableMapping) and self._mapping == other._mapping ) def __repr__(self) -> str: @@ -210,6 +230,10 @@ def __repr__(self) -> str: class GraphDef(tp.Generic[Node]): + """A class that represents all the static, stateless, and Pythonic parts of a Flax + :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or + :func:`graphdef` on the :class:`Module`.""" + type: type[Node] index: int @@ -224,10 +248,9 @@ def __nnx_repr__(self): yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) - def __penzai_repr__(self, path, subtree_renderer): - from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped] - - return pz_repr_lib.render_object_constructor( + def __treescope_repr__(self, path, subtree_renderer): + import treescope # type: ignore[import-not-found,import-untyped] + return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={'type': self.type, 'index': self.index}, path=path, @@ -237,8 +260,55 @@ def __penzai_repr__(self, path, subtree_renderer): jax.tree_util.register_static(NodeRef) - @dataclasses.dataclass(frozen=True, repr=False) +class VariableDef(reprlib.Representable): + type: type[Variable] + index: int + metadata: HashableMapping[str, tp.Any] + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('index', self.index) + yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata)) + + def __treescope_repr__(self, path, subtree_renderer): + import treescope # type: ignore[import-not-found,import-untyped] + + return treescope.repr_lib.render_object_constructor( + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'metadata': self.metadata, + }, + path=path, + subtree_renderer=subtree_renderer, + ) + + +jax.tree_util.register_static(VariableDef) + + +@dataclasses.dataclass(frozen=True, slots=True) +class SubGraphAttribute: + key: Key + value: NodeDef[tp.Any] | NodeRef[tp.Any] + + +@dataclasses.dataclass(frozen=True, slots=True) +class StaticAttribute: + key: Key + value: tp.Any + + +@dataclasses.dataclass(frozen=True, slots=True) +class LeafAttribute: + key: Key + value: VariableDef | NodeRef[tp.Any] + + +@dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a :class:`Module`. A ``GraphDef`` can be generated by either @@ -246,22 +316,16 @@ class NodeDef(GraphDef[Node], reprlib.Representable): type: tp.Type[Node] index: int - attributes: tuple[Key, ...] - subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] - static_fields: _HashableMapping[Key, tp.Any] - leaves: _HashableMapping[Key, NodeRef[tp.Any] | None] + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any - index_mapping: FrozenDict[Index, Index] | None + index_mapping: HashableMapping[Index, Index] | None @classmethod def create( cls, type: tp.Type[Node], index: int, - attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]], - static_fields: tp.Iterable[tuple[Key, tp.Any]], - leaves: tp.Iterable[tuple[Key, NodeRef[tp.Any] | None]], + attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, index_mapping: tp.Mapping[Index, Index] | None, ): @@ -269,11 +333,8 @@ def create( type=type, index=index, attributes=attributes, - subgraphs=_HashableMapping(subgraphs), - static_fields=_HashableMapping(static_fields), - leaves=_HashableMapping(leaves), metadata=metadata, - index_mapping=FrozenDict(index_mapping) + index_mapping=HashableMapping(index_mapping) if index_mapping is not None else None, ) @@ -283,12 +344,7 @@ def __nnx_repr__(self): yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) - yield reprlib.Attr('attributes', self.attributes) - yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs)) - yield reprlib.Attr( - 'static_fields', reprlib.PrettyMapping(self.static_fields) - ) - yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves)) + yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes)) yield reprlib.Attr('metadata', self.metadata) yield reprlib.Attr( 'index_mapping', @@ -300,18 +356,15 @@ def __nnx_repr__(self): def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] return treescope.repr_lib.render_object_constructor( - object_type=type(self), - attributes={ - 'type': self.type, - 'index': self.index, - 'attributes': self.attributes, - 'subgraphs': dict(self.subgraphs), - 'static_fields': dict(self.static_fields), - 'leaves': dict(self.leaves), - 'metadata': self.metadata, - }, - path=path, - subtree_renderer=subtree_renderer, + object_type=type(self), + attributes={ + 'type': self.type, + 'index': self.index, + 'attributes': self.attributes, + 'metadata': self.metadata, + }, + path=path, + subtree_renderer=subtree_renderer, ) def apply( @@ -374,40 +427,39 @@ def _graph_flatten( else: index = -1 - subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = [] - static_fields: list[tuple[Key, tp.Any]] = [] - leaves: list[tuple[Key, NodeRef | None]] = [] + attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) - subgraphs.append((key, nodedef)) + # subgraphs.append((key, nodedef)) + attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): if value in ref_index: - leaves.append((key, NodeRef(type(value), ref_index[value]))) + attributes.append( + LeafAttribute(key, NodeRef(type(value), ref_index[value])) + ) else: flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) - leaves.append((key, NodeRef(type(value), variable_index))) - elif is_state_leaf(value): - flat_state[(*path, key)] = value - leaves.append((key, None)) + variabledef = VariableDef( + type(value), variable_index, HashableMapping(value.get_metadata()) + ) + attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): path_str = '/'.join(map(str, (*path, key))) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) - static_fields.append((key, value)) + # static_fields.append((key, value)) + attributes.append(StaticAttribute(key, value)) nodedef = NodeDef.create( type=node_impl.type, index=index, - attributes=tuple(key for key, _ in values), - subgraphs=subgraphs, - static_fields=static_fields, - leaves=leaves, + attributes=tuple(attributes), metadata=metadata, index_mapping=None, ) @@ -416,7 +468,7 @@ def _graph_flatten( def unflatten( graphdef: GraphDef[Node], - state: GraphState, + state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], /, *, index_ref: dict[Index, tp.Any] | None = None, @@ -437,17 +489,17 @@ def unflatten( existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ + if isinstance(state, State): + state = state.raw_mapping # type: ignore if index_ref is None: index_ref = {} assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten( - graphdef, state.raw_mapping, index_ref, index_ref_cache - ) + node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) return node def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], - state: tp.Mapping[Key, StateLeaf | tp.Mapping[Key, tp.Any]], + state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], index_ref: dict[Index, tp.Any], index_ref_cache: dict[Index, tp.Any] | None, ) -> Node: @@ -476,24 +528,22 @@ def _graph_unflatten( node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[Key, StateLeaf | Node] = {} - - # NOTE: we could allw adding new StateLeafs here - if unkown_keys := set(state) - set(nodedef.attributes): - raise ValueError(f'Unknown keys: {unkown_keys}') + children: dict[Key, NodeLeaf | Node] = {} + state_keys: set = set(state.keys()) # for every key in attributes there are 6 possible cases: # - (2) the key can either be present in the state or not # - (3) the key can be a subgraph, a leaf, or a static attribute - for key in nodedef.attributes: + for attribute in nodedef.attributes: + key = attribute.key if key not in state: - # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys? # if key is not present create an empty types - if key in nodedef.static_fields: - children[key] = nodedef.static_fields[key] - elif key in nodedef.subgraphs: + if type(attribute) is StaticAttribute: + children[key] = attribute.value + elif type(attribute) is SubGraphAttribute: # if the key is a subgraph we create an empty node - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value + assert not isinstance(subgraphdef, VariableDef) if isinstance(subgraphdef, NodeRef): # subgraph exists, take it from the cache children[key] = index_ref[subgraphdef.index] @@ -506,11 +556,11 @@ def _get_children(): children[key] = _graph_unflatten( subgraphdef, substate, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - noderef = nodedef.leaves[key] - if noderef is not None and noderef.index in index_ref: + elif type(attribute) is LeafAttribute: + variabledef = attribute.value + if variabledef.index in index_ref: # variable exists, take it from the cache - children[key] = index_ref[noderef.index] + children[key] = index_ref[variabledef.index] else: # key for a variable is missing, raise an error raise ValueError( @@ -520,19 +570,21 @@ def _get_children(): else: raise RuntimeError(f'Unknown static field: {key!r}') else: + state_keys.remove(key) value = state[key] - if key in nodedef.static_fields: + # if key in nodedef.static_fields: + if type(attribute) is StaticAttribute: raise ValueError( f'Got state for static field {key!r}, this is not supported.' ) - if key in nodedef.subgraphs: + elif type(attribute) is SubGraphAttribute: if is_state_leaf(value): raise ValueError( - f'Expected value of type {nodedef.subgraphs[key]} for ' + f'Expected value of type {attribute.value} for ' f'{key!r}, but got {value!r}' ) assert isinstance(value, dict) - subgraphdef = nodedef.subgraphs[key] + subgraphdef = attribute.value if isinstance(subgraphdef, NodeRef): children[key] = index_ref[subgraphdef.index] @@ -541,45 +593,48 @@ def _get_children(): subgraphdef, value, index_ref, index_ref_cache ) - elif key in nodedef.leaves: - if not is_state_leaf(value): - raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}') + elif type(attribute) is LeafAttribute: + variabledef = attribute.value - noderef = nodedef.leaves[key] - - if noderef is None: - # if the leaf is None, it means that the value was originally - # a non-VariableState leaf, however we allow providing a - # VariableState presumbly created by modifying the State - if isinstance(value, VariableState): - value = value.to_variable() - children[key] = value - elif noderef.index in index_ref: + if variabledef.index in index_ref: # add an existing variable - children[key] = index_ref[noderef.index] + assert isinstance(variabledef, NodeRef) + children[key] = index_ref[variabledef.index] else: # its a unseen variable, create a new one - if not isinstance(value, VariableState): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(value)}.' - ) + assert isinstance(variabledef, VariableDef) # when idxmap is present, check if the Varable exists there # and update existing variables if it does - if index_ref_cache is not None and noderef.index in index_ref_cache: - variable = index_ref_cache[noderef.index] + if ( + index_ref_cache is not None + and variabledef.index in index_ref_cache + ): + # if variable exists, update it + variable = index_ref_cache[variabledef.index] if not isinstance(variable, Variable): raise ValueError( f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - variable.update_from_state(value) + if isinstance(value, VariableState): + variable.update_from_state(value) + else: + variable.raw_value = value else: # if it doesn't, create a new variable - assert isinstance(value, VariableState) - variable = value.to_variable() + if isinstance(value, VariableState): + variable = value.to_variable() + else: + variable = variabledef.type.from_metadata( + value, variabledef.metadata + ) children[key] = variable - index_ref[noderef.index] = variable + index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') + # NOTE: we could allw adding new StateLeafs here + if state_keys: + raise ValueError(f'Unknown keys: {state_keys}') + return children if isinstance(node_impl, GraphNodeImpl): @@ -672,7 +727,7 @@ def _graph_pop( pass -def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): +def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -699,26 +754,19 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): if is_state_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) - elif isinstance(value, VariableState): + else: # case 3: state leaf is being updated if not isinstance(current_value, Variable): raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) - current_value.update_from_state(value) - elif is_state_leaf(value): - # case 4: state field is being updated - if isinstance(node_impl, PytreeNodeImpl): - raise ValueError( - f'Cannot set key {key!r} on immutable node of ' - f'type {type(node).__name__}' - ) - node_impl.set_key(node, key, value) - else: - raise ValueError( - f'Unsupported update type: {type(value)} for key {key!r}' - ) + if isinstance(value, VariableState): + # updated from VariableState + current_value.update_from_state(value) + else: + # updated from raw value + current_value.raw_value = value # -------------------------------------------------------- # UpdateContext @@ -768,7 +816,7 @@ def split( if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=FrozenDict(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index) ) return graphdef, *states @@ -958,7 +1006,7 @@ def split( if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=FrozenDict(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index) ) self.flatten_end(ref_index) @@ -1247,12 +1295,11 @@ def split( states = _split_state(state, filters) return graphdef, *states - def merge( graphdef: GraphDef[A], - state: GraphState, + state: tp.Mapping[KeyT, tp.Any], /, - *states: GraphState, + *states: tp.Mapping[KeyT, tp.Any], ) -> A: """The inverse of :func:`split`. @@ -1289,13 +1336,15 @@ def merge( Returns: The merged :class:`Module`. """ - state = GraphState.merge(state, *states) + state = State.merge(state, *states) node = unflatten(graphdef, state) return node -def update(node, state: State, /, *states: State) -> None: - """Update the given graph node with a new :class:`State` in-place. +def update( + node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any] +) -> None: + """Update the given graph node with a new state(s) in-place. Example usage:: @@ -1321,19 +1370,71 @@ def update(node, state: State, /, *states: State) -> None: *states: Additional :class:`State` objects. """ if states: - state = GraphState.merge(state, *states) + state = State.merge(state, *states) + if isinstance(state, State): + state = state.raw_mapping + _graph_update_dynamic(node, state) - _graph_update_dynamic(node, state.raw_mapping) +def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]: + for path, value in iter_graph(node): + if isinstance(value, Variable): + yield path, value @tp.overload -def state(node, /) -> GraphState: ... +def variables(node, /) -> State[Key, Variable]: ... +@tp.overload +def variables(node, first: filterlib.Filter, /) -> State[Key, Variable]: ... +@tp.overload +def variables( + node, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, +) -> tuple[State[Key, Variable], ...]: ... +def variables( + node, + *filters: filterlib.Filter, +) -> tp.Union[State[Key, Variable], tuple[State[Key, Variable], ...]]: + """Similar to :func:`state` but returns the current :class:`Variable` objects instead + of new :class:`VariableState` instances. + Example:: -@tp.overload -def state(node, first: filterlib.Filter, /) -> GraphState: ... + >>> from flax import nnx + ... + >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + >>> params = nnx.variables(model, nnx.Param) + ... + >>> assert params['kernel'] is model.kernel + >>> assert params['bias'] is model.bias + Args: + node: A graph node object. + *filters: One or more :class:`Variable` objects to filter by. + Returns: + One or more :class:`State` mappings containing the :class:`Variable` objects. + """ + num_filters = len(filters) + if num_filters == 0: + filters = (..., ...) + else: + filters = (*filters, ...) + variables_iterable = _variables_generator(node) + flat_states = variablelib.split_flat_state( + variables_iterable, (*filters, ...) + ) + states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + if num_filters < 2: + return states[0] + return states + +@tp.overload +def state(node, /) -> GraphState: ... +@tp.overload +def state(node, first: filterlib.Filter, /) -> GraphState: ... @tp.overload def state( node, @@ -1342,8 +1443,6 @@ def state( /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ... - - def state( node, *filters: filterlib.Filter, @@ -1675,11 +1774,23 @@ class Static(tp.Generic[A]): # --------------------------------------------------------- # Pytree # --------------------------------------------------------- -class PytreeType: ... +class GenericPytree: ... def is_pytree_node(x: tp.Any) -> bool: - return not jax.tree_util.all_leaves((x,)) + t = type(x) + if t in PYTREE_REGISTRY: + return True + elif t in GRAPH_REGISTRY: + return False + # known non-pytree types + elif isinstance(x, Variable): + return False + # knon pytree types + elif isinstance(x, (VariableState, State)): + return True + else: + return not jax.tree_util.all_leaves((x,)) def _key_path_to_key(key: tp.Any) -> Key: @@ -1688,7 +1799,7 @@ def _key_path_to_key(key: tp.Any) -> Key: elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): - if not isinstance(key.key, Key): + if not is_key_like(key.key): raise ValueError( f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' ) @@ -1716,7 +1827,33 @@ def _unflatten_pytree( PYTREE_NODE_IMPL = PytreeNodeImpl( - type=PytreeType, + type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, ) + +# common pytrees +# list +register_pytree_node_type( + list, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore +) +# tuple +register_pytree_node_type( + tuple, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore +) +# dict +register_pytree_node_type( + dict, + flatten=lambda x: (sorted(x.items()), None), + unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore +) +# None +register_pytree_node_type( + type(None), + flatten=lambda x: ([], None), + unflatten=lambda _, __: None, # type: ignore +) \ No newline at end of file diff --git a/flax/nnx/module.py b/flax/nnx/module.py index efada835..795bb9a0 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -23,7 +23,7 @@ filterlib, graph, ) -from flax.nnx import variables as variableslib +from flax.nnx import variablelib as variableslib from flax.nnx.graph import GraphDef from flax.nnx.object import Object, ObjectMeta from flax.nnx.graph import GraphState, StateLeaf diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 38d598fc..185e0bd9 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -569,23 +569,25 @@ def __call__( def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing - forward inference. + forward inference. When in decode mode, only one token must be passed + at a time. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... - >>> rngs = nnx.Rngs(42) + >>> batch_size = 5 + >>> embed_dim = 3 + >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... - >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, - ... rngs=rngs, + ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 1a35058b..364b5dac 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict from flax import nnx -from flax.nnx import rnglib, variables +from flax.nnx import rnglib, variablelib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers from flax.typing import ( @@ -193,7 +193,7 @@ def kernel_init_wrap(rng, shape, dtype): ) flat_shape = jax.tree.map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) - if isinstance(kernel, variables.VariableMetadata): + if isinstance(kernel, variablelib.VariableMetadata): kernel.raw_value = jnp.reshape(kernel.raw_value, shape) else: kernel = jnp.reshape(kernel, shape) @@ -215,7 +215,7 @@ def kernel_init_wrap(rng, shape, dtype): def bias_init_wrap(rng, shape, dtype): flat_shape = (int(np.prod(shape)),) bias = self.bias_init(rng, flat_shape, dtype) - if isinstance(bias, variables.VariableMetadata): + if isinstance(bias, variablelib.VariableMetadata): bias.raw_value = jnp.reshape(bias.raw_value, shape) else: bias = jnp.reshape(bias, shape) @@ -370,6 +370,7 @@ def __call__(self, inputs: Array) -> Array: (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) + assert self.use_bias == (bias is not None) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y diff --git a/flax/nnx/nn/lora.py b/flax/nnx/nn/lora.py index 6fe5984e..dbba23fd 100644 --- a/flax/nnx/nn/lora.py +++ b/flax/nnx/nn/lora.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from flax.nnx import rnglib, variables +from flax.nnx import rnglib, variablelib from flax.nnx.module import Module from flax.nnx.nn import initializers from flax.nnx.nn.linear import Linear @@ -32,7 +32,7 @@ default_kernel_init = initializers.lecun_normal() -class LoRAParam(variables.Param[A]): +class LoRAParam(variablelib.Param[A]): pass @@ -84,7 +84,7 @@ def __init__( dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, - lora_param_type: tp.Type[variables.Variable] = LoRAParam, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, ): self.in_features = in_features @@ -155,7 +155,7 @@ def __init__( lora_dtype: tp.Optional[Dtype] = None, lora_param_dtype: Dtype = jnp.float32, lora_kernel_init: Initializer = default_kernel_init, - lora_param_type: tp.Type[variables.Variable] = LoRAParam, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, rngs: rnglib.Rngs, **kwargs, ): diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py new file mode 100644 index 00000000..ea18805d --- /dev/null +++ b/flax/nnx/nn/recurrent.py @@ -0,0 +1,924 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RNN modules for Flax.""" + +from typing import ( + Any, + TypeVar +) +from collections.abc import Callable +from functools import partial +from typing_extensions import Protocol +from absl import logging + +import jax +import jax.numpy as jnp + +from flax import nnx +from flax.nnx import rnglib +from flax.nnx.module import Module +from flax.nnx.nn import initializers +from flax.nnx.nn.linear import Linear +from flax.nnx.nn.activations import sigmoid +from flax.nnx.nn.activations import tanh +from flax.nnx.transforms.iteration import Carry +from flax.typing import ( + Dtype, + Initializer, + Shape +) + +default_kernel_init = initializers.lecun_normal() +default_bias_init = initializers.zeros_init() + +A = TypeVar("A") +Array = jax.Array +Output = Any + + +class RNNCellBase(Module): + """RNN cell base class.""" + + def initialize_carry( + self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None + ) -> Carry: + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + raise NotImplementedError + + def __call__( + self, + carry: Carry, + inputs: Array + ) -> tuple[Carry, Array]: + """Run the RNN cell. + + Args: + carry: the hidden state of the RNN cell. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + raise NotImplementedError + + @property + def num_feature_axes(self) -> int: + """Returns the number of feature axes of the RNN cell.""" + raise NotImplementedError + +def modified_orthogonal(key: Array, shape: Shape, dtype: Dtype = jnp.float32) -> Array: + """Modified orthogonal initializer for compatibility with half precision.""" + initializer = initializers.orthogonal() + return initializer(key, shape).astype(dtype) + +class LSTMCell(RNNCellBase): + r"""LSTM cell. + + The mathematical definition of the cell is as follows + + .. math:: + \begin{array}{ll} + i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ + c' = f * c + i * g \\ + h' = o * \tanh(c') \\ + \end{array} + + where x is the input, h is the output of the previous time step, and c is + the memory. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = modified_orthogonal, + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # input and recurrent layers are summed so only one needs a bias. + dense_i = partial( + Linear, + in_features=in_features, + out_features=hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + dense_h = partial( + Linear, + in_features=hidden_features, + out_features=hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.ii = dense_i() + self.if_ = dense_i() + self.ig = dense_i() + self.io = dense_i() + self.hi = dense_h() + self.hf = dense_h() + self.hg = dense_h() + self.ho = dense_h() + + def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] + r"""A long short-term memory (LSTM) cell. + + Args: + carry: the hidden state of the LSTM cell, + initialized using ``LSTMCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + c, h = carry + i = self.gate_fn(self.ii(inputs) + self.hi(h)) + f = self.gate_fn(self.if_(inputs) + self.hf(h)) + g = self.activation_fn(self.ig(inputs) + self.hg(h)) + o = self.gate_fn(self.io(inputs) + self.ho(h)) + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return (c, h) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class OptimizedLSTMCell(RNNCellBase): + r"""More efficient LSTM Cell that concatenates state components before matmul. + + The parameters are compatible with ``LSTMCell``. Note that this cell is often + faster than ``LSTMCell`` as long as the hidden size is roughly <= 2048 units. + + The mathematical definition of the cell is the same as ``LSTMCell`` and as + follows: + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ + c' = f * c + i * g \\ + h' = o * \tanh(c') \\ + \end{array} + + where x is the input, h is the output of the previous time step, and c is + the memory. + + Attributes: + gate_fn: activation function used for gates (default: sigmoid). + activation_fn: activation function used for output and memory update + (default: tanh). + kernel_init: initializer function for the kernels that transform + the input (default: lecun_normal). + recurrent_kernel_init: initializer function for the kernels that transform + the hidden state (default: initializers.orthogonal()). + bias_init: initializer for the bias parameters (default: initializers.zeros_init()). + dtype: the dtype of the computation (default: infer from inputs and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # input and recurrent layers are summed so only one needs a bias. + self.dense_i = Linear( + in_features=in_features, + out_features=4 * hidden_features, + use_bias=False, + kernel_init=self.kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.dense_h = Linear( + in_features=hidden_features, + out_features=4 * hidden_features, + use_bias=True, + kernel_init=self.recurrent_kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + def __call__(self, carry: tuple[Array, Array], inputs: Array) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] + r"""An optimized long short-term memory (LSTM) cell. + + Args: + carry: the hidden state of the LSTM cell, initialized using + ``LSTMCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + c, h = carry + + # Compute combined transformations for inputs and hidden state + y = self.dense_i(inputs) + self.dense_h(h) + + # Split the combined transformations into individual gates + i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) + + # Apply gate activations + i = self.gate_fn(i) + f = self.gate_fn(f) + g = self.activation_fn(g) + o = self.gate_fn(o) + + # Update cell state and hidden state + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> tuple[Array, Array]: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + c = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return (c, h) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class SimpleCell(RNNCellBase): + r"""Simple cell. + + The mathematical definition of the cell is as follows + + .. math:: + + \begin{array}{ll} + h' = \tanh(W_i x + b_i + W_h h) + \end{array} + + where x is the input and h is the output of the previous time step. + + If `residual` is `True`, + + .. math:: + + \begin{array}{ll} + h' = \tanh(W_i x + b_i + W_h h + h) + \end{array} + """ + + def __init__( + self, + in_features: int, + hidden_features: int, # not inferred from carry for now + *, + dtype: Dtype = jnp.float32, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + residual: bool = False, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = initializers.lecun_normal(), + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.residual = residual + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.rngs = rngs + + # self.hidden_features = carry.shape[-1] + # input and recurrent layers are summed so only one needs a bias. + self.dense_h = Linear( + in_features=self.hidden_features, + out_features=self.hidden_features, + use_bias=False, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.recurrent_kernel_init, + rngs=rngs, + ) + self.dense_i = Linear( + in_features=self.in_features, + out_features=self.hidden_features, + use_bias=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + rngs=rngs, + ) + + def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + new_carry = self.dense_i(inputs) + self.dense_h(carry) + if self.residual: + new_carry += carry + new_carry = self.activation_fn(new_carry) + return new_carry, new_carry + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + if rngs is None: + rngs = self.rngs + batch_dims = input_shape[:-1] + mem_shape = batch_dims + (self.hidden_features,) + return self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + + @property + def num_feature_axes(self) -> int: + return 1 + + +class GRUCell(RNNCellBase): + r"""GRU cell. + + The mathematical definition of the cell is as follows + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ + n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ + h' = (1 - z) * n + z * h \\ + \end{array} + + where x is the input and h is the output of the previous time step. + + Attributes: + in_features: number of input features. + hidden_features: number of output features. + gate_fn: activation function used for gates (default: sigmoid). + activation_fn: activation function used for output and memory update + (default: tanh). + kernel_init: initializer function for the kernels that transform + the input (default: lecun_normal). + recurrent_kernel_init: initializer function for the kernels that transform + the hidden state (default: initializers.orthogonal()). + bias_init: initializer for the bias parameters (default: initializers.zeros_init()). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: float32). + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + *, + gate_fn: Callable[..., Any] = sigmoid, + activation_fn: Callable[..., Any] = tanh, + kernel_init: Initializer = default_kernel_init, + recurrent_kernel_init: Initializer = initializers.orthogonal(), + bias_init: Initializer = initializers.zeros_init(), + dtype: Dtype | None = None, + param_dtype: Dtype = jnp.float32, + carry_init: Initializer = initializers.zeros_init(), + rngs: rnglib.Rngs, + ): + self.in_features = in_features + self.hidden_features = hidden_features + self.gate_fn = gate_fn + self.activation_fn = activation_fn + self.kernel_init = kernel_init + self.recurrent_kernel_init = recurrent_kernel_init + self.bias_init = bias_init + self.dtype = dtype + self.param_dtype = param_dtype + self.carry_init = carry_init + self.rngs = rngs + + # Combine input transformations into a single linear layer + self.dense_i = Linear( + in_features=in_features, + out_features=3 * hidden_features, # r, z, n + use_bias=True, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + self.dense_h = Linear( + in_features=hidden_features, + out_features=3 * hidden_features, # r, z, n + use_bias=False, + kernel_init=self.recurrent_kernel_init, + dtype=self.dtype, + param_dtype=self.param_dtype, + rngs=rngs, + ) + + def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + """Gated recurrent unit (GRU) cell. + + Args: + carry: the hidden state of the GRU cell, + initialized using ``GRUCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + h = carry + + # Compute combined transformations for inputs and hidden state + x_transformed = self.dense_i(inputs) + h_transformed = self.dense_h(h) + + # Split the combined transformations into individual components + xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) + hh_r, hh_z, hh_n = jnp.split(h_transformed, 3, axis=-1) + + # Compute gates + r = self.gate_fn(xi_r + hh_r) + z = self.gate_fn(xi_z + hh_z) + + # Compute n with an additional linear transformation on h + n = self.activation_fn(xi_n + r * hh_n) + + # Update hidden state + new_h = (1.0 - z) * n + z * h + return new_h, new_h + + def initialize_carry(self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | None = None) -> Array: # type: ignore[override] + """Initialize the RNN cell carry. + + Args: + rngs: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + batch_dims = input_shape[:-1] + if rngs is None: + rngs = self.rngs + mem_shape = batch_dims + (self.hidden_features,) + h = self.carry_init(rngs.carry(), mem_shape, self.param_dtype) + return h + + @property + def num_feature_axes(self) -> int: + return 1 + + +class RNN(Module): + """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence + + using :func:`flax.nnx.scan`. + """ + + def __init__( + self, + cell: RNNCellBase, + time_major: bool = False, + return_carry: bool = False, + reverse: bool = False, + keep_order: bool = False, + unroll: int = 1, + rngs: rnglib.Rngs | None = None, + ): + self.cell = cell + self.time_major = time_major + self.return_carry = return_carry + self.reverse = reverse + self.keep_order = keep_order + self.unroll = unroll + if rngs is None: + rngs = rnglib.Rngs(0) + self.rngs = rngs + + def __call__( + self, + inputs: Array, + *, + initial_carry: Carry | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + rngs: rnglib.Rngs | None = None, + ): + if return_carry is None: + return_carry = self.return_carry + if time_major is None: + time_major = self.time_major + if reverse is None: + reverse = self.reverse + if keep_order is None: + keep_order = self.keep_order + + # Infer the number of batch dimensions from the input shape. + # Cells like ConvLSTM have additional spatial dimensions. + time_axis = 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) + + # make time_axis positive + if time_axis < 0: + time_axis += inputs.ndim + + if time_major: + # we add +1 because we moved the time axis to the front + batch_dims = inputs.shape[1 : -self.cell.num_feature_axes] + else: + batch_dims = inputs.shape[:time_axis] + + # maybe reverse the sequence + if reverse: + inputs = jax.tree_util.tree_map( + lambda x: flip_sequences( + x, + seq_lengths, + num_batch_dims=len(batch_dims), + time_major=time_major, # type: ignore + ), + inputs, + ) + if rngs is None: + rngs = self.rngs + carry: Carry = ( + self.cell.initialize_carry( + inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :], rngs + ) + if initial_carry is None + else initial_carry + ) + + slice_carry = seq_lengths is not None and return_carry + + def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: + carry, y = cell(carry, x) + if slice_carry: + return carry, (carry, y) + return carry, y + state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type] + scan = nnx.scan( + scan_fn, + in_axes=(state_axes, Carry, time_axis), + out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), + unroll=self.unroll, + ) + scan_output = scan(self.cell, carry, inputs) + + # Next we select the final carry. If a segmentation mask was provided and + # return_carry is True we slice the carry history and select the last valid + # carry for each sequence. Otherwise we just use the last carry. + if slice_carry: + assert seq_lengths is not None + _, (carries, outputs) = scan_output + # seq_lengths[None] expands the shape of the mask to match the + # number of dimensions of the carry. + carry = _select_last_carry(carries, seq_lengths) + else: + carry, outputs = scan_output + + if reverse and keep_order: + outputs = jax.tree_util.tree_map( + lambda x: flip_sequences( + x, + seq_lengths, + num_batch_dims=len(batch_dims), + time_major=time_major, # type: ignore + ), + outputs, + ) + + if return_carry: + return carry, outputs + else: + return outputs + + +def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: + last_idx = seq_lengths - 1 + + def _slice_array(x: jnp.ndarray): + return x[last_idx, jnp.arange(x.shape[1])] + + return jax.tree_util.tree_map(_slice_array, sequence) + + +def _expand_dims_like(x, target): + """Expands the shape of `x` to match `target`'s shape by adding singleton dimensions.""" + return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim)) + + +def flip_sequences( + inputs: Array, + seq_lengths: Array | None, + num_batch_dims: int, + time_major: bool, +) -> Array: + """Flips a sequence of inputs along the time axis. + + This function can be used to prepare inputs for the reverse direction of a + bidirectional LSTM. It solves the issue that, when naively flipping multiple + padded sequences stored in a matrix, the first elements would be padding + values for those sequences that were padded. This function keeps the padding + at the end, while flipping the rest of the elements. + + Example:: + + >>> from flax.nnx.nn.recurrent import flip_sequences + >>> from jax import numpy as jnp + >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) + >>> lengths = jnp.array([1, 2, 3]) + >>> flip_sequences(inputs, lengths, 1, False) + Array([[1, 0, 0], + [3, 2, 0], + [6, 5, 4]], dtype=int32) + + + Args: + inputs: An array of input IDs [batch_size, seq_length]. + lengths: The length of each sequence [batch_size]. + + Returns: + An ndarray with the flipped inputs. + """ + # Compute the indices to put the inputs in flipped order as per above example. + time_axis = 0 if time_major else num_batch_dims + max_steps = inputs.shape[time_axis] + + if seq_lengths is None: + # reverse inputs and return + inputs = jnp.flip(inputs, axis=time_axis) + return inputs + + seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis) + + # create indexes + idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] + if time_major: + idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims) + else: + idxs = jnp.reshape( + idxs, [1] * num_batch_dims + [max_steps] + ) # [1, ..., max_steps] + idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] + idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features] + # Select the inputs in flipped order. + outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis) + + return outputs + + +def _concatenate(a: Array, b: Array) -> Array: + """Concatenates two arrays along the last dimension.""" + return jnp.concatenate([a, b], axis=-1) + + +class RNNBase(Protocol): + def __call__( + self, + inputs: Array, + *, + initial_carry: Carry | None = None, + rngs: rnglib.Rngs | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, + keep_order: bool | None = None, + ) -> Output | tuple[Carry, Output]: ... + + +class Bidirectional(Module): + """Processes the input in both directions and merges the results. + + Example usage:: + + >>> from flax import nnx + >>> import jax + >>> import jax.numpy as jnp + + >>> # Define forward and backward RNNs + >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) + + >>> # Create Bidirectional layer + >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) + + >>> # Input data + >>> x = jnp.ones((2, 3, 3)) + + >>> # Apply the layer + >>> out = layer(x) + >>> print(out.shape) + (2, 3, 8) + + """ + + forward_rnn: RNNBase + backward_rnn: RNNBase + merge_fn: Callable[[Array, Array], Array] = _concatenate + time_major: bool = False + return_carry: bool = False + + def __init__( + self, + forward_rnn: RNNBase, + backward_rnn: RNNBase, + *, + merge_fn: Callable[[Array, Array], Array] = _concatenate, + time_major: bool = False, + return_carry: bool = False, + rngs: rnglib.Rngs | None = None, + ): + self.forward_rnn = forward_rnn + self.backward_rnn = backward_rnn + self.merge_fn = merge_fn + self.time_major = time_major + self.return_carry = return_carry + if rngs is None: + rngs = rnglib.Rngs(0) + self.rngs = rngs + + def __call__( + self, + inputs: Array, + *, + initial_carry: tuple[Carry, Carry] | None = None, + rngs: rnglib.Rngs | None = None, + seq_lengths: Array | None = None, + return_carry: bool | None = None, + time_major: bool | None = None, + reverse: bool | None = None, # unused + keep_order: bool | None = None, # unused + ) -> Output | tuple[tuple[Carry, Carry], Output]: + if time_major is None: + time_major = self.time_major + if return_carry is None: + return_carry = self.return_carry + if rngs is None: + rngs = self.rngs + if initial_carry is not None: + initial_carry_forward, initial_carry_backward = initial_carry + else: + initial_carry_forward = None + initial_carry_backward = None + # Throw a warning in case the user accidentally re-uses the forward RNN + # for the backward pass and does not intend for them to share parameters. + if self.forward_rnn is self.backward_rnn: + logging.warning( + "forward_rnn and backward_rnn is the same object, so " + "they will share parameters." + ) + + # Encode in the forward direction. + carry_forward, outputs_forward = self.forward_rnn( + inputs, + initial_carry=initial_carry_forward, + rngs=rngs, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=False, + ) + + # Encode in the backward direction. + carry_backward, outputs_backward = self.backward_rnn( + inputs, + initial_carry=initial_carry_backward, + rngs=rngs, + seq_lengths=seq_lengths, + return_carry=True, + time_major=time_major, + reverse=True, + keep_order=True, + ) + + carry = (carry_forward, carry_backward) if return_carry else None + outputs = jax.tree_util.tree_map( + self.merge_fn, outputs_forward, outputs_backward + ) + + if return_carry: + return carry, outputs + else: + return outputs diff --git a/flax/nnx/object.py b/flax/nnx/object.py index f2714ff7..c63506fc 100644 --- a/flax/nnx/object.py +++ b/flax/nnx/object.py @@ -29,7 +29,7 @@ tracers, ) from flax.nnx import graph -from flax.nnx.variables import Variable, VariableState +from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key from flax import errors diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index 855a3049..6ed7660c 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -16,7 +16,6 @@ import dataclasses import threading import typing as tp -from abc import ABC, abstractmethod A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -48,10 +47,9 @@ class Attr: end: str = '' -class Representable(ABC): +class Representable: __slots__ = () - @abstractmethod def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: raise NotImplementedError @@ -121,4 +119,14 @@ def __nnx_repr__(self): yield Object(type='', value_sep=': ', start='{', end='}') for key, value in self.mapping.items(): - yield Attr(repr(key), value) \ No newline at end of file + yield Attr(repr(key), value) + +@dataclasses.dataclass(repr=False) +class PrettySequence(Representable): + list: tp.Sequence + + def __nnx_repr__(self): + yield Object(type='', value_sep='', start='[', end=']') + + for value in self.list: + yield Attr('', value) \ No newline at end of file diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 01ad0698..17bbaf37 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -23,7 +23,7 @@ from flax import struct from flax.nnx import graph from flax.nnx.statelib import State -from flax.nnx.variables import Variable +from flax.nnx.variablelib import Variable from flax.nnx import filterlib from flax.nnx.filterlib import All from flax.nnx.object import Object @@ -230,6 +230,13 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash index ab896ebd..9fcfec02 100644 --- a/flax/nnx/scripts/run-all-examples.bash +++ b/flax/nnx/scripts/run-all-examples.bash @@ -1,9 +1,8 @@ set -e source .venv/bin/activate -cd flax/nnx -for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do +for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" echo "$f" echo "---------------------------------" diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 822e24c4..5bd4eb2d 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -15,40 +15,47 @@ import functools import typing as tp -import jax -from jax.interpreters import pxla -from jax.sharding import PartitionSpec - -from flax.nnx import variables +import flax.core.spmd as core_spmd +from flax.nnx import variablelib from flax.typing import ( Array, ArrayPytree, # pylint: disable=invalid-name PartitionSpecPytree, # pylint: disable=invalid-name Sharding, ) +import jax +from jax.interpreters import pxla +from jax.sharding import PartitionSpec A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) PARTITION_NAME = 'partition_name' +class HasSharding(tp.Protocol): + sharding: tuple[str | None, ...] | None + + +def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]: + return hasattr(x, 'sharding') and x.sharding is not None -def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: +def add_axis(tree: A, index: int, params: tp.Mapping) -> A: axis_name = _get_partition_name(params) def _add_axis(x: tp.Any): - if isinstance(x, variables.VariableState): - if hasattr(x, 'sharding') and x.sharding is not None: + if isinstance(x, variablelib.VariableState): + if _has_sharding(x) and x.sharding is not None: sharding: list[str | None] = list(x.sharding) while len(sharding) < index: sharding.append(None) sharding.insert(index, axis_name) x.sharding = tuple(sharding) # type: ignore + assert isinstance(x, variablelib.VariableState) x.add_axis(index, axis_name) return x return jax.tree.map( - _add_axis, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) + _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState) ) @@ -56,7 +63,7 @@ def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A: axis_name = _get_partition_name(params) def _remove_axis(x: tp.Any): - if isinstance(x, variables.VariableState): + if isinstance(x, variablelib.VariableState): if hasattr(x, 'sharding') and x.sharding is not None: sharding = list(x.sharding) assert sharding.pop(index) == axis_name @@ -67,7 +74,7 @@ def _remove_axis(x: tp.Any): return jax.tree.map( _remove_axis, tree, - is_leaf=lambda x: isinstance(x, variables.VariableState), + is_leaf=lambda x: isinstance(x, variablelib.VariableState), ) @@ -89,15 +96,16 @@ def _maybe_replicate(x): else: return None - def from_rules(sharding, sharding_rules): - rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} - return (rules[s] if s in rules else None for s in sharding) - def f(x): - if isinstance(x, (variables.VariableState, variables.Variable)): + if isinstance(x, (variablelib.VariableState, variablelib.Variable)): if hasattr(x, 'sharding') and x.sharding: - if hasattr(x, 'sharding_rules') and x.sharding_rules: - return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) + if core_spmd.get_logical_axis_rules() or hasattr(x, 'sharding_rules'): + context_rules = core_spmd.get_logical_axis_rules() + local_rules = getattr(x, 'sharding_rules', ()) + rules = core_spmd.composite_rules(context_rules, local_rules) + return x.replace( + PartitionSpec(*core_spmd.from_sharding_rules(x.sharding, rules)) + ) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) @@ -105,7 +113,7 @@ def f(x): return _maybe_replicate(x) return jax.tree.map( - f, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) + f, tree, is_leaf=lambda x: isinstance(x, variablelib.VariableState) ) @@ -171,7 +179,7 @@ def with_partitioning( mesh: tp.Optional[jax.sharding.Mesh] = None, **metadata: tp.Any, ) -> F: - return variables.with_metadata( + return variablelib.with_metadata( initializer, sharding=sharding, mesh=mesh, diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 9063bc81..df299ea5 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -21,6 +21,7 @@ from flax.nnx import traversals from flax.nnx import filterlib, reprlib +from flax.nnx import variablelib from flax.typing import PathParts A = tp.TypeVar('A') @@ -55,11 +56,10 @@ def __treescope_repr__(self, path, subtree_renderer): class State(MutableMapping[K, V], reprlib.Representable): - """A pytree-like structure that contains a ``Mapping`` from strings or - integers to leaves. A valid leaf type is either :class:`Variable`, - ``jax.Array``, ``numpy.ndarray`` or nested ``State``'s. A ``State`` - can be generated by either calling :func:`split` or :func:`state` on - the :class:`Module`.""" + """A pytree-like structure that contains a ``Mapping`` from hashable and + comparable keys to leaves. Leaves can be of any type but :class:`VariableState` + and :class:`Variable` are the most common. + """ def __init__( self, @@ -146,6 +146,12 @@ def __treescope_repr__(self, path, subtree_renderer): subtree_renderer=subtree_renderer, ) + def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: + flat_state = self.flat_state() + for path, variable_state in flat_state.items(): + flat_state[path] = f(path, variable_state) + return State.from_flat_path(flat_state) + def flat_state(self) -> FlatState[V]: return traversals.flatten_mapping(self._mapping) @@ -172,11 +178,17 @@ def to_pure_dict(self, def replace_by_pure_dict(self, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None): + def try_convert_int(x): + try: + return int(x) + except ValueError: + return x # Works for nnx.Variable and nnx.VariableState if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v current_flat = self.flat_state() for kp, v in traversals.flatten_mapping(pure_dict).items(): + kp = tuple(map(try_convert_int, kp)) if kp not in current_flat: raise ValueError(f'key in pure_dict not available in state: {kp}') current_flat[kp] = replace_fn(current_flat[kp], v) @@ -307,7 +319,9 @@ def filter( return states # type: ignore[bad-return-type] @staticmethod - def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: + def merge( + state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V] + ) -> State[K, V]: """The inverse of :meth:`split() `. ``merge`` takes one or more ``State``'s and creates @@ -340,14 +354,16 @@ def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]: The merged ``State``. """ if not states: - return state + if isinstance(state, State): + return state + return State(state) states = (state, *states) new_state: FlatState[V] = {} for state in states: - new_state.update(state.flat_state()) # type: ignore[attribute-error] # pytype is wrong here + new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here return State.from_flat_path(new_state) @@ -419,3 +435,13 @@ def _split_state( flat_states[-1][path] = value # type: ignore[index] # mypy is wrong here? return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + + +def create_path_filters(state: State): + flat_state = state.flat_state() + value_paths: dict[tp.Any, set[PathParts]] = {} + for path, value in flat_state.items(): + if isinstance(value, (variablelib.Variable, variablelib.VariableState)): + value = value.value + value_paths.setdefault(value, set()).add(path) + return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index cc785973..c53bbd5c 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -39,10 +39,8 @@ def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: - if jax.__version_info__ <= (0, 4, 33): - return self._jax_trace is current_jax_trace() - - return self._jax_trace == current_jax_trace() + # TODO: re-enable when we update nnx to use stackless trace context + return True def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') @@ -62,3 +60,10 @@ def __eq__(self, other): return isinstance(other, TraceState) and self._jax_trace is other._jax_trace return isinstance(other, TraceState) and self._jax_trace == other._jax_trace + + # pickle support + def __getstate__(self): + return {} + + def __setstate__(self, state): + self._jax_trace = current_jax_trace() diff --git a/flax/nnx/training/metrics.py b/flax/nnx/training/metrics.py index 49269134..2073787b 100644 --- a/flax/nnx/training/metrics.py +++ b/flax/nnx/training/metrics.py @@ -20,7 +20,7 @@ from flax import struct from flax.nnx import filterlib, graph from flax.nnx.object import Object -from flax.nnx.variables import Variable +from flax.nnx.variablelib import Variable import jax, jax.numpy as jnp # TODO: add tests and docstrings diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 281066ea..4b85d5a3 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -19,9 +19,9 @@ from flax import nnx from flax.nnx import filterlib -from flax.nnx import variables +from flax.nnx import variablelib from flax.nnx.object import Object -from flax.nnx.variables import Variable, VariableState +from flax.nnx.variablelib import Variable, VariableState # TODO: add tests and docstrings @@ -47,7 +47,7 @@ class OptVariable(OptState): def _wrap_optimizer_state(opt_state): def wrap_optimizer_state_fn(x): - if isinstance(x, variables.VariableState): + if isinstance(x, variablelib.VariableState): new_state = x.copy() new_state.source_type = x.type new_state.type = OptVariable @@ -58,7 +58,7 @@ def wrap_optimizer_state_fn(x): return jax.tree.map( wrap_optimizer_state_fn, opt_state, - is_leaf=lambda x: isinstance(x, variables.VariableState), + is_leaf=lambda x: isinstance(x, variablelib.VariableState), ) @@ -193,7 +193,7 @@ def __init__( self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt))) self.wrt = wrt - def update(self, grads): + def update(self, grads, **kwargs): """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the gradients are with respect to the same :class:`Variable` types as defined in @@ -249,14 +249,16 @@ def update(self, grads): Args: grads: the gradients derived from ``nnx.grad``. + **kwargs: additional keyword arguments passed to the tx.update, to support + ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ params = nnx.state(self.model, self.wrt) opt_state = _opt_state_variables_to_state(self.opt_state) - updates, new_opt_state = self.tx.update(grads, opt_state, params) + updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs) new_params = optax.apply_updates(params, updates) assert isinstance(new_params, nnx.State) self.step.value += 1 nnx.update(self.model, new_params) - _update_opt_state(self.opt_state, new_opt_state) + _update_opt_state(self.opt_state, new_opt_state) \ No newline at end of file diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 9e55f709..5ef0d183 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -23,7 +23,7 @@ extract, filterlib, graph, - variables, + variablelib, ) from flax.nnx.statelib import State import jax @@ -126,7 +126,7 @@ def _grad_general( index_filter[index] = ( dataclasses.replace(argnum, argnum=-1) if isinstance(argnum, DiffState) - else DiffState(-1, variables.Param) + else DiffState(-1, variablelib.Param) ) gradded_fn = transform( @@ -362,6 +362,22 @@ def value_and_grad( return_value=True, ) +# ----------------------------------------------- +# custom_vjp +# ----------------------------------------------- +# custom_vjp is one of the most complicated transforms as it requires +# to handle 4 different functions: +# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes +# to pytrees and output pytrees to graph nodes. +# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts +# its input pytrees to graph nodes and output graph nodes to pytrees. +# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. Since it might run by itself in a separate context, +# it needs to be aware if the update_context is active or not in order to update the outer +# referenes. +# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. It doesn't need to be aware of the outer context +# since it will never update the outer references as it runs during the backward pass. def _custom_vjp_merge_fn( ctx: graph.MergeContext, @@ -381,16 +397,15 @@ def _custom_vjp_split_fn( prefix: bool | DiffState, value, *, - nondiff_states: deque[extract.GraphDefState], + nondiff_states: list[extract.GraphDefState], ): + broadcast: graph.GraphState if prefix is False: - # pure non-differentiable arg, we pass all the state through - # but we return TreeNode.from_split with a graphdef to we can call from_tree - # on the nondiff args during the backward pass - graphdef, passed = ctx.split(value) - broadcast = State({}) # type: ignore[var-annotated] - nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.NodeStates.from_split(graphdef, passed) + # pure non-differentiable arg, not supported + raise TypeError( + 'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. ' + f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}' + ) elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef @@ -409,23 +424,28 @@ def _custom_vjp_split_fn( return extract.NodeStates.from_states(passed) -class CustomVjpMetadata(struct.PyTreeNode): + nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) +def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): + if isinstance(x, graph.NodeDef): + assert x.index_mapping is not None + index_mappings.append(x.index_mapping) + return dataclasses.replace(x, index_mapping=None) + return x @dataclasses.dataclass(eq=False) class CustomVjpFnWrapper: f: tp.Callable[..., tp.Any] + jax_nondiff_argnums: tuple[int, ...] ctxtag: str + nondiff_states: list[extract.GraphDefState] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) - ) - metadata, nondiff_states = broadcast + nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( @@ -436,10 +456,24 @@ def __call__(self, *pure_args): out = self.f(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state( + self.ctxtag + ) + + pure_args_out, pure_out = jax.tree.map( + functools.partial(_extract_index_mappings, index_mappings=index_mappings), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) return pure_args_out, pure_out @@ -447,67 +481,90 @@ def __call__(self, *pure_args): @dataclasses.dataclass(eq=False) class FwdFn: fwd: tp.Callable[..., tp.Any] + nondiff_argnums: tuple[int, ...] ctxtag: str + nondiff_states: list[extract.GraphDefState] def __post_init__(self): functools.update_wrapper(self, self.fwd) def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) + # here we need to be aware if the update_context is active or not + # when its not active, index_mappings will be None + # when its active, we will remove the index_mappings from the NodeDef's and store them + # in the index_mappings deque created by CustomVjp + update_context_active = ( + self.ctxtag in graph.GRAPH_CONTEXT.update_context_stacks ) - metadata, nondiff_states = broadcast + nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), - ctxtag=self.ctxtag, + ctxtag=self.ctxtag if update_context_active else None, ) out, residual = self.fwd(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in self.nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( - (args_out, out), ctxtag=self.ctxtag + (args_out, out), + ctxtag=self.ctxtag if update_context_active else None, ) pure_residual = extract.to_tree(residual) - return (pure_args_out, pure_out), (metadata, pure_residual) + if update_context_active: + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[graph.HashableMapping] = ( + extract.get_broadcast_state(self.ctxtag) + ) + pure_args_out, pure_out = jax.tree.map( + functools.partial( + _extract_index_mappings, index_mappings=index_mappings + ), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) + + return (pure_args_out, pure_out), pure_residual @dataclasses.dataclass(eq=False) class BwdFn: bwd: tp.Callable[..., tp.Any] + tree_node_args: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.bwd) def __call__(self, *args): - res: tuple[CustomVjpMetadata, tp.Any] - pure_g: tuple[tp.Any, tp.Any] - *nondiff, res, pure_g = args - metadata, pure_residual = res - nondiff = extract.from_tree(nondiff) + *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args residual = extract.from_tree(pure_residual) - pure_g = jax.tree.map( + (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, - pure_g, + (pure_args_out_g, pure_out_g), is_leaf=lambda x: isinstance(x, extract.NodeStates), ) - tangent = self.bwd(*nondiff, residual, pure_g) + tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g)) - def state_to_tree_node(is_tree_node: bool, x): - if is_tree_node: - if not isinstance(x, State): + def state_to_node_states(is_differentiable: bool, x): + if is_differentiable: + if isinstance(x, jax.Array): + return x + elif not isinstance(x, State): raise ValueError(f'Expected State, got {type(x)}') return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( - state_to_tree_node, - metadata.tangent_tree_node_args, + state_to_node_states, + self.tree_node_args, tangent, is_leaf=lambda x: isinstance(x, State), ) @@ -521,14 +578,15 @@ def __init__( nondiff_argnums: tuple[int | DiffState, ...], ): functools.update_wrapper(self, fun) - jax_nondiff_argnums = tuple( - x.argnum if isinstance(x, DiffState) else x for x in nondiff_argnums + # first argument is metadata + self.jax_nondiff_argnums = tuple( + x for x in nondiff_argnums if isinstance(x, int) ) self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' - self.custom_vjp_fn = jax.custom_vjp( - CustomVjpFnWrapper(fun, self.ctxtag), - nondiff_argnums=jax_nondiff_argnums, - ) + self.fun = fun + self.fwd: tp.Callable | None = None + self.bwd: tp.Callable | None = None + self.symbolic_zeros: bool | None = None self.nondiff_argnums = nondiff_argnums self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} for argnum in self.nondiff_argnums: @@ -541,16 +599,18 @@ def __init__( else False ) - def __getattr__(self, name: str) -> tp.Any: - return getattr(self.custom_vjp_fn, name) + # def __getattr__(self, name: str) -> tp.Any: + # if not hasattr(self.custom_vjp_fn, name): + # raise AttributeError(f'{type(self).__name__} has no attribute {name}') + # return getattr(self.custom_vjp_fn, name) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: # pytype: disable=invalid-annotation with graph.update_context(self.ctxtag): - args = resolve_kwargs(self.custom_vjp_fn, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) del kwargs - nondiff_states: deque[extract.GraphDefState] = deque() + nondiff_states: list[extract.GraphDefState] = [] arg_filters = tuple( self.diff_filter.get(i, True) for i in range(len(args)) ) @@ -562,24 +622,57 @@ def __call__( ), ctxtag=self.ctxtag, ) - tangent_args = tp.cast( - tuple[tp.Literal[True] | DiffState, ...], - tuple(x for x in arg_filters if x is not False), - ) tree_node_args = jax.tree.map( lambda x: isinstance(x, extract.NodeStates), pure_args, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) - tangent_tree_node_args = tuple( - arg - for arg, is_tree_node in zip(args, tree_node_args) - if is_tree_node is not False + tree_node_args = tuple( + x + for i, x in enumerate(tree_node_args) + if i not in self.jax_nondiff_argnums + ) + index_mappings: deque[graph.HashableMapping] = deque() + with extract.broadcast_state(self.ctxtag, index_mappings): + if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: + raise ValueError() + + custom_vjp_fn = jax.custom_vjp( + fun=CustomVjpFnWrapper( + f=self.fun, + jax_nondiff_argnums=self.jax_nondiff_argnums, + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + ), + nondiff_argnums=self.jax_nondiff_argnums, + ) + custom_vjp_fn.defvjp( + fwd=FwdFn( + fwd=self.fwd, + nondiff_argnums=self.jax_nondiff_argnums, + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + ), + bwd=BwdFn( + bwd=self.bwd, + tree_node_args=tree_node_args, + ), + symbolic_zeros=self.symbolic_zeros, + ) + pure_args_out, pure_out = custom_vjp_fn(*pure_args) + + # insert index_mappings + def _insert_index_mappings(x): + if isinstance(x, graph.NodeDef): + index_mapping: graph.HashableMapping = index_mappings.popleft() + return dataclasses.replace(x, index_mapping=index_mapping) + return x + + pure_args_out, pure_out = jax.tree_util.tree_map( + _insert_index_mappings, + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), ) - metadata = CustomVjpMetadata(tangent_args) - - with extract.broadcast_state(self.ctxtag, (metadata, nondiff_states)): - pure_args_out, pure_out = self.custom_vjp_fn(*pure_args) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=self.ctxtag @@ -593,86 +686,9 @@ def defvjp( bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: - """Define a custom VJP rule for the function represented by this instance. - - Args: - fwd: a Python callable representing the forward pass of the custom VJP - rule. When there are no ``nondiff_argnums``, the ``fwd`` function has - the same input signature as the underlying primal function. It should - return as output a pair, where the first element represents the primal - output and the second element represents any "residual" values to store - from the forward pass for use on the backward pass by the function - ``bwd``. Input arguments and elements of the output pair may be arrays - or nested tuples/lists/dicts thereof. - bwd: a Python callable representing the backward pass of the custom VJP - rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes - two arguments, where the first is the "residual" values produced on the - forward pass by ``fwd``, and the second is the output cotangent with the - same structure as the primal function output. The output of ``bwd`` must - be a tuple of length equal to the number of arguments of the primal - function, and the tuple elements may be arrays or nested - tuples/lists/dicts thereof so as to match the structure of the primal - input arguments. - symbolic_zeros: boolean, determining whether to indicate symbolic zeros - to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom - derivative rules to detect when certain inputs, and when certain - output cotangents, are not involved in differentiation. If ``True``: - - * ``fwd`` must accept, in place of each leaf value ``x`` in - the pytree comprising an argument to the original function, - an object (of type - ``jax.custom_derivatives.CustomVJPPrimal``) with two - attributes instead: ``value`` and ``perturbed``. The - ``value`` field is the original primal argument, and - ``perturbed`` is a boolean. The ``perturbed`` bit indicates - whether the argument is involved in differentiation (i.e., - if it is ``False``, then the corresponding Jacobian "column" - is zero). - - * ``bwd`` will be passed objects representing static symbolic zeros in - its cotangent argument in correspondence with unperturbed values; - otherwise, only standard JAX types (e.g. array-likes) are passed. - - Setting this option to ``True`` allows these rules to detect whether - certain inputs and outputs are not involved in differentiation, but at - the cost of special handling. For instance: - - * The signature of ``fwd`` changes, and the objects it is passed cannot - be output from the rule directly. - - * The ``bwd`` rule is passed objects that are not entirely array-like, - and that cannot be passed to most ``jax.numpy`` functions. - - * Any custom pytree nodes involved in the primal function's arguments - must accept, in their unflattening functions, the two-field record - objects that are given as input leaves to the ``fwd`` rule. - - Default ``False``. - - Returns: - None. - - Examples: - - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) - """ - - self.custom_vjp_fn.defvjp( - FwdFn(fwd, self.ctxtag), - BwdFn(bwd), - symbolic_zeros=symbolic_zeros, - ) + self.fwd = fwd + self.bwd = bwd + self.symbolic_zeros = symbolic_zeros @tp.overload @@ -694,6 +710,14 @@ def custom_vjp( """Reference aware version of `jax.custom_vjp `__. + ``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference + with the JAX version is that, because Modules follow reference semantics, they propagate the State + updates for the inputs as auxiliary outputs. This means that the incomming gradients in the ``bwd`` function + will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of + the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in + ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tanget will be + expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms. + Example:: >>> import jax @@ -713,10 +737,14 @@ def custom_vjp( ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): - ... inputs_g, out_g = g + ... input_updates_g, out_g = g ... cos_x, sin_x, m = res - ... tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g)) - ... return (tangent_m,) + ... (m_updates_g,) = input_updates_g + ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy + ... + ... m_g['x'].value = cos_x * out_g * m.y + ... m_g['y'].value = sin_x * out_g + ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... @@ -735,6 +763,63 @@ def custom_vjp( ) }) + Note that the State objects that represent Module terms on ``input_updates_g`` have the + same shape as the State objects expected in the output tanget. This means that you can + usually just copy them from ``input_updates_g`` and update them with their corresponding + gradient values. + + You can select which substates are differentiable (have a tangent) for Modules and other + graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to + differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following:: + + >>> x_attribute = nnx.PathContains('x') + >>> diff_state = nnx.DiffState(0, x_attribute) + ... + >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) + ... def f(m: Foo): + ... return jnp.sin(m.x) * m.y # type: ignore + + >>> def f_fwd(m: Foo): + ... y = f(m) + ... res = (jnp.cos(m.x), m) # type: ignore + ... return y, res + ... + >>> def f_bwd(res, g): + ... input_updates_g, out_g = g + ... cos_x, m = res + ... (m_updates_g,) = input_updates_g + ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy + ... + ... m_g.x.value = cos_x * out_g * m.y + ... del m_g['y'] # y is not differentiable + ... return (m_g,) + + >>> f.defvjp(f_fwd, f_bwd) + ... + >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) + >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) + ... + >>> jax.tree.map(jnp.shape, grad) + State({ + 'x': VariableState( + type=Param, + value=() + ) + }) + + Note that ``grad`` cannot calculate gradients for states that don't have a tangent + defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute`` + filter to keep ``custom_vjp`` and ``grad`` in sync. + + Args: + fun: Callable base function. + nondiff_argnums: Tuple of integers or DiffState objects specifying the + argument indices that are not differentiated. By default all arguments are + differentiated. Integers cannot be used to mark graph nodes such as Modules + as non-differentiable, in this case use a DiffState object. DiffState objects + define the set of differentiable substates, contrary to what the name of this + argument suggests, this is done for compatibility with ``grad``. + """ if isinstance(fun, Missing): return functools.partial(custom_vjp, nondiff_argnums=nondiff_argnums) @@ -789,3 +874,18 @@ def remat( ), ) ) + """A 'lifted' version of the + `jax.checkpoint `__ + (a.k.a. ``jax.remat``). + + ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for + example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus + how they are recomputed during the backward pass, trading off memory and FLOPs. + + Learn more in `Flax NNX vs JAX Transformations `_. + + To learn about ``jax.remat``, go to JAX's + `fundamentals of jax.checkpoint `_ + and `practical notes `_. + """ + diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 88d99e8f..e5ce20f8 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -21,6 +21,8 @@ extract, filterlib, graph, + statelib, + variablelib, ) import jax import jax.core @@ -36,13 +38,17 @@ # ------------------------------- -class StateSharding: +class StateSharding(extract.PrefixMapping): def __init__( self, - filter_sharding: tp.Mapping[filterlib.Filter, tp.Any] + filter_sharding: statelib.State + | tp.Mapping[filterlib.Filter, tp.Any] | tp.Iterable[tuple[filterlib.Filter, tp.Any]], /, ): + if isinstance(filter_sharding, statelib.State): + filter_sharding = statelib.create_path_filters(filter_sharding) # type: ignore + iterable = tuple( filter_sharding.items() if isinstance(filter_sharding, tp.Mapping) @@ -59,6 +65,15 @@ def filters(self) -> tuple[filterlib.Filter, ...]: def shardings(self) -> tuple[tp.Any, ...]: return self._shardings + def map_prefix( + self, path: variablelib.PathParts, variable: variablelib.Variable + ) -> tp.Any: + for filter, sharding in zip(self.filters, self.shardings): + predicate = filterlib.to_predicate(filter) + if predicate(path, variable): + return sharding + raise ValueError(f'No axis found for {path=}, {variable=}') + def __repr__(self): return f'StateSharding({dict(zip(self.filters, self.shardings))})' diff --git a/flax/nnx/transforms/deprecated.py b/flax/nnx/transforms/deprecated.py index f0191fc0..844cea48 100644 --- a/flax/nnx/transforms/deprecated.py +++ b/flax/nnx/transforms/deprecated.py @@ -20,7 +20,7 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.nnx import extract, filterlib, graph, rnglib, spmd, variables +from flax.nnx import extract, filterlib, graph, rnglib, spmd, variablelib from flax.nnx.module import GraphDef, Module from flax.nnx.proxy_caller import DelayedAccessor from flax.nnx.statelib import State @@ -1685,7 +1685,7 @@ def grad( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., tp.Any]: """Lifted version of ``jax.grad`` that can handle Modules / graph nodes as arguments. @@ -1770,7 +1770,7 @@ def value_and_grad( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., tp.Any]: return _grad_general( f, @@ -1794,7 +1794,7 @@ def constructor( reduce_axes: tp.Sequence[AxisName] = (), return_value: bool = False, *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, ) -> tp.Callable[..., Grad[MA]]: def _create_grad(*args, **kwargs): return Grad( @@ -1821,7 +1821,7 @@ def __init__( allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), *, - wrt: filterlib.Filter = variables.Param, + wrt: filterlib.Filter = variablelib.Param, # submodule args module_init_args: tuple[tp.Any, ...], module_init_kwargs: dict[str, tp.Any], diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index c77e3195..994e5828 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -20,7 +20,8 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.nnx import extract, filterlib, graph, spmd +from flax.nnx import extract, filterlib, graph, spmd, variablelib +from flax.nnx import statelib from flax.nnx.module import Module from flax.nnx.statelib import State from flax.nnx.transforms.transforms import resolve_kwargs @@ -39,6 +40,7 @@ M = tp.TypeVar('M', bound=Module) MA = tp.TypeVar('MA', bound=Module) N = tp.TypeVar('N', bound=Module) +T = tp.TypeVar('T') StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable Leaves = tp.List[Leaf] @@ -54,16 +56,20 @@ class Carry: # ------------------------------- -class StateAxes: +class StateAxes(extract.PrefixMapping): def __init__( - self, - filter_axes: ( - tp.Mapping[filterlib.Filter, Index | type[Carry] | None] - | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]] - ), - /, + self, + filter_axes: ( + statelib.State + | tp.Mapping[filterlib.Filter, Index | type[Carry] | None] + | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]] + ), + /, ): + if isinstance(filter_axes, statelib.State): + filter_axes = statelib.create_path_filters(filter_axes) # type: ignore + iterable = tuple( filter_axes.items() if isinstance(filter_axes, tp.Mapping) @@ -80,6 +86,15 @@ def filters(self) -> tuple[filterlib.Filter, ...]: def axes(self) -> tuple[Index | type[Carry] | None, ...]: return self._axes + def map_prefix( + self, path: variablelib.PathParts, variable: variablelib.Variable + ) -> tp.Any: + for filter, axis in zip(self.filters, self.axes): + predicate = filterlib.to_predicate(filter) + if predicate(path, variable): + return axis + raise ValueError(f'No axis found for {path=}, {variable=}') + def __repr__(self): return f'StateAxes({dict(self.items())})' @@ -635,7 +650,7 @@ def check_carry_same_references(key_path, arg, out): def _extract_index_mappings( pure_carry_arg_out, - carry_index_mappings: list[FrozenDict[int, int]], + carry_index_mappings: list[graph.HashableMapping[int, int]], /, ): def extract_index_mappings(x): @@ -660,7 +675,7 @@ def extract_index_mappings(x): def _insert_index_mappings( pure_carry_arg_out, - carry_index_mappings: deque[FrozenDict[int, int]], + carry_index_mappings: deque[graph.HashableMapping[int, int]], /, ): def insert_index_mappings(x): @@ -1081,7 +1096,7 @@ def __call__( # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs - carry_index_mappings: list[FrozenDict[int, int]] = [] + carry_index_mappings: list[graph.HashableMapping[int, int]] = [] pure_carry_arg_out = _extract_index_mappings( pure_carry_arg_out, carry_index_mappings ) @@ -1290,3 +1305,245 @@ def scan_wrapper(*args, **kwargs): return out return scan_wrapper # type: ignore + + + + + +# ------------------------------- +# while_loop +# ------------------------------- + + +@dataclasses.dataclass(eq=False) +class WhileLoopCondFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, pure_val): + val = extract.from_tree(pure_val) + out = self.f(val) + return out + + +def _add_fake_index_mapping(tree: tp.Any): + global_index_mapping = {} # for the whole context, over all inputs + def per_node_state(ns: extract.NodeStates | tp.Any): + if not isinstance(ns, extract.NodeStates) or not isinstance( + ns._graphdef, graph.NodeDef + ): + return ns + + def per_node_def(nd: graph.NodeDef | graph.NodeRef): + if nd.index >= 0: + global_index_mapping[nd.index] = nd.index + if isinstance(nd, graph.NodeRef): + return + + for attribute in nd.attributes: + if type(attribute) is graph.SubGraphAttribute: + per_node_def(attribute.value) + elif ( + type(attribute) is graph.LeafAttribute + and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) + and attribute.value.index >= 0 + ): + global_index_mapping[attribute.value.index] = attribute.value.index + return + + per_node_def(ns._graphdef) + return dataclasses.replace( + ns, + _graphdef=dataclasses.replace( + ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) + ), + ) + + return jax.tree.map(per_node_state, tree, + is_leaf=lambda x: isinstance(x, extract.NodeStates)) + + +def _remove_index_mapping(tree: tp.Any): + '''Remove a fake index_mapping for the input to match that of the output.''' + def per_node_state(ns: extract.NodeStates | tp.Any): + if not isinstance(ns, extract.NodeStates) or not isinstance( + ns._graphdef, graph.NodeDef + ): + return ns + assert isinstance(ns._graphdef, graph.NodeDef) + return dataclasses.replace(ns, _graphdef=dataclasses.replace( + ns._graphdef, index_mapping=None + )) + + return jax.tree.map(per_node_state, tree, + is_leaf=lambda x: isinstance(x, extract.NodeStates)) + + +@dataclasses.dataclass(eq=False) +class WhileLoopBodyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + @graph.update_context('while_loop_body') + def __call__(self, pure_val): + # Removing the dummy index mapping being added outside of body function. + pure_val_in = _remove_index_mapping(pure_val) + + val = extract.from_tree(pure_val_in, ctxtag='while_loop_body') + out = self.f(val) + pure_out = extract.to_tree(out, ctxtag='while_loop_body') + + try: + jax.tree.map(lambda a, b: None, pure_val, pure_out) + except ValueError as e: + msg = ("nnx.while_loop requires body function's input and output to " + "have the same reference and pytree structure, but they differ. " + "If the mismatch comes from `index_mapping` field, you might " + "have modified reference structure within the body function, " + "which is not allowed." + f"Detail of the mismatch: \n {str(e)}") + raise ValueError(msg) + + return pure_out + + +@graph.update_context('while_loop') +def while_loop(cond_fun: tp.Callable[[T], tp.Any], + body_fun: tp.Callable[[T], T], + init_val: T) -> T: + """A Flax NNX transformation of `jax.lax.while_loop `_. + + Caution: for the NNX internal reference tracing mechanism to work, you cannot + change the variable reference structure of ``init_val`` inside ``body_fun``. + + Example:: + + >>> import jax + >>> from flax import nnx + >>> def fwd_fn(input): + ... module, x, count = input + ... return module, module(x), count - 1.0 + + >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + >>> x = jax.random.normal(jax.random.key(0), (10,)) + >>> # `module` will be called three times + >>> _, y, _ = nnx.while_loop( + ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + + + Args: + cond_fun: A function for the continue condition of the while loop, taking a + single input of type ``T`` and outputting a boolean. + body_fun: A function that takes an input of type ``T`` and outputs an ``T``. + Note that both data and modules of ``T`` must have the same reference + structure between inputs and outputs. + init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``. + + """ + + pure_init_val = extract.to_tree(init_val, ctxtag='while_loop') + + # Adding the expected reference mapping to `pure_init_val` to match + # `body_fun`'s output pytree structure, to make JAX while_loop happy. + pure_init_val = _add_fake_index_mapping(pure_init_val) + + pure_out = jax.lax.while_loop( + WhileLoopCondFn(cond_fun), + WhileLoopBodyFn(body_fun), + pure_init_val, + ) + out = extract.from_tree(pure_out, ctxtag='while_loop') + return out + + +@dataclasses.dataclass(eq=False) +class ForiLoopBodyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + @graph.update_context('fori_loop_body') + def __call__(self, i, pure_val): + # Removing the dummy index mapping being added outside of body function. + pure_val_in = _remove_index_mapping(pure_val) + + val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body') + out = self.f(i, val) + pure_out = extract.to_tree(out, ctxtag='fori_loop_body') + + try: + jax.tree.map(lambda a, b: None, pure_val, pure_out) + except ValueError as e: + msg = ("nnx.fori_loop requires body function's input and output to " + "have the same reference and pytree structure, but they differ. " + "If the mismatch comes from `index_mapping` field, you might " + "have modified reference structure within the body function, " + "which is not allowed. " + f"Detail of the mismatch: \n {str(e)}") + raise ValueError(msg) + + return pure_out + + +@graph.update_context('fori_loop') +def fori_loop(lower: int, upper: int, + body_fun: tp.Callable[[int, T], T], + init_val: T, + *, + unroll: int | bool | None = None) -> T: + """A Flax NNX transformation of `jax.lax.fori_loop `_. + + Caution: for the NNX internal reference tracing mechanism to work, you cannot + change the variable reference structure of `init_val` inside `body_fun`. + + Example:: + + >>> import jax + >>> from flax import nnx + + >>> def fwd_fn(i, input): + ... m, x = input + ... m.kernel.value = jnp.identity(10) * i + ... return m, m(x) + + >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + >>> x = jax.random.normal(jax.random.key(0), (10,)) + >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) + >>> np.testing.assert_array_equal(y, x * 2 * 3) + + + Args: + lower: An integer representing the loop index lower bound (inclusive). + upper: An integer representing the loop index upper bound (exclusive). + body_fun: a function that takes an input of type ``T`` and outputs an ``T``. + Note that both data and modules of ``T`` must have the same reference + structure between inputs and outputs. + init_val: the initial input for body_fun. Must be of type ``T``. + unroll: An optional integer or boolean that determines how much to unroll + the loop. If an integer is provided, it determines how many unrolled + loop iterations to run within a single rolled iteration of the loop. If a + boolean is provided, it will determine if the loop is competely unrolled + (i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``). + This argument is only applicable if the loop bounds are statically known. + + Returns: + A loop value from the final iteration, of type ``T``. + + """ + + pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop') + + # Adding the expected reference mapping to `pure_init_val` to match + # `body_fun`'s output pytree structure, to make JAX happy. + pure_init_val = _add_fake_index_mapping(pure_init_val) + + pure_out = jax.lax.fori_loop(lower, upper, + ForiLoopBodyFn(body_fun), pure_init_val, + unroll=unroll) + out = extract.from_tree(pure_out, ctxtag='fori_loop') + return out diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 558584dd..8a83a026 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -15,12 +15,16 @@ from __future__ import annotations from abc import abstractmethod +import dataclasses import functools import inspect import typing as tp +from jax._src import checkify as checkify_lib + from flax.nnx import ( extract, + graph, ) from flax.nnx.module import Module from flax.nnx.proxy_caller import ( @@ -119,7 +123,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): # ------------------------------- -# eval_shape +# simple transforms # ------------------------------- @@ -138,11 +142,85 @@ def _eval_shape_fn(*args, **kwargs): out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) + """A "lifted" version of `jax.eval_shape `_ + that can handle `flax.nnx.Module `_ + / graph nodes as arguments. + + Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without + performing any floating point operations (FLOPs) which can be expensive. This can be + useful for performing shape inference, for example. + """ + +@dataclasses.dataclass(eq=False) +class CheckifyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args, **pure_kwargs): + args, kwargs = extract.from_tree( + (pure_args, pure_kwargs), ctxtag='checkify' + ) + out = self.f(*args, **kwargs) + + args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) + pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( + (args, kwargs, out), ctxtag='checkify' + ) + return pure_args_out, pure_kwargs_out, pure_out + +def checkify( + f: tp.Callable[..., checkify_lib.Out], + errors: frozenset[type[checkify_lib.JaxException]] = checkify_lib.user_checks, # type: ignore +) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]: + """Reference-aware version of `jax.experimental.checkify + `_. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> import dataclasses + >>> from flax import nnx + ... + >>> @dataclasses.dataclass + ... class Foo(nnx.Module): + ... a: nnx.Param + ... + >>> @nnx.jit + ... def f(m): + ... y = jnp.sin(m.a.value) # error + ... return m.a + y + ... + >>> m = Foo(a=nnx.Param(jnp.inf)) + >>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m) + >>> # err.throw() + >>> print(err) + Error(nan generated by primitive: sin.) + """ + checkify_fn = checkify_lib.checkify(CheckifyFn(f), errors) + @functools.wraps(f) + @graph.update_context('checkify') + def jit_wrapper(*args, **kwargs): + pure_args, pure_kwargs = extract.to_tree( + (args, kwargs), + ctxtag='checkify', + ) + error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn( + *pure_args, **pure_kwargs + ) -# ------------------------------- -# cond -# ------------------------------- + args_out, kwargs_out, out = extract.from_tree( + (pure_args_out, pure_kwargs_out, pure_out), + ctxtag='checkify', + ) + + return error, out + + return jit_wrapper # type: ignore @general.split_inputs(ctxtag='cond') @@ -160,3 +238,17 @@ def cond( *operands, **kwargs, ) + + +@general.split_inputs(ctxtag='switch') +def switch( + index, + branches: tp.Sequence[tp.Callable[..., A]], + *operands, +) -> A: + return jax.lax.switch( + index, + [general.merge_inputs(f, ctxtag='switch') for f in branches], + *operands, + ) + diff --git a/flax/nnx/variables.py b/flax/nnx/variablelib.py similarity index 76% rename from flax/nnx/variables.py rename to flax/nnx/variablelib.py index ee6c8a00..7af20cdb 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variablelib.py @@ -23,8 +23,8 @@ import jax from flax import errors -from flax.nnx import reprlib, tracers -from flax.typing import Missing +from flax.nnx import filterlib, reprlib, tracers +from flax.typing import Missing, PathParts import jax.tree_util as jtu A = tp.TypeVar('A') @@ -120,131 +120,87 @@ class Variable(tp.Generic[A], reprlib.Representable): """ raw_value: A - set_value_hooks: tuple[SetValueHook[A], ...] - get_value_hooks: tuple[GetValueHook[A], ...] - create_value_hooks: tuple[CreateValueHook[A], ...] - add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] - remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] _trace_state: tracers.TraceState + _var_metadata: dict[str, tp.Any] def __init__( self, value: tp.Union[A, VariableMetadata[A]], - *, - set_value_hooks: tp.Union[ - SetValueHook[A], tp.Sequence[SetValueHook[A]] - ] = (), - get_value_hooks: tp.Union[ - GetValueHook[A], tp.Sequence[GetValueHook[A]] - ] = (), - create_value_hooks: tp.Union[ - CreateValueHook[A], tp.Sequence[CreateValueHook[A]] - ] = (), - add_axis_hooks: tp.Union[ - AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] - ] = (), - remove_axis_hooks: tp.Union[ - RemoveAxisHook[Variable[A]], - tp.Sequence[RemoveAxisHook[Variable[A]]], - ] = (), **metadata: tp.Any, ): - vars(self)['_trace_state'] = tracers.TraceState() - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) - else: - set_value_hooks = tuple(set_value_hooks) - - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) - else: - get_value_hooks = tuple(get_value_hooks) - - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) - else: - create_value_hooks = tuple(create_value_hooks) - - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) - else: - add_axis_hooks = tuple(add_axis_hooks) - - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) - else: - remove_axis_hooks = tuple(remove_axis_hooks) + type_vars = vars(type(self)) + vars_self = vars(self) + vars_self['_trace_state'] = tracers.TraceState() if isinstance(value, VariableMetadata): - value_metadata = dict(value.metadata) - if value.set_value_hooks: - set_value_hooks = set_value_hooks + value.set_value_hooks - if value.get_value_hooks: - get_value_hooks = get_value_hooks + value.get_value_hooks - if value.create_value_hooks: - create_value_hooks = create_value_hooks + value.create_value_hooks - if value.add_axis_hooks: - add_axis_hooks = add_axis_hooks + value.add_axis_hooks - if value.remove_axis_hooks: - remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks - - metadata.update(value_metadata) + metadata.update(value.metadata) value = tp.cast(A, value.raw_value) - self.raw_value = value + object.__setattr__(self, 'raw_value', value) - if 'on_get_value' in vars(type(self)): - on_get_value = getattr(type(self), 'on_get_value') - if on_get_value not in get_value_hooks: - get_value_hooks = (on_get_value, *get_value_hooks) + if 'on_get_value' in type_vars and 'on_get_value' not in metadata: + metadata['get_value'] = getattr(type(self), 'on_get_value') - if 'on_set_value' in vars(type(self)): - on_set_value = getattr(type(self), 'on_set_value') - if on_set_value not in set_value_hooks: - set_value_hooks = (on_set_value, *set_value_hooks) + if 'on_set_value' in type_vars and 'on_set_value' not in metadata: + metadata['set_value'] = getattr(type(self), 'on_set_value') - if 'on_create_value' in vars(type(self)): - on_create_value = getattr(type(self), 'on_create_value') - if on_create_value not in create_value_hooks: - create_value_hooks = (on_create_value, *create_value_hooks) + if 'on_create_value' in type_vars and 'on_create_value' not in metadata: + metadata['create_value'] = getattr(type(self), 'on_create_value') - if 'on_add_axis' in vars(type(self)): - on_add_axis = getattr(type(self), 'on_add_axis') - if on_add_axis not in add_axis_hooks: - add_axis_hooks = (on_add_axis, *add_axis_hooks) + if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata: + metadata['add_axis'] = getattr(type(self), 'on_add_axis') - if 'on_remove_axis' in vars(type(self)): - on_remove_axis = getattr(type(self), 'on_remove_axis') - if on_remove_axis not in remove_axis_hooks: - remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) - - self.get_value_hooks = get_value_hooks - self.set_value_hooks = set_value_hooks - self.create_value_hooks = create_value_hooks - self.add_axis_hooks = add_axis_hooks - self.remove_axis_hooks = remove_axis_hooks - vars(self).update(metadata) + if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata: + metadata['remove_axis'] = getattr(type(self), 'on_remove_axis') + vars_self['_var_metadata'] = metadata # run create_value hooks - self.raw_value = self.create_value(self.raw_value) + vars_self['raw_value'] = self.create_value(self.raw_value) - if not tp.TYPE_CHECKING: + def __getattr__(self, name: str) -> tp.Any: + if name in vars(self)['_var_metadata']: + return self._var_metadata[name] + return getattr(self.value, name) - def __setattr__(self, name: str, value: Any) -> None: - return self._setattr(name, value) + def __setattr__(self, name: str, value: tp.Any): + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' + ) - def _setattr(self, name: str, value: tp.Any): + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value + + def __delattr__(self, name: str): if not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) - object.__setattr__(self, name, value) + if ( + name == 'value' + or name == 'raw_value' + or name == '_var_metadata' + or name == '_trace_state' + ): + object.__delattr__(self, name) + else: + del self._var_metadata[name] @classmethod def state(cls, value: A, **metadata) -> VariableState[A]: return cls(value, **metadata).to_state() + def get_metadata(self): + return self._var_metadata + def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( @@ -253,29 +209,20 @@ def copy_from(self, other: Variable[A]) -> None: ) if self is other: return - trace_state = self._trace_state - vars_dict = vars(self) - other_vars = vars(other).copy() - del other_vars['_trace_state'] - vars_dict.clear() - vars_dict.update(other_vars, _trace_state=trace_state) + self.raw_value = other.raw_value + self._var_metadata.clear() + self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: VariableState[A]): - trace_state = self._trace_state - variable_vars = vars(self) - variable_vars.clear() - variable_vars.update( - variable_state.get_metadata(), - raw_value=variable_state.value, - _trace_state=trace_state, - ) + vars_self = vars(self) + vars_self['raw_value'] = variable_state.value + vars_self['_var_metadata'] = variable_state.get_metadata().copy() @property def value(self) -> A: value = self.raw_value - if self.get_value_hooks: - for hook in self.get_value_hooks: - value = hook(self, value) + if 'on_get_value' in self._var_metadata: + value = self._var_metadata['on_get_value'](self, value) return value @value.setter @@ -284,23 +231,22 @@ def value(self, value: A): raise ValueError( 'Cannot set value to a Variable, ' 'use `copy_from` method instead' ) - if self.set_value_hooks: - for hook in self.set_value_hooks: - value = hook(self, value) - self.raw_value = value + if 'on_set_value' in self._var_metadata: + value = self._var_metadata['on_set_value'](self, value) + vars(self)['raw_value'] = value def create_value(self, value: A): - for hook in self.create_value_hooks: - value = hook(self, value) + if 'on_create_value' in self._var_metadata: + value = self._var_metadata['on_create_value'](self, value) return value def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @@ -338,45 +284,43 @@ def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: return value # get and update attributes - attributes = vars(self).copy() - attributes.update(**kwargs) # return new instance with updated attributes obj = object.__new__(type(self)) - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value')) + object.__setattr__(obj, '_var_metadata', self.get_metadata()) + obj._var_metadata.update(kwargs) + return obj + + @classmethod + def from_metadata(cls, value: A, attributes: tp.Mapping[str, tp.Any]): + obj = object.__new__(cls) + object.__setattr__(obj, '_trace_state', tracers.TraceState()) + object.__setattr__(obj, 'raw_value', value) + object.__setattr__(obj, '_var_metadata', attributes) return obj def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) - attributes = vars(self).copy() - attributes['_trace_state'] = tracers.TraceState() - vars(obj).update(attributes) + object.__setattr__(obj, '_trace_state', self._trace_state) + object.__setattr__(obj, 'raw_value', self.raw_value) + object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) return obj def to_state(self: Variable[A]) -> VariableState[A]: - metadata = vars(self).copy() - del metadata['raw_value'] - del metadata['_trace_state'] + metadata = self.get_metadata() return VariableState(type(self), self.raw_value, **metadata) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue + yield reprlib.Attr('value', self.raw_value) + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {} - for name, value in vars(self).items(): - if name == 'raw_value': - name = 'value' - if name.endswith('_hooks') or name == '_trace_state': - continue - children[name] = value + children = {'value': self.raw_value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -404,13 +348,16 @@ def on_remove_axis( def __jax_array__(self): return self.value + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + # -------------------------------------------- # proxy methods # -------------------------------------------- - # NOTE: we dont override __setattr__ to avoid cases where - # you need to set an attribute on the variable instance - def __getattr__(self, name: str) -> tp.Any: - return getattr(self.value, name) def __getitem__(self, key) -> tp.Any: return self.value[key] # type: ignore @@ -784,39 +731,51 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): + __slots__ = ('type', 'value', '_var_metadata') + type: type[Variable[A]] + value: A + _var_metadata: dict[str, tp.Any] + def __init__( self, - type: type[Variable[tp.Any]], + type: type[Variable[A]], # type: ignore [valid-type] value: A, **metadata, ): - self.type = type - self.value = value - vars(self).update(metadata) - - if tp.TYPE_CHECKING: + object.__setattr__(self, 'type', type) + object.__setattr__(self, 'value', value) + object.__setattr__(self, '_var_metadata', metadata) + + def __getattr__(self, name: str) -> None: + var_metadata = object.__getattribute__(self, '_var_metadata') + if name not in var_metadata: + raise AttributeError(f"'VariableState' object has no attribute '{name}'") + return var_metadata[name] + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__setattr__(self, name, value) + else: + self._var_metadata[name] = value - def __getattr__(self, name: str) -> None: ... - def __setattr__(self, name: str, value: Any) -> None: ... - def __delattr__(self, name: str) -> None: ... + def __delattr__(self, name: str) -> None: + if name == 'type' or name == 'value' or name == '_var_metadata': + object.__delattr__(self, name) + else: + del self._var_metadata[name] def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('value', self.value) - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue + for name, value in self._var_metadata.items(): yield reprlib.Attr(name, repr(value)) def __treescope_repr__(self, path, subtree_renderer): import treescope # type: ignore[import-not-found,import-untyped] - children = {'type': self.type} - for name, value in vars(self).items(): - if name == 'type' or name.endswith('_hooks'): - continue - children[name] = value + children = {'type': self.type, 'value': self.value, **self._var_metadata} return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes=children, @@ -830,29 +789,25 @@ def replace(self, value: B) -> VariableState[B]: def to_variable(self) -> Variable[A]: # we use object.__new__ to avoid calling __init__ and bypass the # __init__ logic which should not be called twice - metadata = self.get_metadata() - variables = object.__new__(self.type) - vars(variables).update( - metadata, raw_value=self.value, _trace_state=tracers.TraceState() - ) - return variables + variable = object.__new__(self.type) + object.__setattr__(variable, '_trace_state', tracers.TraceState()) + object.__setattr__(variable, 'raw_value', self.value) + object.__setattr__(variable, '_var_metadata', self.get_metadata().copy()) + return variable def copy(self: VariableState[A]) -> VariableState[A]: return jax.tree.map(lambda x: x, self) def get_metadata(self) -> dict[str, tp.Any]: - metadata = vars(self).copy() - del metadata['type'] - del metadata['value'] - return metadata + return self._var_metadata def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.add_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_add_axis' in self._var_metadata: + self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): - for hook in self.remove_axis_hooks: - hook(self, axis_index, axis_name) + if 'on_remove_axis' in self._var_metadata: + self._var_metadata['on_remove_axis'](self, axis_index, axis_name) def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): @@ -953,3 +908,29 @@ def wrapper(*args): ) return wrapper # type: ignore + + +def split_flat_state( + flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], + filters: tuple[filterlib.Filter, ...], +) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: + predicates = filterlib.filters_to_predicates(filters) + # we have n + 1 states, where n is the number of predicates + # the last state is for values that don't match any predicate + flat_states: tuple[list[tuple[PathParts, Variable | VariableState]], ...] = ( + tuple([] for _ in predicates) + ) + + for path, value in flat_state: + for i, predicate in enumerate(predicates): + if predicate(path, value): + flat_states[i].append((path, value)) + break + else: + raise ValueError( + 'Non-exhaustive filters, got a non-empty remainder: ' + f'{path} -> {value}.' + '\nUse `...` to match all remaining elements.' + ) + + return flat_states diff --git a/flax/struct.py b/flax/struct.py index b4d242c4..4e8de0a7 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -14,8 +14,10 @@ """Utilities for defining custom classes that can be used with jax transformations.""" +from collections.abc import Callable import dataclasses -from typing import TypeVar +import functools +from typing import TypeVar, overload import jax from typing_extensions import ( @@ -33,7 +35,22 @@ def field(pytree_node=True, *, metadata=None, **kwargs): @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] +@overload def dataclass(clz: _T, **kwargs) -> _T: + ... + + +@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] +@overload +def dataclass(**kwargs) -> Callable[[_T], _T]: + ... + + +@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] +def dataclass( + clz: _T | None = None, + **kwargs, +) -> _T | Callable[[_T], _T]: """Create a class which can be passed to functional transformations. .. note:: @@ -99,9 +116,15 @@ class method that provides the smart constructor. Args: clz: the class that will be transformed by the decorator. + **kwargs: arguments to pass to the dataclass constructor. + Returns: The new class. """ + # Support passing arguments to the decorator (e.g. @dataclass(kw_only=True)) + if clz is None: + return functools.partial(dataclass, **kwargs) + # check if already a flax dataclass if '_flax_dataclass' in clz.__dict__: return clz @@ -119,46 +142,12 @@ class method that provides the smart constructor. meta_fields.append(field_info.name) def replace(self, **updates): - """ "Returns a new object replacing the specified fields with new values.""" + """Returns a new object replacing the specified fields with new values.""" return dataclasses.replace(self, **updates) data_clz.replace = replace - # Remove this guard once minimux JAX version is >0.4.26. - try: - if hasattr(jax.tree_util, 'register_dataclass'): - jax.tree_util.register_dataclass( - data_clz, data_fields, meta_fields - ) - else: - raise NotImplementedError - except NotImplementedError: - - def iterate_clz(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple(getattr(x, name) for name in data_fields) - return data, meta - - def iterate_clz_with_keys(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple( - (jax.tree_util.GetAttrKey(name), getattr(x, name)) - for name in data_fields - ) - return data, meta - - def clz_from_iterable(meta, data): - meta_args = tuple(zip(meta_fields, meta)) - data_args = tuple(zip(data_fields, data)) - kwargs = dict(meta_args + data_args) - return data_clz(**kwargs) - - jax.tree_util.register_pytree_with_keys( - data_clz, - iterate_clz_with_keys, - clz_from_iterable, - iterate_clz, - ) + jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields) def to_state_dict(x): state_dict = { diff --git a/flax/typing.py b/flax/typing.py index 964de057..a630a357 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -19,9 +19,9 @@ Generic, Optional, Protocol, + TypeGuard, TypeVar, Union, - runtime_checkable, ) from collections.abc import Callable, Hashable, Mapping, Sequence @@ -41,11 +41,12 @@ Shape = Sequence[int] K = TypeVar('K') -@runtime_checkable class Key(Hashable, Protocol): def __lt__(self: K, value: K, /) -> bool: ... +def is_key_like(x: Any) -> TypeGuard[Key]: + return hasattr(x, '__hash__') and hasattr(x, '__lt__') Path = str PathParts = tuple[Key, ...] @@ -117,16 +118,17 @@ class Out(Generic[T]): # SPMD LogicalNames = tuple[Union[str, None], ...] +AxisName = str | tuple[str, ...] | None # Maps each logical axis to physical mesh, can be either None (replicated), # one physical axis or a tuple of physical axes. -LogicalRules = Sequence[tuple[str, Union[str, tuple[str, ...], None]]] +LogicalRules = Sequence[tuple[str, AxisName]] ArrayPytree = Any # pylint: disable=invalid-name LogicalPartitionSpec = Any # pylint: disable=invalid-name LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name PartitionSpecPytree = Any # pylint: disable=invalid-name -Sharding = tuple[Optional[str], ...] +Sharding = tuple[AxisName, ...] A = TypeVar('A') @@ -158,4 +160,4 @@ class Missing: pass -MISSING = Missing() \ No newline at end of file +MISSING = Missing() diff --git a/flax/version.py b/flax/version.py index dd9609eb..bdba069c 100644 --- a/flax/version.py +++ b/flax/version.py @@ -13,4 +13,4 @@ # limitations under the License. """Current Flax version at head on Github.""" -__version__ = '0.9.0' +__version__ = '0.10.2' diff --git a/flaxlib/README.md b/flaxlib/README.md deleted file mode 100644 index 66910f7e..00000000 --- a/flaxlib/README.md +++ /dev/null @@ -1 +0,0 @@ -# flaxlib \ No newline at end of file diff --git a/flaxlib/flaxlib/__init__.py b/flaxlib/flaxlib/__init__.py deleted file mode 100644 index 435dad41..00000000 --- a/flaxlib/flaxlib/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from flaxlib.flaxlib import sum_as_string as sum_as_string diff --git a/flaxlib/.gitignore b/flaxlib_src/.gitignore similarity index 100% rename from flaxlib/.gitignore rename to flaxlib_src/.gitignore diff --git a/flaxlib/Cargo.lock b/flaxlib_src/Cargo.lock similarity index 94% rename from flaxlib/Cargo.lock rename to flaxlib_src/Cargo.lock index fb772bb9..6a6decf9 100644 --- a/flaxlib/Cargo.lock +++ b/flaxlib_src/Cargo.lock @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", @@ -128,9 +128,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -148,9 +148,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -160,9 +160,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck", "proc-macro2", diff --git a/flaxlib/Cargo.toml b/flaxlib_src/Cargo.toml similarity index 93% rename from flaxlib/Cargo.toml rename to flaxlib_src/Cargo.toml index b6445579..80e95152 100644 --- a/flaxlib/Cargo.toml +++ b/flaxlib_src/Cargo.toml @@ -9,4 +9,4 @@ name = "flaxlib" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.20.3" +pyo3 = "0.21.2" diff --git a/flaxlib/LICENSE b/flaxlib_src/LICENSE similarity index 100% rename from flaxlib/LICENSE rename to flaxlib_src/LICENSE diff --git a/flaxlib_src/README.md b/flaxlib_src/README.md new file mode 100644 index 00000000..29b4a837 --- /dev/null +++ b/flaxlib_src/README.md @@ -0,0 +1,34 @@ +# flaxlib + +## Build flaxlib from source + +Install necessary dependencies to build the C++ based package. + +```shell +pip install meson-python ninja build +``` + +Clone the Flax repository, navigate to the flaxlib source directory. + +```shell +git clone git@github.com:google/flax.git +cd flax/flaxlib_src +``` + +Configure the build. + +```shell +mkdir -p subprojects +meson wrap install robin-map +meson wrap install nanobind +meson setup builddir +``` + +Compile the code. You'll need to run this repeatedly if you modify the source +code. Note that the actual wheel name will differ depending on your system. + +```shell +meson compile -C builddir +python -m build . -w +pip install dist/flaxlib-0.0.1-cp311-cp311-macosx_14_0_arm64.whl --force-reinstall +``` diff --git a/flaxlib/flaxlib/flaxlib.pyi b/flaxlib_src/flaxlib.pyi similarity index 100% rename from flaxlib/flaxlib/flaxlib.pyi rename to flaxlib_src/flaxlib.pyi diff --git a/flaxlib_src/meson.build b/flaxlib_src/meson.build new file mode 100644 index 00000000..0d78d943 --- /dev/null +++ b/flaxlib_src/meson.build @@ -0,0 +1,14 @@ +project( + 'flaxlib', + 'cpp', + version: '0.0.1', + default_options: ['cpp_std=c++17'], +) +py = import('python').find_installation() +nanobind_dep = dependency('nanobind', static: true) +py.extension_module( + 'flaxlib', + sources: ['src/lib.cc'], + dependencies: [nanobind_dep], + install: true, +) \ No newline at end of file diff --git a/flaxlib/pyproject.toml b/flaxlib_src/pyproject.toml similarity index 67% rename from flaxlib/pyproject.toml rename to flaxlib_src/pyproject.toml index 993b9703..0afc7699 100644 --- a/flaxlib/pyproject.toml +++ b/flaxlib_src/pyproject.toml @@ -1,12 +1,12 @@ [build-system] -requires = ["maturin>=1.7,<2.0"] -build-backend = "maturin" +requires = ['meson-python'] +build-backend = 'mesonpy' [project] name = "flaxlib" requires-python = ">=3.10" classifiers = [ - "Programming Language :: Rust", + "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] @@ -15,5 +15,3 @@ dynamic = ["version"] tests = [ "pytest", ] -[tool.maturin] -features = ["pyo3/extension-module"] diff --git a/flaxlib_src/src/lib.cc b/flaxlib_src/src/lib.cc new file mode 100644 index 00000000..c7145881 --- /dev/null +++ b/flaxlib_src/src/lib.cc @@ -0,0 +1,14 @@ +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" + +namespace flaxlib { +std::string sum_as_string(int a, int b) { + return std::to_string(a + b); +} + +NB_MODULE(flaxlib, m) { + m.def("sum_as_string", &sum_as_string); +} +} // namespace flaxlib \ No newline at end of file diff --git a/flaxlib/src/lib.rs b/flaxlib_src/src/lib.rs similarity index 93% rename from flaxlib/src/lib.rs rename to flaxlib_src/src/lib.rs index 81180d86..cadab2ef 100644 --- a/flaxlib/src/lib.rs +++ b/flaxlib_src/src/lib.rs @@ -22,7 +22,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// A Python module implemented in Rust. #[pymodule] -fn flaxlib(_py: Python, m: &PyModule) -> PyResult<()> { +fn flaxlib(_py: Python, m: &Bound) -> PyResult<()> { m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; Ok(()) } diff --git a/flaxlib/uv.lock b/flaxlib_src/uv.lock similarity index 100% rename from flaxlib/uv.lock rename to flaxlib_src/uv.lock diff --git a/pyproject.toml b/pyproject.toml index 0b21a5c2..658b2f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "rich>=11.1", "typing_extensions>=4.2", "PyYAML>=5.4.1", + "treescope>=0.1.2", ] classifiers = [ "Development Status :: 3 - Alpha", @@ -62,6 +63,7 @@ testing = [ "tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error "torch", "treescope>=0.1.1; python_version>='3.10'", + "cloudpickle>=3.0.0", ] docs = [ "sphinx>=3.3.1", @@ -75,11 +77,9 @@ docs = [ "sphinx-design", "jupytext==1.13.8", "dm-haiku", - # Need to pin docutils to 0.16 to make bulleted lists appear correctly on # ReadTheDocs: https://stackoverflow.com/a/68008428 "docutils==0.16", - # The next packages are for notebooks. "matplotlib", "scikit-learn", @@ -87,6 +87,8 @@ docs = [ "ml_collections", # notebooks "einops", + "kagglehub>=0.3.3", + "ipywidgets>=8.1.5", ] dev = [ "pre-commit>=3.8.0", @@ -181,6 +183,8 @@ filterwarnings = [ "ignore:.*invalid value encountered in cast.*:RuntimeWarning", # RuntimeWarning: divide by zero encountered in equal/not_equal "ignore:.*divide by zero encountered in.*:RuntimeWarning", + # DeprecationWarning: numpy.core is deprecated + "ignore:.*numpy.core is deprecated.*:DeprecationWarning", ] [tool.coverage.report] @@ -221,3 +225,7 @@ unfixable = [] indent-style = "space" quote-style = "single" + +[tool.uv] +# Ignore uv.lock and always upgrade the package to the latest +upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"] diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 27f74fe3..5ff7e3e6 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - import jax import numpy as np from absl.testing import absltest @@ -23,10 +21,6 @@ from flax import errors from flax.core import FrozenDict, apply, copy, init, lift, nn -# TODO(jakevdp): use jax.debug_key_reuse directly once min jax version is 0.4.26 -jax_debug_key_reuse = (jax.debug_key_reuse if hasattr(jax, 'debug_key_reuse') - else jax.enable_key_reuse_checks) - class LiftTest(absltest.TestCase): def test_aliasing(self): @@ -128,11 +122,6 @@ def f(scope, x): np.testing.assert_allclose(y_t, jnp.ones_like(x)) def test_while_loop(self): - def clone(key): - if hasattr(jax.random, "clone"): - # jax v0.4.26+ - return jax.random.clone(key) - return key def f(scope, x): key_zero = random.key(0) @@ -140,7 +129,7 @@ def f(scope, x): scope.param('inc', lambda _: 1) scope.put_variable('state', 'acc', 0) scope.put_variable('state', 'rng_params', key_zero) - scope.put_variable('state', 'rng_loop', clone(key_zero)) + scope.put_variable('state', 'rng_loop', jax.random.clone(key_zero)) def cond_fn(scope, c): acc = scope.get_variable('state', 'acc') @@ -179,12 +168,11 @@ def body_fn(scope, c): ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) - np.testing.assert_array_equal( + self.assertEqual( vars['state']['rng_params'][0], vars['state']['rng_params'][1] ) - with jax_debug_key_reuse(False): - np.testing.assert_array_compare( - operator.__ne__, + with jax.debug_key_reuse(False): + self.assertNotEqual( vars['state']['rng_loop'][0], vars['state']['rng_loop'][1], ) diff --git a/tests/flaxlib_test.py b/tests/flaxlib_test.py index dc36f6a2..c23f70ba 100644 --- a/tests/flaxlib_test.py +++ b/tests/flaxlib_test.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest -import flaxlib +# TODO: Re-enable this test after setting up CI build for flaxlib CC. -class TestFlaxlib(absltest.TestCase): +# from absl.testing import absltest +# import flaxlib - def test_flaxlib(self): - self.assertEqual(flaxlib.sum_as_string(1, 2), '3') + +# class TestFlaxlib(absltest.TestCase): + +# def test_flaxlib(self): +# self.assertEqual(flaxlib.sum_as_string(1, 2), '3') diff --git a/tests/linen/linen_activation_test.py b/tests/linen/linen_activation_test.py index 5f8369c2..6d4d0eb4 100644 --- a/tests/linen/linen_activation_test.py +++ b/tests/linen/linen_activation_test.py @@ -14,13 +14,13 @@ """Tests for flax.linen.activation.""" +from absl.testing import absltest +from flax import linen as nn import jax +from jax import random import jax.numpy as jnp import numpy as np -from absl.testing import absltest -from jax import random -from flax import linen as nn # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @@ -44,32 +44,6 @@ def test_prelu(self): np.testing.assert_array_almost_equal(expected_y, y) np.testing.assert_array_equal(init_negative_slope, expected_negative_slope) - def test_geglu(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU() - expected_result = jnp.array([[0.00024275, -0.00208032], - [0.00336634, -0.02307648]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - - def test_geglu_with_dim_expansion(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU(3) - expected_result = jnp.array([[-0.02157649, -0.00018928, -0.01176354], - [-0.08777858, 0.00258885, -0.18744925]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - - def test_geglu_with_dim_contraction(self): - rng = random.key(0) - x = jnp.array([[0.123,0.234], [0.456,0.789]]) - act = nn.GeGLU(1) - expected_result = jnp.array([[0.00224223], [0.0307451 ]]) - y, _ = act.init_with_output(rng, x) - np.testing.assert_array_almost_equal(y, expected_result) - if __name__ == '__main__': absltest.main() diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 2e7a6789..661c0474 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -113,10 +113,7 @@ def test_multihead_self_attention_w_dropout(self): def test_multihead_self_attention_explicit_dropout(self): def clone(key): - if hasattr(jax.random, "clone"): - # JAX v0.4.26+ - return jax.tree.map(jax.random.clone, key) - return key + return jax.tree.map(jax.random.clone, key) class Foo(nn.Module): attention_kwargs: dict diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index fdee1443..c70a08e8 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -1783,7 +1783,7 @@ class MyModule2(nn.Module): submodule: MyComponent2[jnp.ndarray] def test_jit_rng_equivalance(self): - model = nn.Dense(1, use_bias=False) + model = nn.fold_rngs(nn.Dense)(1, use_bias=False) jit_model = nn.jit(nn.Dense)(1, use_bias=False) param = model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] param_2 = jit_model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index ff31efdd..96539d89 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1040,17 +1040,12 @@ def test_dropout_broadcast( self.assertTrue(slice_fn(out, i).sum() in (0, summed_total)) def test_dropout_manual_rng(self): - def clone(key): - if hasattr(jax.random, 'clone'): - # JAX v0.4.26+ - return jax.random.clone(key) - return key class Foo(nn.Module): @nn.compact def __call__(self, x): key = self.make_rng('dropout') x1 = nn.Dropout(rate=0.5, deterministic=False)(x, rng=key) - x2 = nn.Dropout(rate=0.5, deterministic=False)(x, rng=clone(key)) + x2 = nn.Dropout(rate=0.5, deterministic=False)(x, rng=jax.random.clone(key)) return x1, x2 module = Foo() diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 7720aa7e..d5634a01 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -32,10 +32,6 @@ import jax.numpy as jnp import numpy as np -# TODO(jakevdp): use jax.debug_key_reuse directly once min jax version is 0.4.26 -jax_debug_key_reuse = (jax.debug_key_reuse if hasattr(jax, 'debug_key_reuse') - else jax.enable_key_reuse_checks) - # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @@ -90,6 +86,10 @@ def __call__(self, inputs): class TransformTest(parameterized.TestCase): + def assert_keys_equal(self, key1, key2): + self.assertEqual(key1.dtype, key2.dtype) + np.testing.assert_array_equal(random.key_data(key1), random.key_data(key2)) + def test_jit(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) @@ -114,6 +114,19 @@ def test_jit_decorated(self): self.assertTrue(np.all(y1 == y2)) + def test_jit_init_fn(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(2)(x) + + @nn.jit + def init_with_output(self, rngs, *args, **kwargs): + return super().init_with_output(rngs, *args, **kwargs) + + Foo().init_with_output(random.key(0), jnp.ones((2, 3))) + + def test_remat(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) @@ -1581,6 +1594,7 @@ def setup(self): def helper(self, x, ms): return ms[0](x) + ms[1](x) + @nn.fold_rngs def __call__(self, x): return self.helper(x, self.inners) @@ -1589,7 +1603,6 @@ class JitFoo(nn.Module): def setup(self): self.inners = [nn.Dense(2), nn.Dense(2)] - @nn.jit def helper(self, x, ms): return ms[0](x) + ms[1](x) @@ -1745,6 +1758,7 @@ def setup(self): def setup_helper(self): self.b = nn.Dense(2) + @nn.fold_rngs def __call__(self, x): return self.b(self.a(x)) @@ -1824,6 +1838,61 @@ def __call__(self, x): ) np.testing.assert_array_equal(y, 2) + def test_fold_rngs(self): + class Foo(nn.Module): + + def __call__(self, use_jit: bool): + def f(foo: Foo): + return foo.make_rng('params') + + if use_jit: + key = nn.jit(f)(self) + else: + key = nn.fold_rngs(f)(self) + + return key + + foo = Foo() + key_jit = foo.apply({}, True, rngs={'params': random.key(0)}) + key_fold_rngs = foo.apply({}, False, rngs={'params': random.key(0)}) + + self.assert_keys_equal(key_jit, key_fold_rngs) + + def test_same_key(self): + + class Block(nn.Module): + + @nn.jit + @nn.compact + def __call__(self, carry, inputs): + # dump_rng_info(self) + key = self.make_rng('params') + # y = jax.random.uniform(self.make_rng('params'), (2,)) + return carry, key + + class Transformer(nn.Module): + + @nn.compact + def __call__(self): + num_blocks = 10 + carry, key = nn.scan( + Block, + variable_axes={'params': 0}, + split_rngs={'params': True}, + # length=num_blocks, + )()(None, jnp.arange(num_blocks)) + return key + + model = Transformer() + keys1, _ = model.init_with_output(jax.random.key(1)) + keys2, _ = model.init_with_output(jax.random.key(1)) + keys3, _ = model.init_with_output(jax.random.key(1)) + keys4, _ = model.init_with_output(jax.random.key(1)) + + self.assert_keys_equal(keys1, keys2) + self.assert_keys_equal(keys2, keys3) + self.assert_keys_equal(keys2, keys3) + def test_jit_repr_hash(self): n = 0 @@ -2141,7 +2210,7 @@ def body_fn(mdl, c): self.assertTrue( jnp.equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1]) ) - with jax_debug_key_reuse(False): + with jax.debug_key_reuse(False): self.assertFalse( jnp.equal( vars['state']['rng_loop'][0], diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index 1aa9d388..56263e1c 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -36,9 +36,9 @@ class PartitioningTest(parameterized.TestCase): def test_axis_rules(self): - self.assertEqual(partitioning._axis_rules.rules, ()) + self.assertEqual(nn.spmd.get_logical_axis_rules(), ()) partitioning.set_axis_rules(AXIS_RULES_1) - self.assertEqual(partitioning._axis_rules.rules, AXIS_RULES_1) + self.assertEqual(nn.spmd.get_logical_axis_rules(), AXIS_RULES_1) self.assertEqual(partitioning.get_axis_rules(), AXIS_RULES_1) partitioning.set_axis_rules(()) diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py deleted file mode 100644 index 4c171026..00000000 --- a/tests/nnx/bridge/module_test.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses - -from absl.testing import absltest -import jax -import jax.numpy as jnp - -from flax import nnx -from flax.nnx import bridge - - -class TestCompatModule(absltest.TestCase): - def test_compact_basic(self): - class Linear(bridge.Module): - dout: int - - def setup(self): - self.count = 0 - - def __call__(self, x): - self.count += 1 - if not hasattr(self, 'w'): - assert self.scope is not None - rngs = self.scope.rngs - self.w = nnx.Param( - jax.random.uniform(rngs(), (x.shape[-1], self.dout)) - ) - self.b = nnx.Param(jnp.zeros((self.dout,))) - return x @ self.w + self.b[None] - - @dataclasses.dataclass - class Foo(bridge.Module): - dout: int - - @bridge.compact - def __call__(self, x): - din = x.shape[-1] - self.linear = Linear(self.dout) - x = self.linear(x) - return x - - foo = Foo(5) - x = jnp.ones((3, 2)) - rngs = nnx.Rngs(0) - - foo._set_scope(bridge.Scope(rngs)) - y = foo(x) - foo._set_scope(None) - - assert y.shape == (3, 5) - assert hasattr(foo, 'Linear_0') - - assert foo.linear is foo.Linear_0 - assert foo.linear.count == 1 - assert rngs.default.count.value == 1 - - foo._set_scope(bridge.Scope(rngs)) - y = foo(x) - foo._set_scope(None) - - assert foo.linear is foo.Linear_0 - assert foo.linear.count == 2 - - # Rngs not called again - assert rngs.default.count.value == 1 - - def test_compact_parent_none(self): - class Foo(bridge.Module): - pass - - class Bar(bridge.Module): - @bridge.compact - def __call__(self): - return Foo().scope - - rngs = nnx.Rngs(0) - bar = Bar() - bar._set_scope(bridge.Scope(rngs)) - scope = bar() - bar._set_scope(None) - assert bar.scope is None - assert scope.rngs is rngs - - class Baz(bridge.Module): - @bridge.compact - def __call__(self): - return Foo(parent=None).scope - - baz = Baz() - baz._set_scope(bridge.Scope(rngs)) - scope = baz() - baz._set_scope(None) - assert scope is None - - def test_name(self): - class Foo(bridge.Module): - dout: int - - def __call__(self, x): - if not hasattr(self, 'w'): - assert self.scope is not None - rngs = self.scope.rngs - self.w = nnx.Param( - jax.random.uniform(rngs(), (x.shape[-1], self.dout)) - ) - return x @ self.w - - class Bar(bridge.Module): - @bridge.compact - def __call__(self, x): - return Foo(5, name='foo')(x) - - bar = Bar() - x = jnp.ones((1, 2)) - rngs = nnx.Rngs(0) - bar._set_scope(bridge.Scope(rngs)) - y = bar(x) - bar._set_scope(None) - assert y.shape == (1, 5) - - assert hasattr(bar, 'foo') - assert isinstance(bar.foo, Foo) - -if __name__ == '__main__': - absltest.main() diff --git a/tests/nnx/containers_test.py b/tests/nnx/containers_test.py index 97785e76..92345abc 100644 --- a/tests/nnx/containers_test.py +++ b/tests/nnx/containers_test.py @@ -21,15 +21,15 @@ class TestContainers(absltest.TestCase): def test_unbox(self): x = nnx.Param( 1, - get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_get_value=lambda c, x: x + 3, # type: ignore ) assert x.value == 4 - def test_box(self): + def test_on_set_value(self): x: nnx.Param[int] = nnx.Param( 1, # type: ignore - set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + on_set_value=lambda c, x: x + 7, # type: ignore ) x.value = 5 @@ -38,9 +38,7 @@ def test_box(self): def test_module_unbox(self): class Foo(nnx.Module): def __init__(self) -> None: - self.x = nnx.Param( - 1, get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] - ) + self.x = nnx.Param(1, on_get_value=lambda c, x: x + 3) module = Foo() @@ -51,7 +49,8 @@ def test_module_box(self): class Foo(nnx.Module): def __init__(self) -> None: self.x = nnx.Param( - 1, set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] + 1, + on_set_value=lambda c, x: x + 7, # type: ignore ) module = Foo() diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index bfbb7046..a7bbf178 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -82,6 +82,17 @@ def test_unflatten(self): assert g[0] is g[2] + def test_unflatten_pure_dict(self): + a = Dict(a=1, b=nnx.Param(2)) + g = List([a, 3, a, nnx.Param(4)]) + + graphdef, state = nnx.split(g) + pure_state = state.to_pure_dict() + + g = nnx.merge(graphdef, pure_state) + + assert g[0] is g[2] + def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] @@ -107,7 +118,20 @@ def test_update_dynamic(self): graphdef, state = nnx.split(g) state[0]['b'].value = 3 - nnx.graph.update(g, state) + nnx.update(g, state) + + assert g[0]['b'].value == 3 + assert g[2]['b'].value == 3 + + def test_update_from_pure_dict(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + graphdef, state = nnx.split(g) + pure_state = state.to_pure_dict() + + pure_state[0]['b'] = 3 + nnx.update(g, pure_state) assert g[0]['b'].value == 3 assert g[2]['b'].value == 3 @@ -279,7 +303,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.subgraphs['tree'].type is nnx.graph.PytreeType + assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state) diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 1742e379..7b572f4b 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import typing as tp from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np +import orbax.checkpoint as ocp from flax import nnx @@ -259,6 +261,44 @@ def __call__(self, x): assert 'y' in intermediates + def test_replace_by_pure_dict(self): + class MLPs(nnx.Module): + def __init__(self, dim, rngs: nnx.Rngs): + self.layers = [] + for _ in range(4): + self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False)) + + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + model = MLPs(4, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(42), (3, 4)) + assert model(x).shape == (3, 4) + + _, state = nnx.split(model) + pure_dict_state = state.to_pure_dict() + nnx.display(pure_dict_state) + + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_dir = ocp.test_utils.erase_and_create_empty( + tmpdir + '/my-checkpoints/' + ) + checkpointer = ocp.StandardCheckpointer() + # checkpointer.save(ckpt_dir / 'state', state) + checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) + + # Restore as a pure dictionary. + restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') + nnx.display(restored_pure_dict) + + abstract_model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0))) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(restored_pure_dict) + model = nnx.merge(graphdef, abstract_state) + assert model(x).shape == (3, 4) # The model still works! + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 2aff69a1..ce65186d 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -14,9 +14,12 @@ from copy import deepcopy import dataclasses +import pickle +import tempfile from typing import TypeVar from absl.testing import absltest +import cloudpickle from flax import nnx, errors import jax import jax.numpy as jnp @@ -37,7 +40,7 @@ def __setitem__(self, idx, value): class Dict(nnx.Module): def __init__(self, *args, **kwargs): - self.items = dict(*args, **kwargs) + vars(self)['items'] = dict(*args, **kwargs) def __getitem__(self, key): return vars(self)['items'][key] @@ -45,6 +48,12 @@ def __getitem__(self, key): def __setitem__(self, key, value): vars(self)['items'][key] = value + def __setattr__(self, key, value): + if key == 'items': + object.__setattr__(self, key, value) + else: + vars(self)['items'][key] = value + def __getattr__(self, key): attrs = vars(self) if 'items' not in attrs: @@ -62,6 +71,7 @@ class Foo(nnx.Module): ... assert hasattr(foo, '_object__state') + @absltest.skip("Context checking doesn't work yet with stackless") def test_trace_level(self): m = Dict(a=nnx.Param(1)) @@ -512,6 +522,34 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + def test_cloud_pickle(self): + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + model.eval() + + y1 = model(jnp.ones((5, 2))) + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/model.pkl' + with open(path, 'wb') as f: + cloudpickle.dump(model, f) + del model + with open(path, 'rb') as f: + model = pickle.load(f) + + self.assertIsInstance(model, Model) + y2 = model(jnp.ones((5, 2))) + np.testing.assert_allclose(y1, y2) + class TestModulePytree: def test_tree_map(self): diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py new file mode 100644 index 00000000..b724b69d --- /dev/null +++ b/tests/nnx/nn/recurrent_test.py @@ -0,0 +1,543 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax, jax.numpy as jnp +from jax import random + +from flax import linen +from flax import nnx +from flax.nnx.nn import initializers + +import numpy as np + +from absl.testing import absltest + +class TestLSTMCell(absltest.TestCase): + def test_basic(self): + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3)) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.shape, (2, 4)) + + def test_lstm_sequence(self): + """Test LSTMCell over a sequence of inputs.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x = random.normal(random.PRNGKey(1), (5, 2, 3)) # seq_len, batch, feature + carry = module.initialize_carry(x.shape[1:], module.rngs) + outputs = [] + for t in range(x.shape[0]): + carry, y = module(carry, x[t]) + outputs.append(y) + outputs = jnp.stack(outputs) + self.assertEqual(outputs.shape, (5, 2, 4)) + + def test_lstm_with_different_dtypes(self): + """Test LSTMCell with different data types.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + dtype=jnp.bfloat16, + param_dtype=jnp.bfloat16, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((2, 3), dtype=jnp.bfloat16) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.dtype, jnp.bfloat16) + self.assertEqual(y.shape, (2, 4)) + + def test_lstm_with_custom_activations(self): + """Test LSTMCell with custom activation functions.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + gate_fn=jax.nn.relu, + activation_fn=jax.nn.elu, + rngs=nnx.Rngs(0), + ) + x = jnp.ones((1, 3)) + carry = module.initialize_carry(x.shape, module.rngs) + new_carry, y = module(carry, x) + self.assertEqual(y.shape, (1, 4)) + + def test_lstm_initialize_carry(self): + """Test the initialize_carry method.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + carry_init=initializers.ones, + rngs=nnx.Rngs(0), + ) + x_shape = (1, 3) + carry = module.initialize_carry(x_shape, module.rngs) + c, h = carry + self.assertTrue(jnp.all(c == 1.0)) + self.assertTrue(jnp.all(h == 1.0)) + self.assertEqual(c.shape, (1, 4)) + self.assertEqual(h.shape, (1, 4)) + + def test_lstm_with_variable_sequence_length(self): + """Test LSTMCell with variable sequence lengths.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0) + ) + + # Simulate a batch with variable sequence lengths + x = jnp.array([ + [[1, 2, 3], [4, 5, 6], [0, 0, 0]], # Sequence length 2 + [[7, 8, 9], [10, 11, 12], [13, 14, 15]], # Sequence length 3 + ]) # Shape: (batch_size=2, max_seq_length=3, features=3) + + seq_lengths = jnp.array([2, 3]) # Actual lengths for each sequence + batch_size = x.shape[0] + max_seq_length = x.shape[1] + carry = module.initialize_carry((batch_size, 3), module.rngs) + outputs = [] + for t in range(max_seq_length): + input_t = x[:, t, :] + carry, y = module(carry, input_t) + outputs.append(y) + outputs = jnp.stack(outputs, axis=1) # Shape: (batch_size, max_seq_length, hidden_features) + + # Zero out outputs beyond the actual sequence lengths + mask = (jnp.arange(max_seq_length)[None, :] < seq_lengths[:, None]) + outputs = outputs * mask[:, :, None] + self.assertEqual(outputs.shape, (2, 3, 4)) + + def test_lstm_stateful(self): + """Test that LSTMCell maintains state across calls.""" + module = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + x1 = jnp.ones((1, 3)) + x2 = jnp.ones((1, 3)) * 2 + carry = module.initialize_carry(x1.shape) + carry, y1 = module(carry, x1) + carry, y2 = module(carry, x2) + self.assertEqual(y1.shape, (1, 4)) + self.assertEqual(y2.shape, (1, 4)) + + def test_lstm_equivalence_with_flax_linen(self): + """Test that nnx.LSTMCell produces the same outputs as flax.linen.LSTMCell.""" + in_features = 3 + hidden_features = 4 + key = random.PRNGKey(42) + x = random.normal(key, (1, in_features)) + + # Initialize nnx.LSTMCell + rngs_nnx = nnx.Rngs(0) + module_nnx = nnx.LSTMCell( + in_features=in_features, + hidden_features=hidden_features, + rngs=rngs_nnx, + ) + carry_nnx = module_nnx.initialize_carry(x.shape, rngs_nnx) + # Initialize flax.linen.LSTMCell + module_linen = linen.LSTMCell( + features=hidden_features, + ) + carry_linen = module_linen.initialize_carry(random.PRNGKey(0), x.shape) + variables_linen = module_linen.init(random.PRNGKey(1), carry_linen, x) + + # Copy parameters from flax.linen.LSTMCell to nnx.LSTMCell + params_linen = variables_linen['params'] + # Map the parameters from linen to nnx + # Assuming the parameter names and shapes are compatible + # For a precise mapping, you might need to adjust parameter names + # Get the parameters from nnx module + nnx_params = module_nnx.__dict__ + + # Map parameters from linen to nnx + for gate in ['i', 'f', 'g', 'o']: + # Input kernels (input to gate) + if gate == 'f': + nnx_layer = getattr(module_nnx, f'if_') + else: + nnx_layer = getattr(module_nnx, f'i{gate}') + linen_params = params_linen[f'i{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + # Hidden kernels (hidden state to gate) + nnx_layer = getattr(module_nnx, f'h{gate}') + linen_params = params_linen[f'h{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + + # Run both modules + new_carry_nnx, y_nnx = module_nnx(carry_nnx, x) + new_carry_linen, y_linen = module_linen.apply(variables_linen, carry_linen, x) + + # Compare outputs + np.testing.assert_allclose(y_nnx, y_linen, atol=1e-5) + # Compare carries + for c_nnx, c_linen in zip(new_carry_nnx, new_carry_linen): + np.testing.assert_allclose(c_nnx, c_linen, atol=1e-5) + +class TestRNN(absltest.TestCase): + + def test_rnn_with_lstm_cell(self): + """Test RNN module using LSTMCell.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(0), + ) + + # Initialize the RNN module with the LSTMCell + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_with_gru_cell(self): + """Test RNN module using GRUCell.""" + # Initialize the GRUCell + cell = nnx.GRUCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(1), + ) + + # Initialize the RNN module with the GRUCell + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_time_major(self): + """Test RNN module with time_major=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(2), + ) + + # Initialize the RNN module with time_major=True + rnn = nnx.RNN(cell, time_major=True) + + # Create input data (seq_length=5, batch_size=2, features=3) + x = jnp.ones((5, 2, 3)) + + # Initialize the carry + carry = cell.initialize_carry(x.shape[1:2] + x.shape[2:], cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (5, 2, 4)) # Output features should match hidden_features + + def test_rnn_reverse(self): + """Test RNN module with reverse=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(3), + ) + + # Initialize the RNN module with reverse=True + rnn = nnx.RNN(cell, reverse=True) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.tile(jnp.arange(5), (2, 1)).reshape(2, 5, 1) # Distinct values to check reversal + x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) + + # Run the RNN module + outputs = rnn(x) + + # Check if the outputs are in reverse order + outputs_reversed = outputs[:, ::-1, :] + # Since we used distinct input values, we can compare outputs to check reversal + # For simplicity, just check the shapes here + self.assertEqual(outputs.shape, (2, 5, 4)) + self.assertEqual(outputs_reversed.shape, (2, 5, 4)) + + def test_rnn_with_seq_lengths(self): + """Test RNN module with variable sequence lengths.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(4), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell, return_carry=True) + + # Create input data with padding (batch_size=2, seq_length=5, features=3) + x = jnp.array([ + [[1, 1, 1], [2, 2, 2], [3, 3, 3], [0, 0, 0], [0, 0, 0]], # Sequence length 3 + [[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7], [8, 8, 8]], # Sequence length 5 + ]) # Shape: (2, 5, 3) + + seq_lengths = jnp.array([3, 5]) # Actual lengths for each sequence + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + final_carry, outputs = rnn(x, initial_carry=carry, seq_lengths=seq_lengths) + + self.assertEqual(outputs.shape, (2, 5, 4)) + + self.assertEqual(final_carry[0].shape, (2, 4)) # c: (batch_size, hidden_features) + self.assertEqual(final_carry[1].shape, (2, 4)) # h: (batch_size, hidden_features) + + # Todo: a better test by matching the outputs with the expected values + + def test_rnn_with_keep_order(self): + """Test RNN module with reverse=True and keep_order=True.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(5), + ) + + # Initialize the RNN module with reverse=True and keep_order=True + rnn = nnx.RNN(cell, reverse=True, keep_order=True) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.tile(jnp.arange(5), (2, 1)).reshape(2, 5, 1) # Distinct values to check reversal + x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + # Check if the outputs are in the original order despite processing in reverse + self.assertEqual(outputs.shape, (2, 5, 4)) + + def test_rnn_equivalence_with_flax_linen(self): + """Test that nnx.RNN produces the same outputs as flax.linen.RNN.""" + in_features = 3 + hidden_features = 4 + seq_length = 5 + batch_size = 2 + key = random.PRNGKey(42) + + # Create input data + x = random.normal(key, (batch_size, seq_length, in_features)) + + # Initialize nnx.LSTMCell and RNN + rngs_nnx = nnx.Rngs(0) + cell_nnx = nnx.LSTMCell( + in_features=in_features, + hidden_features=hidden_features, + rngs=rngs_nnx, + ) + rnn_nnx = nnx.RNN(cell_nnx) + + # Initialize flax.linen.LSTMCell and RNN + cell_linen = linen.LSTMCell(features=hidden_features) + rnn_linen = linen.RNN(cell_linen) + carry_linen = cell_linen.initialize_carry(random.PRNGKey(0), x[:, 0].shape) + variables_linen = rnn_linen.init(random.PRNGKey(1), x) + + # Copy parameters from flax.linen to nnx + params_linen = variables_linen['params']['cell'] + # Copy cell parameters + for gate in ['i', 'f', 'g', 'o']: + # Input kernels + if gate == 'f': + nnx_layer = getattr(cell_nnx, f'if_') + else: + nnx_layer = getattr(cell_nnx, f'i{gate}') + linen_params = params_linen[f'i{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + # Hidden kernels + nnx_layer = getattr(cell_nnx, f'h{gate}') + linen_params = params_linen[f'h{gate}'] + nnx_layer.kernel.value = linen_params['kernel'] + if nnx_layer.use_bias: + nnx_layer.bias.value = linen_params['bias'] + + # Initialize carries + carry_nnx = cell_nnx.initialize_carry((batch_size, in_features), rngs_nnx) + + # Run nnx.RNN + outputs_nnx = rnn_nnx(x, initial_carry=carry_nnx) + + # Run flax.linen.RNN + outputs_linen = rnn_linen.apply(variables_linen, x, initial_carry=carry_linen) + + # Compare outputs + np.testing.assert_allclose(outputs_nnx, outputs_linen, atol=1e-5) + + def test_rnn_with_unroll(self): + """Test RNN module with unroll parameter.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(6) + ) + + # Initialize the RNN module with unroll=2 + rnn = nnx.RNN(cell, unroll=2) + + # Create input data (batch_size=2, seq_length=6, features=3) + x = jnp.ones((2, 6, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 6, 4)) # Output features should match hidden_features + + def test_rnn_with_custom_cell(self): + """Test RNN module with a custom RNN cell.""" + class CustomRNNCell(nnx.Module): + """A simple custom RNN cell.""" + + in_features: int + hidden_features: int + + def __init__(self, in_features, hidden_features, rngs): + self.in_features = in_features + self.hidden_features = hidden_features + self.rngs = rngs + self.dense = nnx.Linear( + in_features=in_features + hidden_features, + out_features=hidden_features, + rngs=rngs, + ) + + def __call__(self, carry, inputs): + h = carry + x = jnp.concatenate([inputs, h], axis=-1) + new_h = jax.nn.tanh(self.dense(x)) + return new_h, new_h + + def initialize_carry(self, input_shape, rngs): + batch_size = input_shape[0] + h = jnp.zeros((batch_size, self.hidden_features)) + return h + + @property + def num_feature_axes(self) -> int: + return 1 + + # Initialize the custom RNN cell + cell = CustomRNNCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(7) + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (2, 5, 4)) # Output features should match hidden_features + + def test_rnn_with_different_dtypes(self): + """Test RNN module with different data types.""" + # Initialize the LSTMCell with float16 + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + dtype=jnp.float16, + param_dtype=jnp.float16, + rngs=nnx.Rngs(8), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + # Create input data (batch_size=2, seq_length=5, features=3) + x = jnp.ones((2, 5, 3), dtype=jnp.float16) + + # Initialize the carry + carry = cell.initialize_carry((2, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.dtype, jnp.float16) + self.assertEqual(outputs.shape, (2, 5, 4)) + + def test_rnn_with_variable_batch_size(self): + """Test RNN module with variable batch sizes.""" + # Initialize the LSTMCell + cell = nnx.LSTMCell( + in_features=3, + hidden_features=4, + rngs=nnx.Rngs(9), + ) + + # Initialize the RNN module + rnn = nnx.RNN(cell) + + for batch_size in [1, 2, 5]: + # Create input data (batch_size, seq_length=5, features=3) + x = jnp.ones((batch_size, 5, 3)) + + # Initialize the carry + carry = cell.initialize_carry((batch_size, 3), cell.rngs) + + # Run the RNN module + outputs = rnn(x, initial_carry=carry) + + self.assertEqual(outputs.shape, (batch_size, 5, 4)) + +if __name__ == '__main__': + absltest.main() diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index c74160cc..5c28e727 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -128,6 +128,58 @@ def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): self.assertTrue(new_loss < initial_loss) + + @parameterized.product( + module_cls=[nnx.Linear, Model], + jit_decorator=[lambda f: f, nnx.jit, jax.jit], + optimizer=[optax.lbfgs], + ) + def test_jit_linesearch(self, module_cls, jit_decorator, optimizer): + x = jax.random.normal(jax.random.key(0), (1, 2)) + y = jnp.ones((1, 4)) + model = module_cls(2, 4, rngs=nnx.Rngs(0)) + tx = optimizer( + 1e-3 + ) + state = nnx.Optimizer(model, tx) + + if jit_decorator == jax.jit: + model_static, model_state = nnx.split(state.model) + loss_fn = lambda graphdef, state, x, y: ( + (nnx.merge(graphdef, state)(x) - y) ** 2 + ).mean() + initial_loss = loss_fn(model_static, model_state, x, y) + + def jax_jit_train_step(graphdef, state, x, y): + state = nnx.merge(graphdef, state) + model_static, model_state = nnx.split(state.model) + grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y) + state.update(grads, grad = grads, value = initial_loss, value_fn = lambda state: loss_fn(model_static, state, x, y)) + return nnx.split(state) + + graphdef, state = jit_decorator(jax_jit_train_step)( + *nnx.split(state), x, y + ) + state = nnx.merge(graphdef, state) + new_loss = loss_fn(*nnx.split(state.model), x, y) + + else: + graphdef = nnx.graphdef(model) + loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() + + loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) + + initial_loss = loss_fn(state.model, x, y) + + def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): + grads = nnx.grad(loss_fn)(optimizer.model, x, y) + optimizer.update(grads, grad = grads, value = initial_loss, value_fn = loss_fn_split) + + jit_decorator(nnx_jit_train_step)(state, x, y) + new_loss = loss_fn(state.model, x, y) + + self.assertTrue(new_loss < initial_loss) + @parameterized.product( module_cls=[nnx.Linear, Model], optimizer=[optax.sgd, optax.adam], @@ -203,6 +255,55 @@ def test_wrt_update(self, variable): ) ) + @parameterized.parameters( + {'variable': nnx.Param}, + #{'variable': nnx.LoRAParam}, + {'variable': (nnx.Param, nnx.LoRAParam)}, + ) + def test_wrt_update_linesearch(self, variable): + in_features = 4 + out_features = 10 + model = nnx.LoRA( + in_features=in_features, + lora_rank=2, + out_features=out_features, + base_module=Model( + in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) + ), + rngs=nnx.Rngs(1), + ) + state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable) + prev_variables, prev_other_variables = nnx.state(model, variable, ...) + + x = jnp.ones((1, 4)) + y = jnp.ones((1, 10)) + loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() + + grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( + state.model, x, y + ) + initial_loss = loss_fn(model, x, y) + graphdef = nnx.graphdef(model) + loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) + + state.update(grads, grad=grads, value_fn = loss_fn_split, value = initial_loss) + self.assertTrue(loss_fn(model, x, y) < initial_loss) + + # make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged + variables, other_variables = nnx.state(model, variable, ...) + self.assertTrue( + jax.tree.all( + jax.tree.map(lambda x, y: (x != y).all(), prev_variables, variables) + ) + ) + if other_variables: + self.assertTrue( + jax.tree.all( + jax.tree.map( + lambda x, y: (x == y).all(), prev_other_variables, other_variables + ) + ) + ) if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index eeb65cca..d3eb2197 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -53,6 +53,7 @@ def test_rng_stream(self): self.assertIs(rngs.params.key.value, key0) self.assertFalse(jnp.allclose(key1, key2)) + @absltest.skip("Context checking doesn't work yet with stackless") def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 6a202e81..2372fbad 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -13,13 +13,13 @@ # limitations under the License. from absl.testing import absltest +import flax +from flax import nnx import jax -import jax.numpy as jnp -import optax from jax.experimental import mesh_utils +import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec - -from flax import nnx +import optax class TestSPMD(absltest.TestCase): @@ -112,19 +112,20 @@ class MLP(nnx.Module): ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( - 3, - 3, - kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding=('din', 'dout'), - add_axis_hooks=lambda _, idx, name: kadds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: kremoves.append((idx, name)), - ), - bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), # no sharding annotation here! - add_axis_hooks=lambda _, idx, name: badds.append((idx, name)), - remove_axis_hooks=lambda _, idx, name: bremoves.append((idx, name)), - ), - rngs=rngs, + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=('din', 'dout'), + on_add_axis=lambda _, idx, name: kadds.append((idx, name)), + on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), # no sharding annotation here! + on_add_axis=lambda _, idx, name: badds.append((idx, name)), + on_remove_axis=lambda _, idx, name: bremoves.append((idx, name)), + ), + rngs=rngs, ) @nnx.scan( @@ -158,7 +159,39 @@ def __call__(self, x: jax.Array): self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(bremoves, [(0, 'layers')]) + def test_logical_rules(self): + class Foo(nnx.Module): + + def __init__(self): + self.w = nnx.Param( + nnx.with_partitioning( + lambda: jnp.ones((8, 2)), + sharding=('row-alias', 'col-alias'), + sharding_rules=(('row-alias', 'row'),), + )() + ) + self.b = nnx.Param( + nnx.with_partitioning( + lambda: jnp.zeros((2,)), sharding=('col-alias',) + )() + ) + + def __call__(self, x): + return x @ self.w + self.b + + graphdef, params = nnx.split(Foo()) + state = nnx.TrainState.create( + graphdef, + params=params, + tx=optax.adam(1e-3), + ) + with flax.core.spmd.logical_axis_rules((('col-alias', 'col'),)): + state_spec = nnx.get_partition_spec(state) + + assert state_spec.params['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') + if __name__ == '__main__': absltest.main() - diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index b29faeb4..736da9ac 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -21,11 +21,12 @@ from flax import nnx from flax.nnx.transforms import general import jax -from jax.experimental import mesh_utils +from jax.experimental import mesh_utils, checkify import jax.numpy as jnp import numpy as np + class List(nnx.Module): def __init__(self, items): vars(self).update({str(i): item for i, item in enumerate(items)}) @@ -598,7 +599,7 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): self.assertIn('bias', grads_m2[0]) -class TestCustomVJP(absltest.TestCase): +class TestCustomVJP(parameterized.TestCase): def test_basic_call(self): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) @@ -644,16 +645,16 @@ def f_fwd(m: Foo): return y, res def f_bwd(res, g): - inputs_g, out_g = g + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(inputs_g, tuple) - self.assertLen(inputs_g, 1) - self.assertIsInstance(inputs_g[0], nnx.State) + self.assertIsInstance(m_g, nnx.State) self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) @@ -666,6 +667,92 @@ def f_bwd(res, g): np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore self.assertEqual(m.z, 1) + def test_diff_state(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + x_in_path = nnx.PathContains('x') + diff_state = nnx.DiffState(0, x_in_path) + + @nnx.custom_vjp(nondiff_argnums=(diff_state,)) + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x), m) # type: ignore + return y, res + + def f_bwd(res, g): + (m_g,), out_g = g + cos_x, m = res + + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + m_g.x.value = cos_x * out_g * m.y + del m_g['y'] + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, x_in_path))(m) + + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + self.assertEqual(m.z, 1) + + def test_jax_example_with_remat(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp + @nnx.remat + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x.value) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x.value), jnp.sin(m.x.value), m) # type: ignore + return y, res + + def f_bwd(res, g): + (m_g,), out_g = g + cos_x, sin_x, m = res + + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + @nnx.jit + def loss_fn(m): + return f(m) + + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) + + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) + def test_two_args(self): @dataclasses.dataclass class Foo(nnx.Module): @@ -726,45 +813,49 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp(nondiff_argnums=(1, 2)) - def f(m1: Foo, m2: Foo, m3): - m1.z += 1 - y = jnp.sin(m1.x) * m1.y # type: ignore - return y, m2 + @nnx.custom_vjp(nondiff_argnums=(0, 2)) + def f(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore - def f_fwd(m1: Foo, m2: Foo, m3): - y, m2 = f(m1, m2, m3) - res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) # type: ignore - return (y, m2), res + def f_fwd(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + y = f(a, m, b) + res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + return y, res - def f_bwd(m2, m3, res, g): - (m1_g, m2_g, m3_g), (y_g, _) = g + def f_bwd(a, b, res, g): + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(m1_g, nnx.State) - self.assertIsInstance(m2_g, nnx.State) - self.assertEqual(y_g.shape, ()) + self.assertEqual(a, 1) + self.assertEqual(b, 2) + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) - - return (m1_g,) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) f.defvjp(f_fwd, f_bwd) - m1 = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - m2 = Foo(nnx.Param(jnp.array(3.0)), nnx.Param(jnp.array(4.0)), 0) + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - def loss_fn(m1, m2, m3): - y, m2 = f(m1, m2, m3) - return y + m2.x * m2.y + def loss_fn(m): + a = 1 + b = 2 + return f(a, m, b) - m1_grad: nnx.State - m1_grad = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m1, m2, m2) + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) - np.testing.assert_allclose(m1_grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore - np.testing.assert_allclose(m1_grad['y'].value, jnp.sin(1.0)) # type: ignore - self.assertEqual(m1.z, 1) + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) def test_docs_example(self): import jax.numpy as jnp @@ -794,6 +885,60 @@ def f_bwd(res, g): m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) grads = nnx.grad(f)(m) + @parameterized.parameters( + {'use_custom_vjp': False}, + {'use_custom_vjp': True}, + ) + def test_issue(self, use_custom_vjp: bool): + class MyLinear(nnx.Module): + def __init__( + self, in_features: int, out_features: int, *, rngs: nnx.Rngs + ): + kernel_init = nnx.initializers.normal(in_features**-0.5) + self.kernel = nnx.Param( + kernel_init(rngs.params(), (in_features, out_features), jnp.float32) + ) + self.bias = nnx.Param(jnp.zeros((out_features,), jnp.float32)) + self.n = nnx.BatchStat(jnp.array(0, jnp.uint32)) + + def linear(m: MyLinear, x: jax.Array) -> jax.Array: + m.n.value += 1 + y = x @ m.kernel + m.bias + return y + + def linear_fwd(m: MyLinear, x: jax.Array): + return linear(m, x), (m, x) + + def linear_bwd(res, g): + m, x = res + (m_g, _x_grad), outputs_g = g + kernel_grad = outputs_g[None, :] * x[:, None] + bias_grad = outputs_g + x_grad = m.kernel @ outputs_g + assert x_grad.shape == x.shape, 'Shape mismatch for x' + assert ( + m.kernel.value.shape == kernel_grad.shape + ), 'Shape mismatch for kernel' + assert m.bias.value.shape == bias_grad.shape, 'Shape mismatch for bias' + return (m_g, x_grad) + + if use_custom_vjp: + linear = nnx.custom_vjp(linear) + linear.defvjp(linear_fwd, linear_bwd) + + @nnx.jit + def loss_fn(x, mod): + y = linear(mod, x) + return y.mean() + + mod = MyLinear(10, 5, rngs=nnx.Rngs(0)) + self.assertEqual(mod.n.value, 0) + x = jax.random.normal(jax.random.key(0), (10,)) + loss, grad = nnx.value_and_grad(loss_fn)(x, mod) + self.assertEqual(loss.shape, ()) + self.assertEqual(grad.shape, (10,)) + self.assertEqual(mod.n.value, 1) + class TestScan(absltest.TestCase): def test_basic(self): @@ -1673,7 +1818,6 @@ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: x = jnp.ones((16, 10, 20)) y = rnn_forward(cell, x) - print(y.shape) class TestRemat(absltest.TestCase): @@ -2135,7 +2279,6 @@ def f(m): def test_consistent_aliasing_shared(self): class Shared(nnx.Module): - def __init__(self): self.a = nnx.Param(jnp.zeros((3, 3))) @@ -2148,17 +2291,46 @@ def __init__(self, shared: Shared): m1 = Foo(shared) m2 = Foo(shared) - @partial(nnx.vmap, in_axes=(0, 1)) + @nnx.vmap(in_axes=(0, 1)) def f(m1, m2): pass with self.assertRaisesRegex( - ValueError, - r'Inconsistent aliasing detected([\s\S]*)Shared([\s\S]*)a:' - r' 0([\s\S]*)a: 1', + ValueError, + r'Inconsistent aliasing detected([\s\S]*)Param([\s\S]*)a:' + r' 0([\s\S]*)a: 1', ): f(m1, m2) + def test_equivalent_state_axes_mapping(self): + m = nnx.Linear(3, 3, rngs=nnx.Rngs(0)) + + sa1 = nnx.StateAxes({...: 0}) + sa2 = nnx.StateAxes({nnx.Param: 0}) + + @nnx.vmap(in_axes=(0, sa1, sa2)) + def f(m1, m2, m3): + pass + + f(m, m, m) + + def test_equivalent_state_sharding_mapping(self): + m = nnx.Linear(3, 3, rngs=nnx.Rngs(0)) + + mesh = jax.sharding.Mesh(jax.devices(), ('mp',)) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('mp') + ) + + sa1 = nnx.StateSharding({...: sharding}) + sa2 = nnx.StateSharding({nnx.Param: sharding}) + + @nnx.jit(in_shardings=(sharding, sa1, sa2)) + def f(m1, m2, m3): + pass + + f(m, m, m) + @absltest.skip('Enable once jax#19586 resolved') def test_captured_module_in_return_error(self): class Foo(nnx.Module): @@ -2317,6 +2489,44 @@ def create_block(rngs: nnx.Rngs): self.assertEqual(m.kernel.value.shape, (5, 16, 32)) self.assertEqual(m.kernel.sharding, ('c', 'a', 'b')) + def test_state_axes_from_state(self): + class Model(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + model = Model(2, 3, rngs=nnx.Rngs(0)) + state = nnx.state(model) + + state['linear']['kernel'] = 0 + state['linear']['bias'] = 1 + state['bn']['scale'] = 0 + state['bn']['mean'] = 1 + state['bn']['var'] = 0 + state['bn']['bias'] = None + + state_axes = nnx.StateAxes(state) + + self.assertEqual(state_axes.map_prefix(('linear', 'kernel'), None), 0) + self.assertEqual(state_axes.map_prefix(('linear', 'bias'), None), 1) + self.assertEqual(state_axes.map_prefix(('bn', 'scale'), None), 0) + self.assertEqual(state_axes.map_prefix(('bn', 'mean'), None), 1) + self.assertEqual(state_axes.map_prefix(('bn', 'var'), None), 0) + self.assertEqual(state_axes.map_prefix(('bn', 'bias'), None), None) + + @nnx.vmap(out_axes=state_axes, axis_size=5) + def create_block(): + return Model(2, 3, rngs=nnx.Rngs(0)) + + model = create_block() + + self.assertEqual(model.linear.kernel.shape, (5, 2, 3)) + self.assertEqual(model.linear.bias.shape, (3, 5)) + self.assertEqual(model.bn.scale.shape, (5, 3)) + self.assertEqual(model.bn.mean.shape, (3, 5)) + self.assertEqual(model.bn.var.shape, (5, 3)) + self.assertEqual(model.bn.bias.shape, (3,)) + class TestPmap(absltest.TestCase): @@ -2546,6 +2756,236 @@ def no_nothing(env: Env): ) +class TestSwitch(absltest.TestCase): + def test_basic(self): + class RoundTable(nnx.Module): + def __init__(self): + self.next_index = 0 + self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + self.linear.kernel.value = jnp.identity(10) + self.rounds_count = nnx.Variable(jnp.array(0)) + + def __call__(self, x): + def fn0(m, x): + m.rounds_count += 1 + return m.linear(x) + def fn1(m, x): + return m.linear(x) * 2 + def fn2(m, x): + m.linear.kernel.value = jnp.zeros((10, 10)) + return m.linear(x) + + # y = nnx.cond(self.next_index.value == 0, fn0, fn1, self, x) + y = nnx.switch(self.next_index, (fn0, fn1, fn2), self, x) + self.next_index = (self.next_index + 1) % 3 + return y + + model = RoundTable() + x = jnp.ones((10,)) + np.testing.assert_array_equal(model(x), x) + assert model.rounds_count.value == 1 + assert model.next_index == 1 + np.testing.assert_array_equal(model(x), x * 2) + assert model.rounds_count.value == 1 + assert model.next_index == 2 + np.testing.assert_array_equal(model(x), jnp.zeros((10,))) + assert model.rounds_count.value == 1 + assert model.next_index == 0 + np.testing.assert_array_equal(model(x), jnp.zeros((10,))) + assert model.rounds_count.value == 2 + assert model.next_index == 1 + + +class TestWhileLoop(absltest.TestCase): + def test_basic(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + module.kernel.value = jnp.identity(10) * 2 + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(y, x * 8) + + def test_multiple_objects(self): + def fwd_fn(input): + m1, (w2,), x, c = input + y = m1(x) @ w2 + return m1, (w2,), y, c - 1.0 + + m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + m1.kernel.value = jnp.identity(10) * 2 + w2 = nnx.Variable(jnp.identity(10) * 0.5) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0)) + np.testing.assert_allclose(y, x) + + def test_nested_module(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + module.kernel.value = jnp.identity(10) * 2 + module = nnx.Sequential(module) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(y, x * 8) + + + def test_shared_module(self): + m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0)) + m2.kernel = m1.kernel + module = nnx.Sequential(m1, m2) + self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + + def fwd_fn(input): + m, x, c = input + y = m(x) + m.layers[0].kernel.value = jnp.zeros_like(m.layers[0].kernel.value) + return m, y, c - 1.0 + + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) + self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + np.testing.assert_array_equal(m1.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(m2.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(y, jnp.zeros((10,))) + + + def test_value_changed(self): + def fwd_fn(input): + m, x, c = input + m.kernel.value = jnp.zeros_like(m.kernel.value) + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(module.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(y, jnp.zeros((10,))) + + + def test_ref_changed(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + m.kernel = nnx.Param(jnp.zeros_like(m.kernel.value)) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + with self.assertRaises(ValueError): + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) + + + def test_structure_changed(self): + def fwd_fn(input): + m, x, c = input + m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1)) + m.kernel.value = jnp.identity(10) * 2 + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + with self.assertRaises(ValueError): + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + + def test_repeated_object(self): + m = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + + def body_fn(val): + count, m, _ = val + return count + 1, m, m + + count, m, _ = nnx.while_loop( + lambda val: val[0] < 2, + body_fn, + (0, m, m), + ) + + def test_fori_loop_basic(self): + def fwd_fn(i, input): + m, x = input + m.kernel.value = jnp.identity(10) * i + return m, m(x) + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (10,)) + + _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) + np.testing.assert_array_equal(y, x * 2 * 3) + + def test_fori_loop_with_sharing(self): + class A(nnx.Object): + def __init__(self): + self.params = nnx.Param(jnp.zeros((10,), dtype=int)) + + class B(nnx.Object): + def __init__(self, a: A): + self.a = a + + class C(nnx.Object): + def __init__(self, a: A): + self.a = a + + class D(nnx.Object): + def __init__(self): + self.a = A() + self.b = B(self.a) + self.c = C(self.a) + + def increment(_, d: D) -> D: + d.a.params += 1 + return d + + @nnx.jit + def rollout(d: D): + nnx.fori_loop(0, 10, increment, d) + + d = D() + rollout(d) + + np.testing.assert_array_equal( + d.a.params.value, np.full((10,), 10, dtype=int) + ) + + def test_loops_multiple_modules(self): + class Foo(nnx.Module): + def __init__(self): + self.param = nnx.Param(jnp.zeros((1,))) + def __call__(self, x): + return self.param + + def loop_fn(inputs): + return inputs + while_loop_fn = lambda inputs: (*loop_fn(inputs[:-1]), inputs[-1]-1) + fori_loop_fn = lambda i, inputs: loop_fn(inputs) + a = Foo() + b = Foo() + nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2)) + nnx.fori_loop(0, 2, fori_loop_fn, (a, b)) + + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): class StatefulLinear(nnx.Linear): @@ -2635,6 +3075,22 @@ def no_nothing(env: Env): env.step.value, np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) ) +class TestCheckify(absltest.TestCase): + def test_basic(self): + @dataclasses.dataclass + class Foo(nnx.Module): + a: nnx.Param + + @nnx.jit + def f(m): + y = jnp.sin(m.a.value) # error + return m.a + y + + m = Foo(a=nnx.Param(jnp.inf)) + err, out = nnx.checkify(f, errors=checkify.float_checks)(m) + + with self.assertRaisesRegex(ValueError, 'nan generated by primitive: sin'): + err.throw() if __name__ == '__main__': absltest.main() diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index e2ded604..920d7101 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -108,7 +108,6 @@ assert_error="flax is not running on editable mode." # env vars must be set after doctest export JAX_NUMPY_RANK_PROMOTION=raise export FLAX_PROFILE=1 -export FLAX_LAZY_RNG=1 if $RUN_PYTEST; then echo "=== RUNNING PYTESTS ===" diff --git a/tests/struct_test.py b/tests/struct_test.py index 6c80caea..9bb4986c 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -15,7 +15,6 @@ """Tests for flax.struct.""" import dataclasses -import functools from typing import Any import jax @@ -49,8 +48,8 @@ def test_mutation(self): p.y = 3 def test_slots(self): - slots_dataclass = functools.partial(struct.dataclass, frozen=False, slots=True) - @slots_dataclass + + @struct.dataclass(frozen=False, slots=True) class SlotsPoint: x: float y: float @@ -100,7 +99,7 @@ def test_kw_only(self, mode): class A: a: int = 1 - @functools.partial(struct.dataclass, kw_only=True) + @struct.dataclass(kw_only=True) class B(A): b: int elif mode == 'pytreenode': @@ -139,7 +138,7 @@ def test_mutable(self, mode): class A: a: int = 1 - @functools.partial(struct.dataclass, frozen=False) + @struct.dataclass(frozen=False) class B: b: int = 1 elif mode == 'pytreenode': diff --git a/uv.lock b/uv.lock index 5dbc9e80..a3015511 100644 --- a/uv.lock +++ b/uv.lock @@ -81,6 +81,12 @@ dependencies = [ { name = "etils", version = "1.9.2", source = { registry = "https://pypi.org/simple" }, extra = ["epath"], marker = "python_full_version >= '3.11' and platform_system != 'Darwin'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/9b/fe3cc94350cf082d3fb70a1393b259cd1d9bce5212f14f53deea1008b94b/array_record-0.5.1-1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3dbfac79589b53ad765d247b4b6b6c108623053950a8ae36d8a5f2bfec396bd1", size = 2140349 }, + { url = "https://files.pythonhosted.org/packages/ce/fd/a241172b054f0c496cc575a6081e2b457ef3cf520e652ee22f3035714535/array_record-0.5.1-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a0911cca3f71aa6724ae08c351e486acc2dcdc098df0e4ae9aa920f16aee2385", size = 2200584 }, + { url = "https://files.pythonhosted.org/packages/7f/b9/ab118be4efaae976db4dbffbf4d9479151509668261d95beaa80a956a757/array_record-0.5.1-1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e6b297b9241d10f072f00a85e97c8743c9e623be20e413ab3403b9326ed98890", size = 2140482 }, + { url = "https://files.pythonhosted.org/packages/d8/5e/9379b00e5b17ea280845b82c492cab9298eb658ea9d40c21f0fd064a4dd5/array_record-0.5.1-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e39b2001fbed6f6d621a5f2079609037167ee06bf977fd6c37d225043c39a015", size = 2200598 }, + { url = "https://files.pythonhosted.org/packages/45/9b/74eb64c839871cb3adfb254246e42be8a7ce636debe9ab9a3748cb0c484b/array_record-0.5.1-1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c4e6e5cef45a82641f4bb008c2a1409cd043f46dd3f0e5a2e7f232416435186d", size = 2140093 }, + { url = "https://files.pythonhosted.org/packages/b7/4d/8ed8fbef16144db66b92e3fcbcb4656edaa5cf538d20fe7913c1caa78b68/array_record-0.5.1-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:927f5f0bdbb141e75d370ade9ce784514babcb78f86d23badbab2d7fd6b7cd48", size = 2200996 }, { url = "https://files.pythonhosted.org/packages/76/85/f8e77e0ee6644ab3585de1b73a183e6831ded6e7b791f21a3de5f6e29aeb/array_record-0.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9f2e304e59a17af9f5bf2a86b93ad4700d0eeb85d742a884aa38dc0b54dda5b", size = 2135133 }, { url = "https://files.pythonhosted.org/packages/9e/da/a7c513f35d4878888ca5d1e8548324e90414106ece7b44908002c800a22f/array_record-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:897362036f2920093eff3d729c2a6e1844e3077f513d6bd29640cd02f98e07c7", size = 2195378 }, { url = "https://files.pythonhosted.org/packages/61/7f/e0329a2aad1cf96e2b797e55e744af94c3d8d1969240c0153660214477c0/array_record-0.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ebe99f37e3a797322f4f5cfc6902b5e852012ba2729fac628aad6affb225247", size = 2135268 }, @@ -504,7 +510,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -767,7 +773,7 @@ wheels = [ [[package]] name = "flax" -version = "0.8.6" +version = "0.10.2" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -778,6 +784,7 @@ dependencies = [ { name = "pyyaml" }, { name = "rich" }, { name = "tensorstore" }, + { name = "treescope" }, { name = "typing-extensions" }, ] @@ -794,7 +801,9 @@ docs = [ { name = "einops" }, { name = "ipykernel" }, { name = "ipython-genutils" }, + { name = "ipywidgets" }, { name = "jupytext" }, + { name = "kagglehub" }, { name = "matplotlib" }, { name = "ml-collections" }, { name = "myst-nb" }, @@ -807,6 +816,7 @@ docs = [ { name = "sphinx-design" }, ] testing = [ + { name = "cloudpickle" }, { name = "clu" }, { name = "einops" }, { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, @@ -831,6 +841,7 @@ testing = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", marker = "extra == 'testing'", specifier = ">=3.0.0" }, { name = "clu", marker = "python_full_version < '3.10' and extra == 'testing'", specifier = "<=0.0.9" }, { name = "clu", marker = "extra == 'testing'" }, { name = "dm-haiku", marker = "extra == 'docs'" }, @@ -840,11 +851,13 @@ requires-dist = [ { name = "gymnasium", extras = ["accept-rom-license", "atari"], marker = "extra == 'testing'" }, { name = "ipykernel", marker = "extra == 'docs'" }, { name = "ipython-genutils", marker = "extra == 'docs'" }, + { name = "ipywidgets", marker = "extra == 'docs'", specifier = ">=8.1.5" }, { name = "jax", specifier = ">=0.4.27" }, { name = "jaxlib", marker = "extra == 'testing'" }, { name = "jaxtyping", marker = "extra == 'testing'" }, { name = "jraph", marker = "extra == 'testing'", specifier = ">=0.0.6.dev0" }, { name = "jupytext", marker = "extra == 'docs'", specifier = "==1.13.8" }, + { name = "kagglehub", marker = "extra == 'docs'", specifier = ">=0.3.3" }, { name = "matplotlib", marker = "extra == 'all'" }, { name = "matplotlib", marker = "extra == 'docs'" }, { name = "ml-collections", marker = "extra == 'docs'" }, @@ -878,6 +891,7 @@ requires-dist = [ { name = "tensorflow-text", marker = "platform_system != 'Darwin' and extra == 'testing'", specifier = ">=2.11.0" }, { name = "tensorstore" }, { name = "torch", marker = "extra == 'testing'" }, + { name = "treescope", specifier = ">=0.1.2" }, { name = "treescope", marker = "python_full_version >= '3.10' and extra == 'testing'", specifier = ">=0.1.1" }, { name = "typing-extensions", specifier = ">=4.2" }, ] @@ -1215,9 +1229,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", size = 26343 }, ] +[[package]] +name = "ipywidgets" +version = "8.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "comm" }, + { name = "ipython" }, + { name = "jupyterlab-widgets" }, + { name = "traitlets" }, + { name = "widgetsnbextension" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/4c/dab2a281b07596a5fc220d49827fe6c794c66f1493d7a74f1df0640f2cc5/ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17", size = 116723 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/2d/9c0b76f2f9cc0ebede1b9371b6f317243028ed60b90705863d493bae622e/ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245", size = 139767 }, +] + [[package]] name = "jax" -version = "0.4.31" +version = "0.4.35" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1226,14 +1256,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/73/e4/c1a4c0e7dafbc53fff9f42f9c1bf5918dabd1f91325512d6b382bea8750b/jax-0.4.31.tar.gz", hash = "sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287", size = 1743359 } +sdist = { url = "https://files.pythonhosted.org/packages/e3/34/21da583b9596e72bb8e95b6197dee0a44b96b9ea2c147fccabd43ca5515b/jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e", size = 1861189 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/cf/5f51b43bd692e90585c0ef6e8d1b0db5d254fe0224a6570daa59a1be014f/jax-0.4.31-py3-none-any.whl", hash = "sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7", size = 2038969 }, + { url = "https://files.pythonhosted.org/packages/62/20/6c57c50c0ccc645fea1895950f1e5cd02f961ee44b3ffe83617fa46b0c1d/jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325", size = 2158621 }, ] [[package]] name = "jaxlib" -version = "0.4.31" +version = "0.4.35" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1241,21 +1271,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/72/12c267f6775aac7e3ca6ed882c9816883cce44d73169d25d0e0b0f1f6972/jaxlib-0.4.31-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:48ea73cb78341bd4aabbb15e1a076ed61505ec80ab8eb4810e2d34758c400f80", size = 88767265 }, - { url = "https://files.pythonhosted.org/packages/b2/c9/0a6a964a852b66cff6108b8d8bc17115b69171fa6a22a916bc911d9f0a61/jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bacb86012f9104dd71706266420fd1e5d179d826d0635c95fe31506d605b4537", size = 70040016 }, - { url = "https://files.pythonhosted.org/packages/ae/4d/71e6286f88bf2c516e8af26a4245b8a68b12fcf1bbb42a4b3b7958575407/jaxlib-0.4.31-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d019023f71dba65127a3016ddc755de4b30f5bc9bd5b632a716a5fb3b00c5e53", size = 73050144 }, - { url = "https://files.pythonhosted.org/packages/cd/d7/918ac5477d1c32329c43bc2eb40473baa1c244851c825904430e8911f15a/jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d", size = 88131641 }, - { url = "https://files.pythonhosted.org/packages/ed/ea/2ba944ba4365cf8f043ff34cdb9704e29a37478b75592d03672fbba4d0df/jaxlib-0.4.31-cp310-cp310-win_amd64.whl", hash = "sha256:d3540a557c188d23ef93760da482b158ca910124a0445263c3b17c09c114538a", size = 56281724 }, - { url = "https://files.pythonhosted.org/packages/46/d0/100199575992545940afc17e62dea5a79c15ef738c1ae9784a1838962aa4/jaxlib-0.4.31-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1fd838ff91ea58ec2bdc7b4ecbb921ad501a318fafdeae120e6e7f88f5c20b17", size = 88768971 }, - { url = "https://files.pythonhosted.org/packages/18/ea/eddfae920bf689314aa0302a4c841cfac01b6cfd77f60f1a3f3dd355fddc/jaxlib-0.4.31-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:86340df8b37729f6fc5742f17761857bb9e59c418c9453e9b090f49f6194cdf9", size = 70038216 }, - { url = "https://files.pythonhosted.org/packages/a6/ce/ce7d3ba4790e18f67cfcb4552056dd04350085116f4754333f481516d97c/jaxlib-0.4.31-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:2d2639d210b0b1918dfaabbcc504fc668326e1a6fd1f0eb427c40b039188bbce", size = 73050770 }, - { url = "https://files.pythonhosted.org/packages/32/33/6d30bf3ec7d590a8dc0f1e30ea4c531b6f6a33116eb2066e354b485066de/jaxlib-0.4.31-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:1db6f8ea35b884f9e7761b006ee9c60ed05be6c75d2e527551f74579cbe11677", size = 88130221 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/5b7d20ed550d156311587eee6e44c48971fe6c3b43f39e82dacda3875396/jaxlib-0.4.31-cp311-cp311-win_amd64.whl", hash = "sha256:ceec494df08aaf65b8bbcbd40dd21a6579fa76ca5b851cce46fd7ce0388c0449", size = 56279795 }, - { url = "https://files.pythonhosted.org/packages/fa/27/3eee15d1b168d434498c388780114d7629f715e19c2d08754ab4be82ad2d/jaxlib-0.4.31-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d", size = 88818529 }, - { url = "https://files.pythonhosted.org/packages/68/cf/28895a4a89d88d18592507d7a35218b6bb2d8bced13615065c9f925f2ae1/jaxlib-0.4.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832", size = 70079551 }, - { url = "https://files.pythonhosted.org/packages/e0/af/10b49f8de2acc7abc871478823579d7241be52ca0d6bb0d2b2c476cc1b68/jaxlib-0.4.31-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803", size = 73053401 }, - { url = "https://files.pythonhosted.org/packages/b1/09/58d35465d48c8bee1d9a4e7a3c5db2edaabfc7ac94f4576c9f8c51b83e70/jaxlib-0.4.31-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd", size = 88162291 }, - { url = "https://files.pythonhosted.org/packages/c8/13/1bb2bcb4d9f4719dd5f3d98f5c2fc2235f961ced576366b040372eebdb17/jaxlib-0.4.31-cp312-cp312-win_amd64.whl", hash = "sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072", size = 56299104 }, + { url = "https://files.pythonhosted.org/packages/f4/67/c025520d2c548569f73cd68b885862e56e8946a10c9d43834460007c2671/jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f", size = 87876323 }, + { url = "https://files.pythonhosted.org/packages/a8/e7/7962830da208ad3fa6596dc2df77824da9bc0196b549ae549ce53d1d1de1/jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288", size = 68025360 }, + { url = "https://files.pythonhosted.org/packages/fa/91/2a1a1551845dd634bb1647fd37157f6f4ea71481e63f4100d08923c29d22/jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5", size = 70588250 }, + { url = "https://files.pythonhosted.org/packages/d7/16/6a9053d8b4b2790e330f9143030ab9d456556da5d98887b7e071bd08ffed/jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e", size = 87282292 }, + { url = "https://files.pythonhosted.org/packages/6c/a9/b6bdff31e21a485190985dccbdd5ae1130fe2e4af826c83c10ae1d0d14a9/jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62", size = 56484115 }, + { url = "https://files.pythonhosted.org/packages/ee/01/4be899cf8d05920877b46b8acf51083dedaba206e951d88ddf7b098bed80/jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b", size = 87895891 }, + { url = "https://files.pythonhosted.org/packages/55/77/ca1e70bc3a161c1043d2e169a618263f865bf959433e5bf40ea56ec13e5e/jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217", size = 68045181 }, + { url = "https://files.pythonhosted.org/packages/cd/2f/a8f4c441718558406cf27749415d1aa14bdac9dbd06fadb7bb4742c53637/jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1", size = 70614621 }, + { url = "https://files.pythonhosted.org/packages/c8/a6/1abe8d682d46cf2989f9c4928866ae80c30a54d607221a262cff8a5d9366/jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4", size = 87309681 }, + { url = "https://files.pythonhosted.org/packages/7d/7c/73a4c4a34f2bbfce63e8baefee11753b0d58a71e0d2c33f210e00edba3cb/jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9", size = 56520062 }, + { url = "https://files.pythonhosted.org/packages/ef/1c/901a59d9bc051b2a991163c46f58c50724d18ab25e71fa5556e5f68b84a4/jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f", size = 87936215 }, + { url = "https://files.pythonhosted.org/packages/da/ff/38030bc3c96fae50f629830afe9c63a8a040aae332f6e28cd529397ba114/jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad", size = 68063993 }, + { url = "https://files.pythonhosted.org/packages/55/27/83b6d2a1b380e20610e1449231c30c948cc4352c9a7e74a0d0d01bff8339/jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74", size = 70629159 }, + { url = "https://files.pythonhosted.org/packages/6d/3f/5ac6dfef795f4f58645ccff0ebd65234cb77d7dbf1bdd2b6c49a677b64b0/jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a", size = 87349348 }, + { url = "https://files.pythonhosted.org/packages/97/05/093b3c511837ba514f0b97581f7b21e1bb79768b8b9c29013a406b00d484/jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d", size = 56561679 }, + { url = "https://files.pythonhosted.org/packages/99/40/aedef37c44797779a01bf71a392145724e3e0fc369e5f08f55c3c82de733/jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18", size = 87934299 }, + { url = "https://files.pythonhosted.org/packages/94/42/62d4d13078886f4d22ca95ca07135f740cf9dd925f4cdb23d7b7d432403b/jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb", size = 68065641 }, + { url = "https://files.pythonhosted.org/packages/4d/a0/87a4eae3811ce7014ce2c59b811ad930273bfbbb8252ba78079606f9ec40/jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5", size = 70629568 }, + { url = "https://files.pythonhosted.org/packages/b3/89/59d6fe10e30ff5a48a73319bafa9a11cd999f91a47e4f08f7dc3651c899c/jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309", size = 87350315 }, + { url = "https://files.pythonhosted.org/packages/79/d7/d7600c65fe0412a6584d84ca172816a8cf19965219ee3dd59542447ffe2f/jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43", size = 56562022 }, ] [[package]] @@ -1405,6 +1440,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, ] +[[package]] +name = "jupyterlab-widgets" +version = "3.0.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/59/73/fa26bbb747a9ea4fca6b01453aa22990d52ab62dd61384f1ac0dc9d4e7ba/jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed", size = 203556 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/93/858e87edc634d628e5d752ba944c2833133a28fa87bb093e6832ced36a3e/jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54", size = 214392 }, +] + [[package]] name = "jupytext" version = "1.13.8" @@ -1421,6 +1465,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/e3/538509410372acd6d41f12c028dfc75ebddfbc4f7544f933bff7b5cc3e97/jupytext-1.13.8-py3-none-any.whl", hash = "sha256:625d2d2012763cc87d3f0dd60383516cec442c11894f53ad0c5ee5aa2a52caa2", size = 297592 }, ] +[[package]] +name = "kagglehub" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "requests" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/69/3e3d9533b44535903011157102bcf08ad4124f12b5d2c294850e6fad5032/kagglehub-0.3.3.tar.gz", hash = "sha256:0777d4d1ee1e59d4125b14ba62a46b2eadedb68bc6517479f6fb02a522a262f8", size = 60620 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/d1/4ab25019a168f5c414202f124d156e11ac79f07845d67288929311f1b1b2/kagglehub-0.3.3-py3-none-any.whl", hash = "sha256:5370acde855d04b6d8a7bc242edff339266913fffc8b198d31859b25b7d095f7", size = 42852 }, +] + [[package]] name = "keras" version = "3.5.0" @@ -2007,7 +2065,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2016,7 +2073,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2025,7 +2081,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2034,7 +2089,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2046,7 +2100,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2055,7 +2108,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2064,7 +2116,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2078,7 +2129,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2090,7 +2140,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2109,7 +2158,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 }, { url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 }, - { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 }, ] [[package]] @@ -2118,7 +2166,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -2216,7 +2263,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.6.0" +version = "0.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2224,19 +2271,19 @@ dependencies = [ { name = "etils", version = "1.9.2", source = { registry = "https://pypi.org/simple" }, extra = ["epath", "epy"], marker = "python_full_version >= '3.11'" }, { name = "humanize" }, { name = "jax" }, - { name = "jaxlib" }, { name = "msgpack" }, { name = "nest-asyncio" }, { name = "numpy" }, { name = "protobuf", version = "3.20.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "protobuf", version = "4.25.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pyyaml" }, + { name = "simplejson" }, { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/4f/f6b372e70fb3785656d31edd9b99a151dc1b4955486e85a1935e9e0273c5/orbax_checkpoint-0.6.0.tar.gz", hash = "sha256:313586128267e0923d6d2095855da5edcd45acee1f9d2e86d1e8330f69acb110", size = 187560 } +sdist = { url = "https://files.pythonhosted.org/packages/07/24/f13f75810a00873f779625b4fff9419d09f95a56bedb01453ac2b4990ce8/orbax_checkpoint-0.10.1.tar.gz", hash = "sha256:aaf44f5a10ced74badc7fcaf8a2396e9047a20a61487ad5e8514e539d7992cd8", size = 230081 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/e0/194d62674be60e3bf2cb764f653e8f06db86b02b6c9c9243ea9af0f48bf1/orbax_checkpoint-0.6.0-py3-none-any.whl", hash = "sha256:fce1d61b1a378939f55b03fb4ac9922ad0def0b846822b1f5e70f4a81d24dbc2", size = 253044 }, + { url = "https://files.pythonhosted.org/packages/b3/67/a175072cd7e5a215b12f39f4d9d891881a6220d75e30ae6480d05647bdf4/orbax_checkpoint-0.10.1-py3-none-any.whl", hash = "sha256:b4d7ae295d89a329c39109f945ff690d47c1db04eac644fa5316b2f42b5fa9e5", size = 328311 }, ] [[package]] @@ -3056,6 +3103,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/12/c657047c11a47e1c3e51bdc26bd6f2661a268fd0384bd8ed56b227530486/simple_parsing-0.1.5-py3-none-any.whl", hash = "sha256:46f35ed7002f9bb25dca3a49eac491cc78d2140e4adcbe156225ae643c2874ea", size = 113568 }, ] +[[package]] +name = "simplejson" +version = "3.19.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/29/085111f19717f865eceaf0d4397bf3e76b08d60428b076b64e2a1903706d/simplejson-3.19.3.tar.gz", hash = "sha256:8e086896c36210ab6050f2f9f095a5f1e03c83fa0e7f296d6cba425411364680", size = 85237 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/24/260ad03435ce8ef2436031951134659c7161776ec3a78094b35b9375ceea/simplejson-3.19.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:50d8b742d74c449c4dcac570d08ce0f21f6a149d2d9cf7652dbf2ba9a1bc729a", size = 93660 }, + { url = "https://files.pythonhosted.org/packages/63/a1/dee207f357bcd6b106f2ca5129ee916c24993ba08b7dfbf9a37c22442ea9/simplejson-3.19.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dd011fc3c1d88b779645495fdb8189fb318a26981eebcce14109460e062f209b", size = 75546 }, + { url = "https://files.pythonhosted.org/packages/80/7b/45ef1da43f54d209ce2ef59b7356cda13f810186c381f38ae23a4d2b1337/simplejson-3.19.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:637c4d4b81825c1f4d651e56210bd35b5604034b192b02d2d8f17f7ce8c18f42", size = 75602 }, + { url = "https://files.pythonhosted.org/packages/7f/4b/9a132382982f8127bc7ce5212a5585d83c174707c9dd698d0cb6a0d41882/simplejson-3.19.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f56eb03bc9e432bb81adc8ecff2486d39feb371abb442964ffb44f6db23b332", size = 138632 }, + { url = "https://files.pythonhosted.org/packages/76/37/012f5ad2f38afa28f8a6ad9da01dc0b64492ffbaf2a3f2f8a0e1fddf9c1d/simplejson-3.19.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ef59a53be400c1fad2c914b8d74c9d42384fed5174f9321dd021b7017fd40270", size = 146740 }, + { url = "https://files.pythonhosted.org/packages/69/b3/89640bd676e26ea2315b5aaf80712a6fbbb4338e4caf872d91448502a19b/simplejson-3.19.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72e8abbc86fcac83629a030888b45fed3a404d54161118be52cb491cd6975d3e", size = 134440 }, + { url = "https://files.pythonhosted.org/packages/61/20/0035a288deaff05397d6cc0145b33f3dd2429b99cdc880de4c5eca41ca72/simplejson-3.19.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8efb03ca77bd7725dfacc9254df00d73e6f43013cf39bd37ef1a8ed0ebb5165", size = 137949 }, + { url = "https://files.pythonhosted.org/packages/5d/de/5b03fafe3003e32d179588953d38183af6c3747e95c7dcc668c4f9eb886a/simplejson-3.19.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:add8850db04b98507a8b62d248a326ecc8561e6d24336d1ca5c605bbfaab4cad", size = 139992 }, + { url = "https://files.pythonhosted.org/packages/d1/ce/e493116ff49fd215f7baa25195b8f684c91e65c153e2a57e04dc3f3a466b/simplejson-3.19.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fc3dc9fb413fc34c396f52f4c87de18d0bd5023804afa8ab5cc224deeb6a9900", size = 140320 }, + { url = "https://files.pythonhosted.org/packages/86/f3/a18b98a7a27548829f672754dd3940fb637a27981399838128d3e560087f/simplejson-3.19.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4dfa420bb9225dd33b6efdabde7c6a671b51150b9b1d9c4e5cd74d3b420b3fe1", size = 148625 }, + { url = "https://files.pythonhosted.org/packages/0f/55/d3da33ee3e708133da079b9d537693d7fef281e6f0d27921cc7e5b3ec523/simplejson-3.19.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7b5c472099b39b274dcde27f1113db8d818c9aa3ba8f78cbb8ad04a4c1ac2118", size = 141287 }, + { url = "https://files.pythonhosted.org/packages/17/e8/56184ab4d66bb64a6ff569f069b3796dfd943f9b961268fe0d403526fc17/simplejson-3.19.3-cp310-cp310-win32.whl", hash = "sha256:817abad79241ed4a507b3caf4d3f2be5079f39d35d4c550a061988986bffd2ec", size = 74143 }, + { url = "https://files.pythonhosted.org/packages/be/8f/a0089eff060f10a925f08b0a0f50854321484f1ac54b1895bbf4c9213dfe/simplejson-3.19.3-cp310-cp310-win_amd64.whl", hash = "sha256:dd5b9b1783e14803e362a558680d88939e830db2466f3fa22df5c9319f8eea94", size = 75643 }, + { url = "https://files.pythonhosted.org/packages/8c/bb/9ee3959e6929d228cf669b3f13f0edd43c5261b6cd69598640748b19ca35/simplejson-3.19.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e88abff510dcff903a18d11c2a75f9964e768d99c8d147839913886144b2065e", size = 91930 }, + { url = "https://files.pythonhosted.org/packages/ac/ae/a06523928af3a6783e2638cd4f6035c3e32de1c1063d563d9060c8d2f1ad/simplejson-3.19.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:934a50a614fb831614db5dbfba35127ee277624dda4d15895c957d2f5d48610c", size = 74787 }, + { url = "https://files.pythonhosted.org/packages/c3/58/fea732e48a7540035fe46d39e6fd77679f5810311d31da8661ce7a18210a/simplejson-3.19.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:212fce86a22188b0c7f53533b0f693ea9605c1a0f02c84c475a30616f55a744d", size = 74612 }, + { url = "https://files.pythonhosted.org/packages/ab/4d/15718f20cb0e3875b8af9597d6bb3bfbcf1383834b82b6385ee9ac0b72a9/simplejson-3.19.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d9e8f836688a8fabe6a6b41b334aa550a6823f7b4ac3d3712fc0ad8655be9a8", size = 143550 }, + { url = "https://files.pythonhosted.org/packages/93/44/815a4343774760f7a82459c8f6a4d8268b4b6d23f81e7b922a5e2ca79171/simplejson-3.19.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23228037dc5d41c36666384062904d74409a62f52283d9858fa12f4c22cffad1", size = 153284 }, + { url = "https://files.pythonhosted.org/packages/9d/52/d3202d9bba95444090d1c98e43da3c10907875babf63ed3c134d1b9437e3/simplejson-3.19.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0791f64fed7d4abad639491f8a6b1ba56d3c604eb94b50f8697359b92d983f36", size = 141518 }, + { url = "https://files.pythonhosted.org/packages/b7/d4/850948bcbcfe0b4a6c69dfde10e245d3a1ea45252f16a1e2308a3b06b1da/simplejson-3.19.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4f614581b61a26fbbba232a1391f6cee82bc26f2abbb6a0b44a9bba25c56a1c", size = 144688 }, + { url = "https://files.pythonhosted.org/packages/58/d2/b8dcb0a07d9cd54c47f9fe8733dbb83891d1efe4fc786d9dfc8781cc04f9/simplejson-3.19.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1df0aaf1cb787fdf34484ed4a1f0c545efd8811f6028623290fef1a53694e597", size = 144534 }, + { url = "https://files.pythonhosted.org/packages/a9/95/1e92d99039041f596e0923ec4f9153244acaf3830944dc69a7c11b23ceaa/simplejson-3.19.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:951095be8d4451a7182403354c22ec2de3e513e0cc40408b689af08d02611588", size = 146565 }, + { url = "https://files.pythonhosted.org/packages/21/04/c96aeb3a74031255e4cbcc0ca1b6ebfb5549902f0a065f06d65ce8447c0c/simplejson-3.19.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a954b30810988feeabde843e3263bf187697e0eb5037396276db3612434049b", size = 155014 }, + { url = "https://files.pythonhosted.org/packages/b7/41/e28a28593afc4a75d8999d057bfb7c73a103e35f927e66f4bb92571787ae/simplejson-3.19.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c40df31a75de98db2cdfead6074d4449cd009e79f54c1ebe5e5f1f153c68ad20", size = 148092 }, + { url = "https://files.pythonhosted.org/packages/2b/82/1c81a3af06f937afb6d2e9d74a465c0e0ae6db444d1bf2a436ea26de1965/simplejson-3.19.3-cp311-cp311-win32.whl", hash = "sha256:7e2a098c21ad8924076a12b6c178965d88a0ad75d1de67e1afa0a66878f277a5", size = 73942 }, + { url = "https://files.pythonhosted.org/packages/65/be/d8ab9717f471be3c114f16abd8be21d9a6a0a09b9b49177d93d64d3717d9/simplejson-3.19.3-cp311-cp311-win_amd64.whl", hash = "sha256:c9bedebdc5fdad48af8783022bae307746d54006b783007d1d3c38e10872a2c6", size = 75469 }, + { url = "https://files.pythonhosted.org/packages/20/15/513fea93fafbdd4993eacfcb762965b2ff3d29e618c029e2956174d68c4b/simplejson-3.19.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:66a0399e21c2112acacfebf3d832ebe2884f823b1c7e6d1363f2944f1db31a99", size = 92921 }, + { url = "https://files.pythonhosted.org/packages/a4/4f/998a907ae1a6c104dc0ee48aa248c2478490152808d34d8e07af57f396c3/simplejson-3.19.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6ef9383c5e05f445be60f1735c1816163c874c0b1ede8bb4390aff2ced34f333", size = 75311 }, + { url = "https://files.pythonhosted.org/packages/db/44/acd6122201e927451869d45952b9ab1d3025cdb5e61548d286d08fbccc08/simplejson-3.19.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:42e5acf80d4d971238d4df97811286a044d720693092b20a56d5e56b7dcc5d09", size = 74964 }, + { url = "https://files.pythonhosted.org/packages/27/ca/d0a1e8f16e1bbdc0b8c6d88166f45f565ed7285f53928cfef3b6ce78f14d/simplejson-3.19.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0b0efc7279d768db7c74d3d07f0b5c81280d16ae3fb14e9081dc903e8360771", size = 150106 }, + { url = "https://files.pythonhosted.org/packages/63/59/0554b78cf26c98e2b9cae3f44723bd72c2394e2afec1a14eedc6211f7187/simplejson-3.19.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0552eb06e7234da892e1d02365cd2b7b2b1f8233aa5aabdb2981587b7cc92ea0", size = 158347 }, + { url = "https://files.pythonhosted.org/packages/b2/fe/9f30890352e431e8508cc569912d3322147d3e7e4f321e48c0adfcb4c97d/simplejson-3.19.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5bf6a3b9a7d7191471b464fe38f684df10eb491ec9ea454003edb45a011ab187", size = 148456 }, + { url = "https://files.pythonhosted.org/packages/37/e3/663a09542ee021d4131162f7a164cb2e7f04ef48433a67591738afbf12ea/simplejson-3.19.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7017329ca8d4dca94ad5e59f496e5fc77630aecfc39df381ffc1d37fb6b25832", size = 152190 }, + { url = "https://files.pythonhosted.org/packages/31/20/4e0c4d35e10ff6465003bec304316d822a559a1c38c66ef6892ca199c207/simplejson-3.19.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:67a20641afebf4cfbcff50061f07daad1eace6e7b31d7622b6fa2c40d43900ba", size = 149846 }, + { url = "https://files.pythonhosted.org/packages/08/7a/46e2e072cac3987cbb05946f25167f0ad2fe536748e7405953fd6661a486/simplejson-3.19.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:dd6a7dabcc4c32daf601bc45e01b79175dde4b52548becea4f9545b0a4428169", size = 151714 }, + { url = "https://files.pythonhosted.org/packages/7f/7d/dbeeac10eb61d5d8858d0bb51121a21050d281dc83af4c557f86da28746c/simplejson-3.19.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:08f9b443a94e72dd02c87098c96886d35790e79e46b24e67accafbf13b73d43b", size = 158777 }, + { url = "https://files.pythonhosted.org/packages/fc/8f/a98bdbb799c6a4a884b5823db31785a96ba895b4b0f4d8ac345d6fe98bbf/simplejson-3.19.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa97278ae6614346b5ca41a45a911f37a3261b57dbe4a00602048652c862c28b", size = 154230 }, + { url = "https://files.pythonhosted.org/packages/b1/db/852eebceb85f969ae40e06babed1a93d3bacb536f187d7a80ff5823a5979/simplejson-3.19.3-cp312-cp312-win32.whl", hash = "sha256:ef28c3b328d29b5e2756903aed888960bc5df39b4c2eab157ae212f70ed5bf74", size = 74002 }, + { url = "https://files.pythonhosted.org/packages/fe/68/9f0e5df0651cb79ef83cba1378765a00ee8038e6201cc82b8e7178a7778e/simplejson-3.19.3-cp312-cp312-win_amd64.whl", hash = "sha256:1e662336db50ad665777e6548b5076329a94a0c3d4a0472971c588b3ef27de3a", size = 75596 }, + { url = "https://files.pythonhosted.org/packages/93/3a/5896821ed543899fcb9c4256c7e71bb110048047349a00f42bc8b8fb379f/simplejson-3.19.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0959e6cb62e3994b5a40e31047ff97ef5c4138875fae31659bead691bed55896", size = 92931 }, + { url = "https://files.pythonhosted.org/packages/39/15/5d33d269440912ee40d856db0c8be2b91aba7a219690ab01f86cb0edd590/simplejson-3.19.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7a7bfad839c624e139a4863007233a3f194e7c51551081f9789cba52e4da5167", size = 75318 }, + { url = "https://files.pythonhosted.org/packages/2a/8d/2e7483a2bf7ec53acf7e012bafbda79d7b34f90471dda8e424544a59d484/simplejson-3.19.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afab2f7f2486a866ff04d6d905e9386ca6a231379181a3838abce1f32fbdcc37", size = 74971 }, + { url = "https://files.pythonhosted.org/packages/4d/9d/9bdf34437c8834a7cf7246f85e9d5122e30579f512c10a0c2560e994294f/simplejson-3.19.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00313681015ac498e1736b304446ee6d1c72c5b287cd196996dad84369998f7", size = 150112 }, + { url = "https://files.pythonhosted.org/packages/a7/e2/1f2ae2d89eaf85f6163c82150180aae5eaa18085cfaf892f8a57d4c51cbd/simplejson-3.19.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d936ae682d5b878af9d9eb4d8bb1fdd5e41275c8eb59ceddb0aeed857bb264a2", size = 158354 }, + { url = "https://files.pythonhosted.org/packages/60/83/26f610adf234c8492b3f30501e12f2271e67790f946c6898fe0c58aefe99/simplejson-3.19.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c6657485393f2e9b8177c77a7634f13ebe70d5e6de150aae1677d91516ce6b", size = 148455 }, + { url = "https://files.pythonhosted.org/packages/b5/4b/109af50006af77133653c55b5b91b4bd2d579ff8254ce11216c0b75f911b/simplejson-3.19.3-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a6a750d3c7461b1c47cfc6bba8d9e57a455e7c5f80057d2a82f738040dd1129", size = 152191 }, + { url = "https://files.pythonhosted.org/packages/75/dc/108872a8825cbd99ae6f4334e0490ff1580367baf12198bcaf988f6820ba/simplejson-3.19.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ea7a4a998c87c5674a27089e022110a1a08a7753f21af3baf09efe9915c23c3c", size = 149954 }, + { url = "https://files.pythonhosted.org/packages/eb/be/deec1d947a5d0472276ab4a4d1a9378dc5ee27f3dc9e54d4f62ffbad7a08/simplejson-3.19.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6300680d83a399be2b8f3b0ef7ef90b35d2a29fe6e9c21438097e0938bbc1564", size = 151812 }, + { url = "https://files.pythonhosted.org/packages/e9/58/4ee130702d36b1551ef66e7587eefe56651f3669255bf748cd71691e2434/simplejson-3.19.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:ab69f811a660c362651ae395eba8ce84f84c944cea0df5718ea0ba9d1e4e7252", size = 158880 }, + { url = "https://files.pythonhosted.org/packages/0f/e1/59cc6a371b60f89e3498d9f4c8109f6b7359094d453f5fe80b2677b777b0/simplejson-3.19.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:256e09d0f94d9c3d177d9e95fd27a68c875a4baa2046633df387b86b652f5747", size = 154344 }, + { url = "https://files.pythonhosted.org/packages/79/45/1b36044670016f5cb25ebd92497427d2d1711ecb454d00f71eb9a00b77cc/simplejson-3.19.3-cp313-cp313-win32.whl", hash = "sha256:2c78293470313aefa9cfc5e3f75ca0635721fb016fb1121c1c5b0cb8cc74712a", size = 74002 }, + { url = "https://files.pythonhosted.org/packages/e2/58/b06226e6b0612f2b1fa13d5273551da259f894566b1eef32249ddfdcce44/simplejson-3.19.3-cp313-cp313-win_amd64.whl", hash = "sha256:3bbcdc438dc1683b35f7a8dc100960c721f922f9ede8127f63bed7dfded4c64c", size = 75599 }, + { url = "https://files.pythonhosted.org/packages/0d/e7/f9fafbd4f39793a20cc52e77bbd766f7384312526d402c382928dc7667f6/simplejson-3.19.3-py3-none-any.whl", hash = "sha256:49cc4c7b940d43bd12bf87ec63f28cbc4964fc4e12c031cc8cd01650f43eb94e", size = 57004 }, +] + [[package]] name = "six" version = "1.16.0" @@ -3408,37 +3516,35 @@ dependencies = [ { name = "tensorflow", marker = "platform_system != 'Darwin'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 }, { url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 }, - { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 }, { url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 }, ] [[package]] name = "tensorstore" -version = "0.1.64" +version = "0.1.68" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ce/b7/04d19901451da377f03a6e1ae3d9edf0b43af93309f558abf28b2e5aaceb/tensorstore-0.1.64.tar.gz", hash = "sha256:7fa89e90876fb5377efc54f3f37326a6fb83ec9e1326565819a75a4e80949886", size = 6510000 } +sdist = { url = "https://files.pythonhosted.org/packages/93/c4/477fc183721128feb97a3427940457ace4c4f063da6f6f7a0b374f6b6d4c/tensorstore-0.1.68.tar.gz", hash = "sha256:6e13d3e3c8fb6ed67712835a343821536b38d6bdb517db554d41cebfe5947ab7", size = 6585632 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/a8/63876bab9ca44d0b57bca6893927df90b08ff0123697216fe7b297036015/tensorstore-0.1.64-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c369088c74c0dda30398290724513a0289f25ccc01865ed5aec62e57f1930709", size = 15366638 }, - { url = "https://files.pythonhosted.org/packages/90/3d/28b0ee2d792842d2e27be9fea5c541a77d1f8f4d4c1a3a981306acb69818/tensorstore-0.1.64-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:40cae39aca2992fdac0ed5fbcef71f72cd38a759b1a61c37d95ad395606697b4", size = 13563010 }, - { url = "https://files.pythonhosted.org/packages/b8/26/40a8cc7ffcc4abeacd196560f8d54ca2e24d2bb8ca540360bf4c7b1b5e70/tensorstore-0.1.64-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8cf64ee03c7cd62a0dde2f4d1f3f8784d50aea3a2e85a65686be0fe33ea18ed5", size = 13650288 }, - { url = "https://files.pythonhosted.org/packages/f1/3b/9e539c9d22f4eda48a9e5788d76e761f0627f249c3018d396bcdf17c7a54/tensorstore-0.1.64-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a78aedbddccc09ea283b145496da03dbc7eb8693ae4e01074ed791d72b7eac2", size = 14926295 }, - { url = "https://files.pythonhosted.org/packages/66/f4/fb0bab70e472ce78f290222b5b1631c589a8fe9043148c0882150b28b527/tensorstore-0.1.64-cp310-cp310-win_amd64.whl", hash = "sha256:72517af8c5f9c49d0343acb7c6b0cc250f8077ca989285d471d3a64dbbfcc36b", size = 11523913 }, - { url = "https://files.pythonhosted.org/packages/4d/9c/e1ef8f867de64f36c2ec3a1cb803693736a4dcb91d5afd0741c8e11e71df/tensorstore-0.1.64-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2b0a1e3294d2e690a9c269ea50d62f2f60f7935ca507243d8b56b2871b0e201f", size = 15367232 }, - { url = "https://files.pythonhosted.org/packages/46/a7/e6adff4ec3f622bd28a79bfa339aea3dc9d66508e87bc739f730b970098e/tensorstore-0.1.64-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3da6fa00ddf312e1b502d2ee9de39b858a78a02b396114201c67c01bc03fc382", size = 13567261 }, - { url = "https://files.pythonhosted.org/packages/19/c4/e74f4c288b429221fd2f128eb57bed62ebf4bf69739970e404d8a5b63712/tensorstore-0.1.64-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c32976f5a0e881a097b52a488fb16d33a1d94a86393115098da87894fc9c5abf", size = 13652088 }, - { url = "https://files.pythonhosted.org/packages/c8/5a/2df005251df903de0fda4d8da7e7a5081a6854d40b62b8eeaf88a86a1c7a/tensorstore-0.1.64-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55af5ec5bd78056e4df18f4af107bac7ea84d2bdc34ff6ab6642b3a036f99390", size = 14926070 }, - { url = "https://files.pythonhosted.org/packages/e5/68/07d792f014fc3ad886a2498ebbfdaf5d6807c09c65fad5534969620846b4/tensorstore-0.1.64-cp311-cp311-win_amd64.whl", hash = "sha256:24a4cebaf9d0e75d494342948f68edc971d6bb90e23192ddf8d98397fb1ff3cb", size = 11523737 }, - { url = "https://files.pythonhosted.org/packages/00/32/e9b22f4c05ae910940fbc6c304b6570b8cf8d35b1d2e8600d8118c42a80d/tensorstore-0.1.64-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:80c510024cc31c4dee7f478ea67a0b4b4cacf5a6bffe8c4e446188fdbe2d7b4c", size = 15404886 }, - { url = "https://files.pythonhosted.org/packages/df/9d/01e43143ac82cdc7b87e55818f0052a63b3414bd9f731a2c991dd68ca4ba/tensorstore-0.1.64-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c90d38b552c79f0d688cc3d502a9023e3dee9821881d6727d8aa06482ccdc0c1", size = 13594439 }, - { url = "https://files.pythonhosted.org/packages/44/7e/1522b9092e396d64d84ea799ef1f9c1d7e7da3514277fa8b908e1d8d26d1/tensorstore-0.1.64-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9968f9a9b9cd7c669bfae5244307e105c006038e8dd156eebbf2146f771ba369", size = 13646074 }, - { url = "https://files.pythonhosted.org/packages/0a/eb/09210bb4a8afc991eb9cb794269ff276a62f15936aef2b64335b61412f7a/tensorstore-0.1.64-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:806774968ee4cc8809114281730e9fad5970a94a7ef9104bc54fa35a32068b2f", size = 14923761 }, - { url = "https://files.pythonhosted.org/packages/c7/70/27281fb67817d69dddc5eec9827513f8e341e3a52cb85f066a84e9274a47/tensorstore-0.1.64-cp312-cp312-win_amd64.whl", hash = "sha256:cc315029f49c0f294f0721462c221e0ef4c15360a526cc34392ac81565fd63b8", size = 11523992 }, + { url = "https://files.pythonhosted.org/packages/ee/22/ec1298e7b01ff2a8dbcfa5f5cffbcc323c7f040ef511bd343a25ee1e2511/tensorstore-0.1.68-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c9ca5a5dc1e13760f024c3607219e60c3b8338f1b4f7413e1a13115a132ac7d9", size = 13996437 }, + { url = "https://files.pythonhosted.org/packages/c8/44/a50067d4964c5106f72752d3b778839edf7156695ece64bc51aefe5b5476/tensorstore-0.1.68-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:425c56cdd7f76af8be0c056933da9bf8b8812c00e4fef08888465e2f126d53eb", size = 12230834 }, + { url = "https://files.pythonhosted.org/packages/99/9e/cbffcabdfc5a471d4ddadddccac32d028f7b76375c1c3204edaeb44cef6a/tensorstore-0.1.68-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1348768a5aae514b440212eedb50d246a1a4b39f8e74d275ef0bead688c562b", size = 13983781 }, + { url = "https://files.pythonhosted.org/packages/18/35/1b4f581767982c539b5558b6bcd14e885cd1d6f5589b614fbae7bf2846d4/tensorstore-0.1.68-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5fa0e47b42eb58ddea81763cb0de4a92c4ab0da530d2a27f1928539980a781a", size = 15320730 }, + { url = "https://files.pythonhosted.org/packages/5e/eb/db4df6fbd35fd8dc2312f5cdb4832505864f6ecb1bdbcc55d8b44a17e979/tensorstore-0.1.68-cp310-cp310-win_amd64.whl", hash = "sha256:76ebad6762d226c9621d256d8703381963e407d0361cd33f0f89409a31acb57e", size = 11996126 }, + { url = "https://files.pythonhosted.org/packages/28/f1/4e0212dd514f8aabf739ad4e599e73e28c9dd3fe69431ba52b0e06e89894/tensorstore-0.1.68-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:23dc88d5188267529beb72012f72ce892ee25d40daf9dd533413bfc818b1d030", size = 13999180 }, + { url = "https://files.pythonhosted.org/packages/38/d9/7f60328b202cc3a263fe1975e9de77d7f2f0f5b9ba0e4c8eedc85d60fe46/tensorstore-0.1.68-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d62c4288e68b4640de878f8393a5779440b2de8e84cf7b717f91a01a4e6b4be", size = 12231469 }, + { url = "https://files.pythonhosted.org/packages/a8/2d/3028f63cdbe8c421301e54b6bf6d4f806906b0ed63493667c930e12185bc/tensorstore-0.1.68-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6e51188a82c93563440c805bd501b12f0dc30267667f664091b3a2b8b108017", size = 13985840 }, + { url = "https://files.pythonhosted.org/packages/57/fc/00412d0acf5e51d17e14d39da30a91d53cec83b5a717064af781021f65db/tensorstore-0.1.68-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:889900ee6a9ffba4635f44f663b41f5b43f67b1e74bd507fa4a30f0f02704c80", size = 15319741 }, + { url = "https://files.pythonhosted.org/packages/b3/8b/e2e8f4b2a8e682d3cbd25a32df2da06441a4091d7cf262cf7f10237627af/tensorstore-0.1.68-cp311-cp311-win_amd64.whl", hash = "sha256:c65460ac90f8db49ad35779964ea5983332fe63e60b4d94ba66640c68ef73091", size = 11997197 }, + { url = "https://files.pythonhosted.org/packages/f7/e5/047b97ac501a0bc48928897f8e3516856bc7488301df4af462750acff37d/tensorstore-0.1.68-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:d80f9b48b057fda9aea0407e576324354b054aae02fa08fc0a8e6b11acf7ae3a", size = 14038026 }, + { url = "https://files.pythonhosted.org/packages/a2/1e/3f368ed9e1dd09e022851be2a1ffdc2e410ababdd16333332bcbf4eda1de/tensorstore-0.1.68-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5902d7c36e6119b761d02260b68646585b315202397e2a6c016e3f5d81d39a43", size = 12260517 }, + { url = "https://files.pythonhosted.org/packages/36/32/26afb8fd2dd5cff758a01708e10a3b59e1756d72e5f00c1fe91f0c971057/tensorstore-0.1.68-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a93fe05708acb9d9e3813f7f7ecd807c8ff34ec3fa30e2baa37e9270d128dcf0", size = 13972573 }, + { url = "https://files.pythonhosted.org/packages/93/e9/9674f5d59161325f350acc51b19db749a1b2b381b09417843ff09eba4d67/tensorstore-0.1.68-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6672b2047df3f772350ac75d6780f31201a82383c5b7c0c1986903b88e6f341a", size = 15312725 }, + { url = "https://files.pythonhosted.org/packages/7b/78/4d855796274f90bd869e2f79f20b84d12ff50b9b612970a128815f7b375e/tensorstore-0.1.68-cp312-cp312-win_amd64.whl", hash = "sha256:172420ec1c4e925a8ec3c386e31b4f81eae403bdca71b6258e7f775a69c3bfb3", size = 11998145 }, ] [[package]] @@ -3587,9 +3693,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, - { url = "https://files.pythonhosted.org/packages/15/67/84e5a4b7b45bdeb11da26a67dfa2b988c512abbcbcad8cbc30aa579051b2/triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230", size = 209380247 }, - { url = "https://files.pythonhosted.org/packages/ea/6b/1d72cc8a7379822dadf050474add7d8b73b02c35057446b6f17d27cb9ea2/triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e", size = 209442823 }, - { url = "https://files.pythonhosted.org/packages/ae/b2/048c9ecfdba0e6b0ae3c02eed2d9dd3e9e990a6d46da98555cf0c2232168/triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253", size = 209468633 }, ] [[package]] @@ -3663,6 +3766,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d1/9babe2ccaecff775992753d8686970b1e2755d21c8a63be73aba7a4e7d77/wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f", size = 67059 }, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/fc/238c424fd7f4ebb25f8b1da9a934a3ad7c848286732ae04263661eb0fc03/widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6", size = 1164730 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/02/88b65cc394961a60c43c70517066b6b679738caf78506a5da7b88ffcb643/widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71", size = 2335872 }, +] + [[package]] name = "wrapt" version = "1.16.0"