From 8e9df2795c1b20b1ae54591ec5c2013231649d40 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 9 Nov 2023 01:21:44 +0000 Subject: [PATCH] add experimental nnx --- .gitignore | 5 +- docs/conf.py | 3 +- docs/experimental/index.rst | 6 + docs/index.rst | 1 + flax/core/flax_functional_engine.ipynb | 92 +- flax/experimental/__init__.py | 0 flax/experimental/nnx/.gitignore | 133 ++ flax/experimental/nnx/README.md | 440 +++++ flax/experimental/nnx/__init__.py | 98 + flax/experimental/nnx/docs/blog.md | 7 + .../nnx/docs/images/stateful-transforms.png | Bin 0 -> 304812 bytes flax/experimental/nnx/docs/quick_start.ipynb | 568 ++++++ flax/experimental/nnx/docs/tiny_nnx.ipynb | 465 +++++ flax/experimental/nnx/docs/why.ipynb | 391 ++++ flax/experimental/nnx/examples/00_demo.ipynb | 288 +++ .../nnx/examples/01_functional_api.py | 108 ++ .../nnx/examples/02_lifted_transforms.py | 106 ++ .../nnx/examples/03_train_state.py | 117 ++ flax/experimental/nnx/examples/05_vae.py | 218 +++ .../nnx/examples/06_scan_over_layers.py | 87 + .../nnx/examples/07_transformer.py | 414 +++++ .../nnx/examples/08_save_load_checkpoints.py | 67 + .../nnx/examples/09_parameter_surgery.py | 56 + .../nnx/examples/10_quantization.py | 437 +++++ .../nnx/examples/requirements.txt | 2 + .../experimental/nnx/ideas/shape_inference.py | 210 +++ flax/experimental/nnx/nnx/__init__.py | 13 + flax/experimental/nnx/nnx/compatibility.py | 93 + flax/experimental/nnx/nnx/dataclasses.py | 188 ++ flax/experimental/nnx/nnx/errors.py | 17 + flax/experimental/nnx/nnx/filterlib.py | 100 + flax/experimental/nnx/nnx/flaglib.py | 55 + flax/experimental/nnx/nnx/helpers.py | 177 ++ flax/experimental/nnx/nnx/ids.py | 79 + flax/experimental/nnx/nnx/module.py | 932 ++++++++++ flax/experimental/nnx/nnx/nn/__init__.py | 13 + flax/experimental/nnx/nnx/nn/activations.py | 69 + flax/experimental/nnx/nnx/nn/dtypes.py | 80 + flax/experimental/nnx/nnx/nn/initializers.py | 73 + flax/experimental/nnx/nnx/nn/linear.py | 445 +++++ flax/experimental/nnx/nnx/nn/normalization.py | 401 ++++ flax/experimental/nnx/nnx/nn/stochastic.py | 86 + flax/experimental/nnx/nnx/pytreelib.py | 291 +++ flax/experimental/nnx/nnx/reprlib.py | 108 ++ flax/experimental/nnx/nnx/rnglib.py | 227 +++ flax/experimental/nnx/nnx/spmd.py | 223 +++ flax/experimental/nnx/nnx/state.py | 226 +++ flax/experimental/nnx/nnx/tracers.py | 113 ++ flax/experimental/nnx/nnx/transforms.py | 1639 +++++++++++++++++ flax/experimental/nnx/nnx/variables.py | 485 +++++ .../experimental/nnx/scripts/requirements.txt | 1 + .../nnx/scripts/run-all-examples.bash | 12 + flax/experimental/nnx/tests/__init__.py | 13 + .../nnx/tests/test_compatibility.py | 22 + .../experimental/nnx/tests/test_containers.py | 91 + flax/experimental/nnx/tests/test_helpers.py | 73 + flax/experimental/nnx/tests/test_ids.py | 30 + .../nnx/tests/test_integration.py | 250 +++ flax/experimental/nnx/tests/test_module.py | 621 +++++++ .../nnx/tests/test_partitioning.py | 159 ++ flax/experimental/nnx/tests/test_pytree.py | 264 +++ flax/experimental/nnx/tests/test_rngs.py | 180 ++ flax/experimental/nnx/tests/test_spmd.py | 75 + .../experimental/nnx/tests/test_transforms.py | 656 +++++++ flax/experimental/nnx/tests/test_variable.py | 33 + flax/linen/kw_only_dataclasses.py | 4 +- flax/linen/module.py | 10 +- pyproject.toml | 8 + tests/run_all_tests.sh | 12 +- 69 files changed, 12930 insertions(+), 36 deletions(-) create mode 100644 docs/experimental/index.rst create mode 100644 flax/experimental/__init__.py create mode 100644 flax/experimental/nnx/.gitignore create mode 100644 flax/experimental/nnx/README.md create mode 100644 flax/experimental/nnx/__init__.py create mode 100644 flax/experimental/nnx/docs/blog.md create mode 100644 flax/experimental/nnx/docs/images/stateful-transforms.png create mode 100644 flax/experimental/nnx/docs/quick_start.ipynb create mode 100644 flax/experimental/nnx/docs/tiny_nnx.ipynb create mode 100644 flax/experimental/nnx/docs/why.ipynb create mode 100644 flax/experimental/nnx/examples/00_demo.ipynb create mode 100644 flax/experimental/nnx/examples/01_functional_api.py create mode 100644 flax/experimental/nnx/examples/02_lifted_transforms.py create mode 100644 flax/experimental/nnx/examples/03_train_state.py create mode 100644 flax/experimental/nnx/examples/05_vae.py create mode 100644 flax/experimental/nnx/examples/06_scan_over_layers.py create mode 100644 flax/experimental/nnx/examples/07_transformer.py create mode 100644 flax/experimental/nnx/examples/08_save_load_checkpoints.py create mode 100644 flax/experimental/nnx/examples/09_parameter_surgery.py create mode 100644 flax/experimental/nnx/examples/10_quantization.py create mode 100644 flax/experimental/nnx/examples/requirements.txt create mode 100644 flax/experimental/nnx/ideas/shape_inference.py create mode 100644 flax/experimental/nnx/nnx/__init__.py create mode 100644 flax/experimental/nnx/nnx/compatibility.py create mode 100644 flax/experimental/nnx/nnx/dataclasses.py create mode 100644 flax/experimental/nnx/nnx/errors.py create mode 100644 flax/experimental/nnx/nnx/filterlib.py create mode 100644 flax/experimental/nnx/nnx/flaglib.py create mode 100644 flax/experimental/nnx/nnx/helpers.py create mode 100644 flax/experimental/nnx/nnx/ids.py create mode 100644 flax/experimental/nnx/nnx/module.py create mode 100644 flax/experimental/nnx/nnx/nn/__init__.py create mode 100644 flax/experimental/nnx/nnx/nn/activations.py create mode 100644 flax/experimental/nnx/nnx/nn/dtypes.py create mode 100644 flax/experimental/nnx/nnx/nn/initializers.py create mode 100644 flax/experimental/nnx/nnx/nn/linear.py create mode 100644 flax/experimental/nnx/nnx/nn/normalization.py create mode 100644 flax/experimental/nnx/nnx/nn/stochastic.py create mode 100644 flax/experimental/nnx/nnx/pytreelib.py create mode 100644 flax/experimental/nnx/nnx/reprlib.py create mode 100644 flax/experimental/nnx/nnx/rnglib.py create mode 100644 flax/experimental/nnx/nnx/spmd.py create mode 100644 flax/experimental/nnx/nnx/state.py create mode 100644 flax/experimental/nnx/nnx/tracers.py create mode 100644 flax/experimental/nnx/nnx/transforms.py create mode 100644 flax/experimental/nnx/nnx/variables.py create mode 100644 flax/experimental/nnx/scripts/requirements.txt create mode 100644 flax/experimental/nnx/scripts/run-all-examples.bash create mode 100644 flax/experimental/nnx/tests/__init__.py create mode 100644 flax/experimental/nnx/tests/test_compatibility.py create mode 100644 flax/experimental/nnx/tests/test_containers.py create mode 100644 flax/experimental/nnx/tests/test_helpers.py create mode 100644 flax/experimental/nnx/tests/test_ids.py create mode 100644 flax/experimental/nnx/tests/test_integration.py create mode 100644 flax/experimental/nnx/tests/test_module.py create mode 100644 flax/experimental/nnx/tests/test_partitioning.py create mode 100644 flax/experimental/nnx/tests/test_pytree.py create mode 100644 flax/experimental/nnx/tests/test_rngs.py create mode 100644 flax/experimental/nnx/tests/test_spmd.py create mode 100644 flax/experimental/nnx/tests/test_transforms.py create mode 100644 flax/experimental/nnx/tests/test_variable.py diff --git a/.gitignore b/.gitignore index 5648751009..ab2066a79c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,7 @@ build/ docs/**/tmp # used by direnv -.envrc \ No newline at end of file +.envrc + +# custom +/tmp-files diff --git a/docs/conf.py b/docs/conf.py index 14d6a8526a..0e908f1394 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -124,7 +124,7 @@ # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development -# nb_execution_mode = 'off' +nb_execution_mode = 'off' # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook @@ -133,6 +133,7 @@ nb_execution_excludepatterns = [ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 + 'flax/experimental/nnx', # exclude nnx ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False diff --git a/docs/experimental/index.rst b/docs/experimental/index.rst new file mode 100644 index 0000000000..5ad44ee587 --- /dev/null +++ b/docs/experimental/index.rst @@ -0,0 +1,6 @@ + + +.. toctree:: + :maxdepth: 2 + + nnx \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 846de11efe..2749a64321 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -323,4 +323,5 @@ Notable examples in Flax include: developer_notes/index philosophy contributing + experimental api_reference/index diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index ce83035cff..8949e0dca5 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -12,7 +12,7 @@ "import functools\n", "import jax\n", "from jax import numpy as jnp, random, lax\n", - "import numpy as np\n" + "import numpy as np" ] }, { @@ -78,26 +78,34 @@ } ], "source": [ - "def dense(scope: Scope, inputs: Array, features: int, bias: bool = True,\n", - " kernel_init=nn.linear.default_kernel_init,\n", - " bias_init=nn.initializers.zeros_init()):\n", + "def dense(\n", + " scope: Scope,\n", + " inputs: Array,\n", + " features: int,\n", + " bias: bool = True,\n", + " kernel_init=nn.linear.default_kernel_init,\n", + " bias_init=nn.initializers.zeros_init(),\n", + "):\n", " kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))\n", " y = jnp.dot(inputs, kernel)\n", " if bias:\n", " y += scope.param('bias', bias_init, (features,))\n", " return y\n", "\n", + "\n", "model_fn = functools.partial(dense, features=3)\n", "\n", "x = jnp.ones((1, 2))\n", "y, params = init(model_fn)(random.key(0), x)\n", "print(params)\n", "\n", + "\n", "def mlp(scope: Scope, inputs: Array, features: int):\n", " hidden = scope.child(dense, 'hidden')(inputs, features)\n", " hidden = nn.relu(hidden)\n", " return dense(scope.push('out'), hidden, 1)\n", "\n", + "\n", "init(mlp)(random.key(0), x, features=3)" ] }, @@ -138,16 +146,31 @@ " def attend(self, query):\n", " return jnp.dot(query, self.table.T)\n", "\n", + "\n", "# all the embedding module does is provide a convenient initializers\n", "\n", - "def embedding(scope: Scope, num_embeddings: int, features: int, init_fn=nn.linear.default_embed_init) -> Embedding:\n", + "\n", + "def embedding(\n", + " scope: Scope,\n", + " num_embeddings: int,\n", + " features: int,\n", + " init_fn=nn.linear.default_embed_init,\n", + ") -> Embedding:\n", " table = scope.param('table', init_fn, (num_embeddings, features))\n", " return Embedding(table)\n", "\n", + "\n", "embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n", "print(embedding.table)\n", "print(embedding.lookup(1))\n", - "print(embedding.attend(jnp.ones((1, 3,))))" + "print(\n", + " embedding.attend(\n", + " jnp.ones((\n", + " 1,\n", + " 3,\n", + " ))\n", + " )\n", + ")" ] }, { @@ -177,11 +200,16 @@ } ], "source": [ - "def lstm(scope, carry, inputs,\n", - " gate_fn=nn.activation.sigmoid, activation_fn=nn.activation.tanh,\n", - " kernel_init=nn.linear.default_kernel_init,\n", - " recurrent_kernel_init=nn.initializers.orthogonal(),\n", - " bias_init=nn.initializers.zeros_init()):\n", + "def lstm(\n", + " scope,\n", + " carry,\n", + " inputs,\n", + " gate_fn=nn.activation.sigmoid,\n", + " activation_fn=nn.activation.tanh,\n", + " kernel_init=nn.linear.default_kernel_init,\n", + " recurrent_kernel_init=nn.initializers.orthogonal(),\n", + " bias_init=nn.initializers.zeros_init(),\n", + "):\n", " r\"\"\"A long short-term memory (LSTM) cell.\n", "\n", " the mathematical definition of the cell is as follows\n", @@ -217,11 +245,15 @@ " hidden_features = h.shape[-1]\n", " # input and recurrent layers are summed so only one needs a bias.\n", " dense_h = lambda name: scope.child(dense, name)(\n", - " h, features=hidden_features, bias=True,\n", - " kernel_init=recurrent_kernel_init, bias_init=bias_init)\n", + " h,\n", + " features=hidden_features,\n", + " bias=True,\n", + " kernel_init=recurrent_kernel_init,\n", + " bias_init=bias_init,\n", + " )\n", " dense_i = lambda name: scope.child(dense, name)(\n", - " inputs, features=hidden_features, bias=False,\n", - " kernel_init=kernel_init)\n", + " inputs, features=hidden_features, bias=False, kernel_init=kernel_init\n", + " )\n", " i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))\n", " f = gate_fn(dense_i(name='if') + dense_h(name='hf'))\n", " g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))\n", @@ -230,10 +262,12 @@ " new_h = o * activation_fn(new_c)\n", " return (new_c, new_h), new_h\n", "\n", + "\n", "def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):\n", " shape = batch_dims + (size,)\n", " return init_fn(shape), init_fn(shape)\n", "\n", + "\n", "x = jnp.ones((1, 2))\n", "carry = lstm_init_carry((1,), 3)\n", "y, variables = init(lstm)(random.key(0), carry, x)\n", @@ -259,23 +293,33 @@ "source": [ "def simple_scan(scope: Scope, xs):\n", " init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])\n", - "# cell = scope.child(lstm, 'cell')\n", - "# ys = []\n", - "# for i in range(xs.shape[1]):\n", - "# x = xs[:, i]\n", - "# init_carry, y = cell(init_carry, x)\n", - "# ys.append(y)\n", - "# return init_carry, ys\n", - " lstm_scan = lift.scan(lstm, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False})\n", + " # cell = scope.child(lstm, 'cell')\n", + " # ys = []\n", + " # for i in range(xs.shape[1]):\n", + " # x = xs[:, i]\n", + " # init_carry, y = cell(init_carry, x)\n", + " # ys.append(y)\n", + " # return init_carry, ys\n", + " lstm_scan = lift.scan(\n", + " lstm,\n", + " in_axes=1,\n", + " out_axes=1,\n", + " variable_broadcast='params',\n", + " split_rngs={'params': False},\n", + " )\n", " return lstm_scan(scope, init_carry, xs)\n", "\n", + "\n", "key1, key2 = random.split(random.key(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "\n", "y, init_variables = init(simple_scan)(key2, xs)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n" + "print(\n", + " 'initialized parameter shapes:\\n',\n", + " jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)),\n", + ")" ] }, { diff --git a/flax/experimental/__init__.py b/flax/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flax/experimental/nnx/.gitignore b/flax/experimental/nnx/.gitignore new file mode 100644 index 0000000000..2a90c3eca6 --- /dev/null +++ b/flax/experimental/nnx/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# project specific +.vscode +/tmp \ No newline at end of file diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md new file mode 100644 index 0000000000..5cc033c3aa --- /dev/null +++ b/flax/experimental/nnx/README.md @@ -0,0 +1,440 @@ +[![codecov](https://codecov.io/gh/cgarciae/nnx/branch/main/graph/badge.svg?token=VqJjL474Z7)](https://codecov.io/gh/cgarciae/nnx) + +# NNX + +_**N**eural **N**etworks for JA**X**_ + +NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/). + +* **Pythonic**: Modules are just regular python classes, they contain their own state, are fully mutable, and allow sharing references between Modules. +* **Compatible**: Easily convert back and forth between Modules and pytrees using the Functional API to integrate with any JAX API. +* **Safe**: NNX incorporates mechanisms to try to prevent tracer leakage, avoid stale RNGs, and ensure proper state propagation in order to help produce correct JAX programs. +* **Semantic**: Partition a Module's state into different semantic collections, allowing for fine-grained control when applying JAX transformations. + +#### Table of Contents +* [Installation](#installation) +* [Getting Started](#getting-started) +* [FAQs](#faqs) +* [Examples](#examples) +* [User Guide](#user-guide) + +## Installation + +To get started with `nnx`, install the package via pip from github: + +``` +pip install git+https://github.com/google/flax.git@nnx +``` + +## Getting Started + +The following example guides you through creating a basic `Linear` model with NNX and executing a forward pass. It also demonstrate how handle mutable state by showing how to keep track of the number of times the model has been called. + +```python +from flax.experimental import nnx +import jax +import jax.numpy as jnp + +class Count(nnx.Variable): pass # typed Variable collections + +class Linear(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): # explicit RNG management + key = rngs() + # put dynamic state in Variable types + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + # other types as treated as static + self.din = din + self.dout = dout + + def __call__(self, x): + self.count += 1 # inplace stateful updates + return x @ self.w + self.b + +model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no special `init` method +y = model(jnp.ones((8, 12))) # call methods directly + +assert model.count == 1 +``` + +In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, this is used by `rngs.()` inside `__init__` to generate a random key to initialize the parameters. + +### Training with the Functional API + +The [Functional API](#functional-api) converts an NNX Module python semantics into pure pytree object with functional semantics. It is the recommended way to use NNX as it provides tight control over the state, allows you to use regular JAX transformations, and it minimizes overhead. In this example the model will be trained using Stochastic Gradient Descent (SGD). + +```python +params, counts, moduledef = model.split(nnx.Param, Count) + +@jax.jit +def train_step(params, counts, x, y): + def loss_fn(params): + model = moduledef.merge(params, counts) + y_pred = model(x) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, updates.extract(Count) + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + +# execute the training step +params, counts = train_step(params, counts, x, y) +model = moduledef.merge(params, counts) +assert model.count == 2 +``` + +### Training with Lifted Transforms + +[Lifted Transforms](#lifted-transforms) provide a convenient way interact with NNX Modules. In this example, we use the `nnx.jit` and `nnx.grad` lifted transforms to define the training step. The model is trained using Stochastic Gradient Descent (SGD). Because lifted transforms automatically update the Module's state, `train_step` doesn't require a return statement. + +```python +@nnx.jit +def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + params: nnx.State = model.extract(nnx.Param) + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, , params, grads) + ) + +# execute the training step +train_step(model, x, y) +assert model.count == 2 +``` + +**Note**: Using `nnx.jit` introduces some overhead when compared to using `jax.jit` directly. Use `nnx.jit` for simple prototypes, but for production code use `jax.jit` directly. + +## Examples + +* [Using the Functional API](https://github.com/cgarciae/nnx/blob/main/examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Using Lifted Transforms](https://github.com/cgarciae/nnx/blob/main/examples/02_lifted_transforms.py): Shows how to train a simple model using lifted transforms. +* [Using TrainState](https://github.com/cgarciae/nnx/blob/main/examples/03_train_state.py): Shows how to train a simple model using the functional API with the help of `TrainState`. +* [Training a VAE](https://github.com/cgarciae/nnx/blob/main/examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset, uses the functional API, `TrainState`, and shows how to use capture intermediate values to retrieve `kl_loss`. +* [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Creating a Transformer](https://github.com/cgarciae/nnx/blob/main/examples/07_transformer.py): Shows how to create a Transformer with an auto-regressive decoder that uses scan over layers and a kv-cache for fast inference. Credits to @levskaya. + +## FAQs + +### Status +NNX is still in early development so expect bugs and breaking changes. + +### How is it different from Flax? +NNX takes the best features that allow Flax to scale to large projects and integrates them into a much simpler Module system with pythonic semantics. + +One place in which NNX strongly deviates from Flax is that (currently) it avoids shape inference in favor of static initialization. It is not a technical limitation but rather a design choice. This design both simplifies the internal implementation and makes it easier to reason about the code for the user, at the cost of being more verbose at times. On the other hand, Pytorch users will feel right at home. + +### How is it different from Equinox? +While they might look similar at a surface-level, NNX's Module system is more powerful and flexible than Equinox's, it contains the following additional features: + +* Uses regular python classes (no mandatory dataclass behavior). +* Modules are mutable +* Reference sharing between Modules is allowed +* Mutable state lives inside the Module (no need for a separate [State container](https://docs.kidger.site/equinox/examples/stateful/)). +* Supports node metadata and semantic partitioning. + +One major difference between the two frameworks is that, by design, NNX Modules are not Pytrees. This adds a safety layer as it prevents state updates from being lost by accident due to referential transparency. It also removes the need of threading a separate [State container](https://docs.kidger.site/equinox/examples/stateful/) throughout the code in order to propagate state. In NNX state updates are either always preserved or explicitly discarded by the user. + +## User Guide + +### Modules + +NNX Modules are normal python classes, they obey regular python semantics such as mutability and reference sharing, including reference cycles. They can contain 2 types of attributes: node attributes and static attributes. Node attributes include NNX `Variable`s (e.g. `nnx.Param`) and sub-Modules. All other types are treated as static attributes. For convenience, `jax.Array`s and `np.ndarray`s are casted to `nnx.Param`. + +```python +class Foo(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + # node attributes + self.variable = nnx.Param(jnp.array(1)) + self.implicit_param = jnp.array(3) + self.submodule = nnx.Linear(2, 4, rngs=rngs) + # static attributes + self.int = 1 + self.float = 2.0 + self.str = "hello" + self.list = [1, 2, 3] + +model = Foo(din=12, dout=2, rngs=nnx.Rngs(0)) +``` +As shown above, python container types such as `list`, `tuple`, and `dict` are treated as static attributes, if similar functionality is needed, NNX provides the `Sequence` and `Dict` Modules. + +### Functional API + +NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `ModuleDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `ModuleDef` contains the static attributes and structural information needed to reconstruct the Module. + +```python +state, moduledef = model.split() +``` +``` +State({ + 'implicit_param',: Param(value=Array(3)), + 'submodule/bias': Param(value=Array(...)), + 'submodule/kernel': Param(value=Array(...)), + 'variable': Param(value=Array(1)) +}) +``` + +`State` and `ModuleDef` are pytrees so they can be passed to JAX transformations. More over, `ModuleDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object: + +```python +model = moduledef.merge(state) +``` +This can be use to e.g. recreate a module inside a JAX transformation. The `apply` provides a functional interface to the module, it can be used call any method or submodule and get the output and the updated state: + +```python +# run __call__ +y, (state, moduledef) = moduledef.apply(state)(x) +# run some_method +y, (state, moduledef) = moduledef.apply(state).some_method(x) +# run submodule +y, (state, moduledef) = moduledef.apply(state).submodule(x) +``` + +`apply` can call any nested method or submodule as long as it can be accessed via the `.` or `[]` operators. + +### Partitioning State +In NNX you can filter based on any node type, most commonly you will want to filter based on `nnx.Variable` subclasses such as `nnx.Param` or `nnx.BatchStat`. + +Here are various examples of how you can use the `split` method to split a module into multiple substates: + +```python +# split the module into the state with all the nodes and the moduledef +state, moduledef = model.split() +# verify that the state contains only params, else raise an error +params, moduledef = model.split(nnx.Param) +# split the state into params and batch_stats, verify no nodes are left +params, batch_stats, moduledef = model.split(nnx.Param, nnx.BatchStat) +# if there are any nodes left, use the `...` filter to capture them +params, batch_stats, rest, moduledef = model.split(nnx.Param, nnx.BatchStat, ...) +# using `...` as the only filter is equivalent to not passing any filters +model.split(...) = model.split() +``` +`split` will make sure all nodes are match by atleast one filter, else it will raise an error. You can use the `...` filter which will any (remaining) nodes. For a more general filter you can pass a predicate function that can use both the path and value of the node: + +```python +(path: Tuple[str, ...], value: Any) -> bool +``` +To reconstruct the module from a set of substates, you can use `merge` as usual but passing the substates as additional arguments: + +```python +model = moduledef.merge(params, batch_stats, rest) +``` + +The same is true for `apply`. + +```python +y, (state, moduledef) = moduledef.apply(params, batch_stats, rest)(x) +``` + + Note that `apply` will return a single `state` object, if you need to `split` the state you can use `State`'s own `split` method: + +```python +params, batch_stats, rest = state.split(nnx.Param, nnx.BatchStat, ...) +``` + +Alternatively, if you are just interested in a subset of partitions, you can use the `State.extract` method which will not raise an error if some nodes are not matched by any filter: + +```python +# only get params +params = state.extract(nnx.Param) +# get params and batch_stats +params, batch_stats = state.extract(nnx.Param, nnx.BatchStat) +``` + +### Filters + +Filters let you select subsets of nodes based on some criteria. These are use throughout the API in methods like `split`, `extract`, and `pop`. There are 4 types of filters: + +* `type`: matches all node instances of the given type. +* `...`: matches all nodes. +* `(path, any) -> bool`: a predicate function that takes a node path and value and returns a boolean. +* `Tuple[Filter, ...]`: a tuple of filters, matches all nodes that match any of the filters. + +NNX also provides the following custom filters: + +* `nnx.Not(filter)`: matches all nodes that do not match the given filter +* `nnx.All(*filters)`: matches nodes that match all filters + +Here is an example of how to use `Not`: +```python +non_params = module.extract(nnx.Not(nnx.Param)) +``` + + +### Capturing Intermediate Values +In NNX you can easily propagate intemediate values by simply assigning them to an attribute at runtime. For convenience, you should assign them to a `Variable` attribute with a `collection` name by using `nnx.var` so you can easily retrieve them later. + +Here is an example of how to create a `Linear` module that captures its output into a `Variable` attribute with the `intermediates` collection name: + +```python +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b + self.y = nnx.Intermediate(y) + return y + +model = Linear(12, 2, rngs=nnx.Rngs(0)) +``` +Since `y` is only created when the module is called, it is not available upon initialization. However, once you call the module `y` will be created. It is recommended that you use `pop` to retrieve temporary collections like `Intermediate`: + +```python +y = model(jnp.ones((8, 12))) +intermediates = model.pop(nnx.Intermediate) +``` +`pop` will return a `State` object with the nodes that match the given filter and remove them from the module's attributes. + +``` +State({ + 'y: Intermediate(value=Array(...)) +}) +``` + +If you use the functional API to call the module instead, the `Intermediate` nodes will be present in the output `state`. To retrieve the `Intermediate` nodes and optionally separate them from the output `state` you can use `State.split`: + +```python +state, moduledef = model.split() +y, (state, moduledef) = moduledef.apply(state)(jnp.ones((8, 12))) +# "pop" the intermediates from the state +intermediates, state = state.split(nnx.Intermediate, ...) +``` + +Alternatively, you can use `State.extract` to retrieve the `Intermediate` nodes without removing them from the `state`. + + +### Lifted Transforms + +NNX lifted transforms analogous versions of JAX transforms but they know how to work with Modules. They usually perform the following tasks: + +* Handle the Module's substates and Rngs's RNG streams according to the transform's semantics. +* Properly propagating state in and out of the transform, including updating the input Module's state with updates that happen inside the transform. + +Here's a diagram illustrating how lifted transformations work: + +![lifted-transforms](https://raw.githubusercontent.com/cgarciae/nnx/main/docs/images/stateful-transforms.png) + +Currently NNX provides the `jit`, `grad`, `scan`, and `remat`, lifted transforms. + +#### Manual Lifting + +In case you want to use JAX transforms directly you can always use the functional API +to manually lift your Modules. + +Here we will create an example of how to implement an MLP that uses "scan over layers" to efficiently process a sequence of inputs assuming that each layer has the same parameters and input/output dimensions. The first thing we need to do is create a `Block` module that represents a single layer, this block with just contain a `Linear` layer, a `Dropout` layer, and a `GELU` activation function: + +```python +class Block(nnx.Module): + def __init__(self, dim: int, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(dim, dim, rngs=rngs) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: + x = self.linear(x) + x = self.dropout(x, deterministic=not train, rngs=rngs) + x = jax.nn.gelu(x) + return x +``` + +Now we will define `ScanMLP`. During `__init__`, instead of creating a list of `Block`s, we will use `jax.vmap` to create a single `Block` whose parameters have an addtional `layer` axis. This will allow us to pass the parameters as inputs to scan so it will apply a layer at each step. + +```python +class ScanMLP(nnx.Module): + def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): + params_key = jax.random.split(rngs.params(), n_layers) + self.n_layers = n_layers + state, moduledef = jax.vmap( + lambda key: Block(dim, rngs=nnx.Rngs(params=key)).split() + )(params_key) + self.layers = moduledef.merge(state) + +``` +Note that we split the `params` key into `n_layers` keys so each layer has different parameters. + +Now we will define `__call__`. Here we need to split the `dropout` key into `n_layers` keys so each layer has a different dropout mask, and `split` the layers to get their `params`. Both `params` and `dropout_key` will be passed as inputs, `x` will be the carry value. Inside the `scan_fn` we will merge the `params` back into a `Block` module and +apply it to the input `x`, passing the sliced `dropout_key` as part of the `Rngs`. + + +```python + def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: + dropout_key = jax.random.split(rngs.dropout(), self.n_layers) + params, moduledef = self.layers.split(nnx.Param) + + def scan_fn(x: inputs): + params, dropout_key = inputs + module = moduledef.merge(params) + x = module(x, train=train, rngs=nnx.Rngs(dropout=dropout_key)) + return x, module.extract(nnx.Param) + + x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) + self.layers.update(params) + return x +``` +Finally we apply `jax.lax.scan`, update the `layers` state with the new `params`, and return the final `x` value. + +Here is a simple way to test our `ScanMLP`: + +```python +model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) + +x = jnp.ones((3, 10)) +y = model(x, train=True, rngs=nnx.Rngs(dropout=1)) +``` + +For a more robust implementation with comments take a look at the [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py) example. + +### Case Studies +#### Shared State + +In NNX, you can create modules that share state between them. This is useful when designing complex neural network architectures, as it allows you to reuse certain layers and reduce the number of learnable parameters. + +Here's an example of creating a module with shared state: + +```python +class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x, *, rngs: nnx.Rngs): + x = self.linear(x) + x = self.bn(x, rngs=rngs) + x = nnx.relu(x) + return x + +class Model(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x +``` + +In this example, the `Model` module contains two instances of the `Block` module. Each instance shares the same `nnx.Linear` module. To run the model, you can use the Rngs `flags` argument to set the `use_running_average` flag for all `BatchNorm` modules. + +Here's an example of computing the loss for a `Model` instance: + +```python +def loss_fn(model: Model, x: jax.Array, y: jax.Array): + with nnx.flags(use_running_average=True): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) +``` + +It's important to note that the state for the shared `nnx.Linear` module will be kept in sync at all times on both `Block` instances, including during gradient updates. diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py new file mode 100644 index 0000000000..adfad6b761 --- /dev/null +++ b/flax/experimental/nnx/__init__.py @@ -0,0 +1,98 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from flax.linen.pooling import avg_pool as avg_pool +from flax.linen.pooling import max_pool as max_pool +from flax.linen.pooling import min_pool as min_pool +from flax.linen.pooling import pool as pool + +from .nnx import compatibility as compatibility +from .nnx.dataclasses import dataclass as dataclass +from .nnx.dataclasses import field as field +from .nnx.dataclasses import param_field as param_field +from .nnx.dataclasses import treenode_field as treenode_field +from .nnx.dataclasses import variable_field as variable_field +from .nnx.errors import TraceContextError as TraceContextError +from .nnx.flaglib import flags as flags +from .nnx.helpers import Dict as Dict +from .nnx.helpers import Sequence as Sequence +from .nnx.helpers import TrainState as TrainState +from .nnx.module import M as M +from .nnx.module import Module as Module +from .nnx.module import ModuleDef as ModuleDef +from .nnx.module import merge as merge +from .nnx.nn import initializers as initializers +from .nnx.nn.activations import celu as celu +from .nnx.nn.activations import elu as elu +from .nnx.nn.activations import gelu as gelu +from .nnx.nn.activations import glu as glu +from .nnx.nn.activations import hard_sigmoid as hard_sigmoid +from .nnx.nn.activations import hard_silu as hard_silu +from .nnx.nn.activations import hard_swish as hard_swish +from .nnx.nn.activations import hard_tanh as hard_tanh +from .nnx.nn.activations import leaky_relu as leaky_relu +from .nnx.nn.activations import log_sigmoid as log_sigmoid +from .nnx.nn.activations import log_softmax as log_softmax +from .nnx.nn.activations import logsumexp as logsumexp +from .nnx.nn.activations import normalize as normalize +from .nnx.nn.activations import one_hot as one_hot +from .nnx.nn.activations import relu as relu +from .nnx.nn.activations import relu6 as relu6 +from .nnx.nn.activations import selu as selu +from .nnx.nn.activations import sigmoid as sigmoid +from .nnx.nn.activations import silu as silu +from .nnx.nn.activations import soft_sign as soft_sign +from .nnx.nn.activations import softmax as softmax +from .nnx.nn.activations import softplus as softplus +from .nnx.nn.activations import standardize as standardize +from .nnx.nn.activations import swish as swish +from .nnx.nn.activations import tanh as tanh +from .nnx.nn.linear import Conv as Conv +from .nnx.nn.linear import Embed as Embed +from .nnx.nn.linear import Linear as Linear +from .nnx.nn.normalization import BatchNorm as BatchNorm +from .nnx.nn.normalization import LayerNorm as LayerNorm +from .nnx.nn.stochastic import Dropout as Dropout +from .nnx.filterlib import All as All +from .nnx.filterlib import Not as Not +from .nnx.pytreelib import Pytree as Pytree +from .nnx.pytreelib import TreeNode as TreeNode +from .nnx.rnglib import Rngs as Rngs +from .nnx.rnglib import RngStream as RngStream +from .nnx.spmd import PARTITION_NAME as PARTITION_NAME +from .nnx.spmd import get_partition_spec as get_partition_spec +from .nnx.spmd import with_partitioning as with_partitioning +from .nnx.spmd import with_sharding_constraint as with_sharding_constraint +from .nnx.state import State as State +from .nnx.transforms import JIT as JIT +from .nnx.transforms import Remat as Remat +from .nnx.transforms import Scan as Scan +from .nnx.transforms import Vmap as Vmap +from .nnx.transforms import grad as grad +from .nnx.transforms import jit as jit +from .nnx.transforms import remat as remat +from .nnx.transforms import scan as scan +from .nnx.transforms import value_and_grad as value_and_grad +from .nnx.transforms import vmap as vmap +from .nnx.variables import EMPTY as EMPTY +from .nnx.variables import A as A +from .nnx.variables import BatchStat as BatchStat +from .nnx.variables import Cache as Cache +from .nnx.variables import Empty as Empty +from .nnx.variables import Intermediate as Intermediate +from .nnx.variables import Param as Param +from .nnx.variables import Rng as Rng +from .nnx.variables import Variable as Variable +from .nnx.variables import VariableMetadata as VariableMetadata +from .nnx.variables import with_metadata as with_metadata diff --git a/flax/experimental/nnx/docs/blog.md b/flax/experimental/nnx/docs/blog.md new file mode 100644 index 0000000000..5c3b2437e1 --- /dev/null +++ b/flax/experimental/nnx/docs/blog.md @@ -0,0 +1,7 @@ +### Do we need another JAX NN library? + +Hello, today I want to talk to you about a new JAX library that I have been working on, but before I do that, I wanted to discuss the topic: Do we need another JAX NN library? + +### JAX Libraries + +JAX NN libraries come in a wide variety ranging from functional like Flax and Haiku, to Pytree-based like Equinox. \ No newline at end of file diff --git a/flax/experimental/nnx/docs/images/stateful-transforms.png b/flax/experimental/nnx/docs/images/stateful-transforms.png new file mode 100644 index 0000000000000000000000000000000000000000..d7002fc163631c1a6eb16bee4ef2eeaa9b28fa60 GIT binary patch literal 304812 zcmeEv1yq$?w=N(`3rKf~bSMJSr67_@*GB0M=>`dvRtcp+LQ1+*QfZ{SySw48t$u%; zb2ffw{AZl|-*G<&v3+CZoX?!|S#vF3%gah&pc13P!NFljKD_?~{BnkaL-I$y2CkSM zve?4Gp$i#{iOEZfiBZZww=^&|(}#n5_&Qt#Nmbz+ew@bJP+?>&Vc|!MGNH0?3{-`1 z-vge~+=Lf>eH%-0busCMDFNDGmlAg4^;cwh1<3o=s#x;8a(I0v?&CLa<-M;+>F5=7n z#@v{U=htD=>vys}_AdsKz7fpszzOS6$}TMIY*Pv%!a=^_NTI?}wm9xET-71+!!we4 z>WRe`HYI)abJ&#a)p5EK$yd5!qv)Rg#F0$j5QVESwgW{Ww|eC_yay#42>js^Zd2~J zwnjI+yoaGfVnpMon%XISAwsuf8MP!9rD-vbrQG%aqWHq+Xj*f<@PdQvqGERL5t5tg zmy8Hnlz?~bMDuI*9m&W<%HE|aVdQuidH4+!KCWc~4fM29#1WrzB`}hv zcrm>>2|seCZ?1>k#?jK*kmw35?-~ z$=3OYER?}GuSBol9jiX`yLnUASUW0G=$^KjO++Y8N&^yV{*5}^qc2fTnog%ytg6w% z4Ub5Bv1dbGlg6NFW8r4K=j&&|qHSWb6MTh+FoUaW?SGG3+&Cl}<++%CB~gRLk?&SZ zp5qhv(NI)3elbm~cL}d>*jncOn(3#%M0!1}y(Z!bpPZZOElEj&tl_faG8LOdhiK)6 zsD}V=C4a}8=IxvSwPO&{gGsqph?E{inhBLEaHv!lcCVB)$%OqXOXLNqk@&suVcz%D zbYu$;xq;y%g&;>obNbSOHH)9nk~J&;+@ljSe}V{Kk2*Ziwu{D_%SaOa)-~+JO%_kVtO(es$9MKk&0J=!R$ zN%@RdDVyTUmQq#=I8rO_#h1P(Vx8-u`b;ws`o3H5_9qNB>rLp_zRK!_@duI$7aut# zm&cZ$g!(fN2+d1in;y$o4~Yo~ofuw(x(S-mxED9z>mcRr3>vRe+oNY9p|9TT_bhta ztnX;#AP~OsS;W$M+wplq@FvCiH5KKEZ?0*?Y2;~a{?!Oy6o>eTbLn2EN&0+HyNZr~ zN{viNirs0ASl{@q5f5(Y;KA0hGY->*1!XzQM46KZ+@#C%r%xvr9av9q!kq|A&3Mf6 z8P@y6Usr--C7qI_^+mlNsfE1BBH}FSu2QyM?KMJ<~A6m*KaEDHiJs7NM4lkc z`@?VezS1NkIR&|j-d4ou6AzMSvc1`IU(EyW zm|Y`yQ8r%PPh#zkT11I&#MR-l=C!_UeZ`t>3pX$J&}UGN`pvH7{!Q&CcV9~}Qig~2 zzp;NO`}Xr&UP%f?MI~uPO+`_~CrbAj8yV*qD->lvX(}x}Ig;+D)OcIV0k*1d0^t z?utE63KVoRgB0Fnusj|~w12b~JWRqEZt?&lA@vE`C#g?;y&SzSdoh=7?Ii4id9~e& zoqe5SkCpZ!@7`&D+ms}efY;q(jAX20R7I?Z!GuwKO+`0C+_wokNQ3w+%2vj(-Kd?v zy)mldE}MzH>6^l=!WTt{>N`cnCb8YlqRor*wp_6i@e&i_LQ)&Ow@frlZ3+vDS_==1 zN=y!xjFw85ZZ1jprFZds?nw+-QeS#*YhxF=Ox{(Meq#K-o4ud?OYVbulb2d#`w+`g zh-aK|j-qAGI7XG+G5XolGt#rU2Kh4K?RFYTEInM! z_TJ{vWF`w1k@_mn92&`s4sr>GEpsg?L~o6-+tI*Am$hA!)vbQHCiCEU8=^er4~Sz64C-Epm>mydRLT z#N0sGxH9f>bnkHXaCLi$TwmZ3yUEBWx@v=J%?XxjGI#WvyEX0ZWap{pniq^0&G7T^ zr=Dx@TFe^CT&CPB`UwgNH#e;&U5dF~Zk;G%;o}l36892k;l&VE3LZXtrS{O^q0B?= zPcDnQT?)^9a$n`nJuA5c!^RVKfq~45?t?l7r_{R+i zHSyO;%O6GV(mCtzdK0d^|8`qK_jTUwJiNTl>6jUVhS-p_X1Tx#ybD$;_GiX*RrjOs z8|x_CU@_KwV)y-IHenX;8=0?u^?VNqe_0!1EXn}dD2e9dJcOhgVg zm3xQjGi_9>h#)H2pC?H*;c@ZhDZx>FU>JfA&PHHoyTvgLwU z{Q~dBXY`3cJfg7=)gPERT=Gu~S_+(f&8D2EGFNj~Wrj9vl5Ff}(PwWhthXzClWOIo z{(fWbRTLeOToiA*mKhr>wuAJ3T&45#_?x2n0uYahLX_90v_)42F(^{9Yv z!KY)0C-M%r;rd3SY$KEZM2A$|z`Y5h^Oz$!ca*b~)caNUpFB79>U<%Cw&h#rn;oD@ zkgc(s9j5J^-j<$Kx!S5TJ}p(ldr`hW;6+O$r5@?bJXRW5DqXlb|NLT(PsQjnz515k z++w(0CSm3yl`1U(>%GS3CEGHl^rp$Cy#3;z%~cKaEQZS2RtowjJ|SlfY8b1QecC>| zcJ7x+a)qnk;?A&Q@lub$)clBsXi2Yk?9+~47xm9K-17t~L zWPCIBTq-WgDjRo3iwdu0zvKkm+Ol2Q-3+k;$n9Yv3BJo`6O;bZp33pxWOyc zyUn}kB{QZHc`5~Kw%BrN-(H_#pTphcyM+`eLf?BAdmIlx&2*DBONS_t4)9d(BF;6O z;#{E%kll~J5x*=D>vSyH(~#<7*c4MXPC7E`%r(90YeTir5n~uD9a^ET|nU!l7V$vS2X$+k^O7SoYw|9eq|ub1P^$aQu$E;G?;|?K4V8b2AGYUPl4yA5ZXt&(OfcMFL@n5fl{|Qh7G_>MnV{>qDV0GYNwR~>K zcAJNXhmHLX+nqZs;0YERCkxwWjw}{7G(QXZqn!KtHoDJ^t!#}gEhwSoKGU(ZvlXDG zhF0{qe?Qx)?`Zt@nk;O7bqjQm4SI*|HY+>Z-^vEB@JYrm z!@+eg))f0g4d=Rdvr54Zl~t4cQd&&4dwL8Z2Wf1j{l_5S0{UoY~r zLHqs>WASsKe_RD4Er`m`_P1FRL?u!4flei<@qIZ(@EeF3^xu_6@Pq#6Z}1rbE3B^l zu@4-aFr4Il5k*J%)iKP_7=1USO^UFj#yF8wj6Ryz2%?Wr%W$T6vW3q%XjHK5$xPJILdCmJS%k>ScH?7o^i$%~EFH`zuqfqe zDHoM!^~mT}xukJ|pm9fSPv~Gb?g_j|&XcHQ0>rAf zgl2uFdsj5|^?7)5dzsa9au=!%F~wf!MQ7Is}Jprn+$l5EJ% zv;yRf*f53P1P^OBBfQ4|c3pR`B36Bmky2860Z|g|hb?nQ$jHuf98$5cC+QI=>5=hU z2-fZ+XMTSgm{Qft2xN*WtsQpnjKN5z@I+>0AXa7gWZ}T<-8MGXZi`_AFKihP10{uj zzgb1O7{2R>D*;pOJ}C)zWh%=?!yd`+xS(G)tV&8K&Pph1uytmZ2GrHbY$pn{E_%Gd zG8;ouLD$l%YUjOWqayW`0riIxK)z$uKKa0u?=DLqhwf-2@i3nd_7n~Qav@Mx76kK| zE@4p0wd*q9U`hl$LOm7bA}6U$KP^nPNx;I|-4=v7FJNXvKLKq_@_vjDy9`XwRyW}y z*tK}VdGP{GSe7q@4e<#?pp^Fa?-*fAE~20OWLKhCp5;Wf>_Vn;w*Q+qZ#qonV2dB( z@8t@%7C?(lNCb_-AdAVrq*0etkBs#dDh=^!%dn*Zrgr^knU^1sdrn)cmRpxU2Xk^L zDS()`8l>aE1T&Zt_OnJ}pyYjBUl^icK9@`AWj<2eczWu35@EN}-`EsHLSR#2dg+iY z767S=_cE7*say|5p6cjO&(F^frvwHCHO|iJq3Wl5!E6L*tXDITu}VDH^C3nF)B#bf zpBm<9c?y37*hefN*a#b|%5kuEmpi^GU7~l0*MPO}41?NuyP&;D1{;@=Z)v340#B%MXNn3lHa zFYzBT!JuTvaKScQ$qWG6=QT4jZtP~<$d`pV1;5u`d8i3|^ltY>g39-q$D}X?D?u4d zr8BtQFznXSRgs+J=8yMaD!yWs(5wQrbZ&CQUJlk}#!-lHCX4 z|d2b=e_UU4-*$L^c~KW|?1D7%9opC}HKY z7%hE2q9t|NkO{{FsX9hCbAheMs0cVXIJddE2UzvZ%$RS{(J9bNVZdwzba?o0Q7+D9 z!}ci%5kOa{2=nnT5maIp;2A7#F~q^dkFHpN9}cpgsbN<{$xaVy68Xj#3A2>kKmZf1 zuh!^c4gh#n07Nx=p7Z!HpAqgsC25Rv~(RIvH_M@ZM!*4D7QYn#@To}Lb3gMYo;zsw&9M}7-a|DNU# zC<24bceED}$q)%^n5yZe4s_jQu8;^Ogi*@;;cQow{&Kc^m=%@euxe6?>gK403PMDXd3)Vvb%1qr~ z=mTIjs0-@l3MLG{h4};shsgj0*wGJr!+fR-9h73wCK(QMY~c~a=z&-6BQ|gc*66Yq zf@;M0oP%I{|H#R#xegoSgVk;>X>BsmT#ds;rhGX9u!)QRFG9de@EzRp`aKPH$*h4} zMo!V{(N1eNemJD5WJ`-s5n-;vTsyaO1b7PEw5*gC@u zn5ZJSU|N`6hw3i~Yl)Mq3og-LCQwS{xGZd<1H>dnVAci{S#QFGt-BBagIhGcumQK8$NR@OiCr{}O8S_dd|AOWEO%MJT80-M-!2&}$Ql96@Zms@!OggrOh#gXB7Bhxp=|}d;ombV>B-#UZ2W?n!MjSu)-|mDG_kgBd|vv!yEwcsxy@Afp2Gl z`HWX8irIK2kAb1#h<->&NOJ^>MpS>De80ne-P**2Q9WDvUoZEsmBK&A3H%ki|GC)xX=j{(_3YO-`^P__ zsBl*uZ3&r_WOumS&NNHAMlcpZ=@J%GJ{p)i^DwTZ(~q+}76CV7|W8 zog#*0%|8+ZIg>~ky6;}DbR#^WXS6Jf&T&BOip;_C*V^VHd zTr^rL!vBu_cV5_y_iDx^!Bzqn(9201J<&$egfIX%QL~m@*Pv5;$O~LO@&Q_UwsaLeA4r zXjIVaaM}8z*Hy&pWW2Z_#9h(MGwI79BqCb+E5wC4q(4OMe<#E(HrHtD>@2ZVE3wcr z8!ff|E5!ZZ3UPsBYpyeq<6vNLD>W^RuTk#Dy3+E$s20NMf6MUQ^JeqhOvD*bAld(B)W)%Pj7{*K%yG$Hnp8^wOebs-HOgkQ_x>4ssDFZQmb4Jn)cU5_^K?23mN)9YmNNt8)4#qWUbcy+F)Z6 z^VbHOzc$$X&uy?7{N-#1X#8}xZ?}so?IA)#M{pr%Z)HGeikR&ZtlkEeu0=Pv{1P<% zDnB`xFz}O#LFenc0UNo0M;HE^W$gln?Vc?QOambX)_ge&KMvf3-C^Nn1eb&LptX?_ zhnCSoV}&fOisGH&OE!S~f3eDW2~`1Z{XWwlHoE|o(m%0JhgtmTLPxCi;pWsnOQuRr zuySV6U#Y+UcT<0BJUmv1i!yso{4hwsvUlOXswF81g5A)a7rq}pe0U!hx88e?lM^pn zIrCpH_y7G(qEXaVyNlh2{Uz0S2jec5qc)uaTS1ntG~|wxUiWMl<@rxH>(_)RoEH-h z<1dcl?WN>ghI^))Jx-fPE@EuIOFpg?oC{N!=IL~jAYs|c8ahfiJCOtk|9T=pAkE>3D3syFZ_Txw3nHk!avL118D!#BEkQYPgrql=^6 zG@+`q%0tM>=z<`32memn{)#@^xYsYB@M=ILcba`uda_$bd7dEd?xwDP{v`w;gN~0?cEOEm8^R+fgICm?#pR6R-}1;mL@C&&EZcU zsj0eJb!=#A2s6qPHUS)Cy3Bn6H*EU4s}&9#!6TM60)iJuoz3PYWBrjfvV3cWgd`*Z zXso4_TxNq<3$YH67^^xIH=fp1P?ymQ;Vr3VkCPG)5Yg@LLw_?X}3sfwn(-zRH_(d1&?NXTNE&XosM9(nymFIoqa1c9Dg@&w!M7t$13;! zkfDQhic|#kX11jpldQmAZ(K2if(-#A@*}Ny}ken0Het*Jq^{FMXBrTmoWclXTfV8JjiLN z%%~+i*{Xg1T+?N($a<%po6Sl|UCPpRhGfRv(nW)Rfu~b4GdC(@2xOx`#cV6fP#`mv z%)IBm*YrEr4ln}uwi27_1J)93-|F{2J+GnFl@|&Mi#41W9-R!N=dn+f^S- z3s(BI?gzGFkh>pFl`bX)M6d{JAL7O*IG!IZa0flA`q|!EYOwB(?-kya@&bZ_A;ti^ z7k}g-o#B<9?T^fW`F-8#zF%-A@F}EX0sk;6Ha6&>=HhG(Ya}HSLOh`D;YP7%7MmfA zjp}qVGC|%*hv0fTLlF|dR(-GUNuSa?VeDm>?f?ib5+-?;ar3fy+(G5S1MJtF;%wsLY+@_WkY87=)0Xg$Q5S~h;1#`L z`a|LVb9c(XWX#?9Mzu2Y1&Jb%DW~#|;Bj||LC?`#q&E46Q>TlLo?ddEz!@mtt#V1m z(m%)JY-NqyzlLm|?R=VzcqF8bfXxcH7WSnp4(zx#V2jz zPc(sivZ?JmyM?F;J{QUH&Atj|CS6ylEA#pT7r>6-dbw;=m7Fo8hg>-HJ@2b}KG$*& z2<~&T1=2t8%Ig7WC<{Vs%75;{uPa`D&_o7T&wADKU^2exw^FQ2J`bPYxwlr>Go2c3 z;Sbho#f$7N&JLP2CXV~<*Mw9{v6?VVW-m^I`PN^xvQ{wr%h?})c?nV~OlA*GDK54s zW=?lHDIV$qq1URMW3eRqI#h4}3Cso#!9PnC4r28|)kdS5VJu!_jZ5UomcL3Rwcyb= zIo5-`j=$wdE{Xz0N4psbn`gNLJ&Zsvw7yGxOHu?Huwc4QUKuKJ=~Gf^tQg<>?AW3Z z_trLNXJ=>4DeNi>^fU==^G3G?zKPckl`~{pQ z-|0HfvBgxg&2)vwW$-%C3diaHAW|=CM<&kqOZSIOk)AB~sc4gV9KQ6p7g|8}=nrNm z7n(s(C#VcN4}=F~htJ(@4|H+b{DL%7l;dXZx?d3AUfN1mtn4bYyo#)AXACIXfBiV$o*dKg&{RITRdm(&F8ZYg^mrW0j< z%LPUI@8p7hFsDPhPWlRJ{;JRSatudw)pVgY@N{hAOe@BDIbA=J`DB0sI2EPaEle}I z6ep^4LoXoV_k*e?JNPyNxvRH=w&X5$<-!G!ot18IZoiNqce>djc(Pi$QGIev*#n!4 z|NL+U33kUG5NP4XA2!fkQ#RP*IPbhT?ZjMlMS>Zy{gb9(Q7z7Cg%;L|VG7{Z>`K+( zHnPQFtjuH;boh6=t=m9X483lE110VWo=#yl1LYsLBLkTo3Y_zG(>;6c8wc)ef}38L zwps%Be|%&d865=4>dKy3#_u{W5ExKgN)*kz$Q_az!(A|O7&hXUls z!y|*xke3wh_hfmOl0_mWVehDblIDK^mKaF+;oP~^OKVD&1q5eVX z(R?(L^7K+#JUNJQ$p(+EU$Uq7ALxV06W~{kqkKcHoWihtwUSfgPPvz@S#}3%rB(BE z9`Ns_%LPqaa$T%D)Q|uLEJ4CIb6k?rCXPSuO`Dc*pOuIho7LKIR;dK8BxJYj1tj2> z)UluTX`^=Fh((2YHz%;SQOrjq>!mZnhrZf}zQfDRf@Qgm9 zkId)rE(A{M2&&oRE>0Jk)6?RJ8#|So$s9CK3^T7FT`%;gW^^xRJTC+CoOF0_3br*) zwzfQHj55Q8pLhxR=1k+XwdPdH^{7?b0_my0cIe8&M`X>ao!E`l z1#G$Ul7xYG`RHQ=SF8LiW>YW#x?fmVu9;ghX>Hg~`7?cfMcoT^$4f>uh6nI!(A`eA znwwGRVsby7XpUAuanJsr@$#Y4(}VS~yl3@(w2>-4MPTde%vCl=5x&8`H9shf;l>ZD z8xCfT6llRdGUgw|#IqI$_$zUXQ%VRwVu>uZg4BJv4#sok9jJ@ozBb^_0@7yrZpu4U zOq|Yhk9eRRO9&lzm#%!SSRSvgn)rdOL#>eh$f{N@yj{k3M4xleoG}I}36_fw3#+H! z4LiHVxF0V)MsyCy&6ePR`Z;ku8yDvr5?dP!VdM{t5(Xr;YPzQvQp3n8P6nH&+c-L< z({C05bCtm5MPgA(<=?gWgq(ZKRzcIY`+?L@-5)+vDDvu10bMXD&-eA>!+=h(r(=G!CD%fIgX7?(=3)wQkqs@uC^X8u>pL$AoVzk^dR!s zH}p`}U~3d_WP?nA@Xh+WBwb(mhG;a3GfOfDzv&A|J;zFhY&=vdfz+$nsVyCkIZTil z##oClmHY03zJ8CERC)jv)jryk6(}sAu~p{Uj{P`us+Y-XE?lLF;;fruuQGjOM4+ea z=O!ui6yYrb;3P&s_`xo$;do@;Fs!J{)c)6zHAMaS0PHP$_ggR0V(rykH@ztWRjQVi zuFNqbU|2fVlT7z(SY~Gj3mtk3y_z^1o5(h=$q0g;ROF|t)n$5bepDt>3MJHaQH;VL zI3`)qx|;7o!H9|0FXo4(oNg#SO_0)y1bZVBZh28_x&Y_-pl}3e+i*PNSWmmfuM||j zC*YtD`sUBZe}~imq;S`O!R9nlt(RXm)jk<6T}q9P7+uCw(Zt?xKd6{729faFgA85o zj}jqGLEPhnmhL;zGm~Gu?aYV{COka&SJXN)^m8&~{-NobeRDvjE}BBl21Aszl*`Y% z#4w?Wx{PudyQM+XmSPPUJgBth{Ao7W6E@O99!4m41tLHn?Nh`K5Ngv4jydOrgfpo$L6NnY!s8Si5h(Dp_3RA*^~LyIvoEpEccK0u zwk~ln^~6rxK?C3TRxmB+$(s<@o9-(m?xvD1yC?G(hanfQsM!@H1~ZX0`FB2at~z|9 zTI_R}7AFNcxvy^TRejBI-^;XKFI(8RmBp525zQcP2b+LQ#@i69qPRKma46X;5#&6M zdt8oZ7odJrNX>C(4Kd_i?5~yy{B0w^0lo*r5ZDOdJlrhP%47w00??mRmo-!TN@r}= zK7I!q#VKjPC=U3k>kCbPCO!p${*6wT@tyMW{lNf>inZeXgKFKHb-QCQpi7;L zgQ!=lg7LPZKueZTcaH^5hD~#VGC!9yjx{O%TI9aT0JiJBC6&^D{Aa8 zuzl3&A?OZpr)+_|L+Il42N^&YVIj~$E^Tu z(vop!Q@N9BppU^IuyPX^<6DddJ{+N8oRa}=jqt>ZF3&N+?y?=P(-_@i_yE7lngK(6 zEi~eRgxF1Z0!w<0-;+ETJo0lB456nmwvWfMLi}*gH-}Mc36RcT4OIJA!XY}yU02~S zWCIN?ONUC-H09!jSUSCAS0=(&=$`N!*s71Z zfF)n|k3V0Emrcge{}m+4fQZnfB{@ztEnxshpc3}xf?1GU9=>Bs@?B%gL;ifU^Gs8& zFV3!CZ3_fp*=a5cWlJDG6AF@0ICN%R;+r^Eca?Lo=w|>f+_!(x zJpRFROtI+V;{1bu8!8c_zW-tn`D^3{^9<)VupjBubI>}yK_@yU@^4en>);P7LC1dM z({`tow9eG|{A9dygI}Ot&~T%we|RCwro%xiawVn!Py^)yhHC`}ydftQ7BK^1a6XVI zS!uWu%3V}Aq%>hTL}6hXggc@+U>|QZQ4%NdbxnA9EVv(E$X5`soPEFCR4YYJjJ%3V zitxvcpSS;upI*qcaEU5DvS#-%Opmtb6hVRwRV1Ioy0n-dI(B~Fw@-ju?FA=sJc|U- zC*GH#gi~C?;nUR`EN>h0vWtm#e#jyW%XtXfzikK9JW(mSi05u|EkpzkG3gy>csI`m zAt$mY9Q4cQEx3@F%r?&L1DRfZih@81_nWUNaX8T&tLQxBU5De{rCoq!C_!coudsW; zCVwi8*ud+g_k^#a;$qn^eGCyy5_CI89vLALUDTJL=j{+%#3S1=<9>u(wF;)g3Aw7y z5<<0$TEpYr4C?Kc1SK!VI0O zF&$iDqzBg<*K49Vo6%a(_H|6&M)nl8z2OPQb%E^TwV!S9l80Uh+ig7Y>ynBgV+KR} zv&k@jQjQ`EFXCBn;Cr2e?1V_?gLRg>`TDI9kBdzDe9Y`Pj(2HWyi1ojQdOb;% zHnzm_;{-BUFR3tj#|MH+d94fuxvKJNM;-ck^MMOd|F|>5_%Xoh#pAsdTfj?1nG$bB zJSBPsr!6Mf(MH@T1P?I;30-}fhG|lf9UhJB{8tiHCM5%cf;boAHj0*d(l9lN49!sv ztEJlRC}k)X@6{I}Xmfj9oRu7roK>H!f9y?Hc-bfK7&qtcb)li? z^qXlMWJd`ooojoru2tYI{o&5z_0h7b6?7IRnoz~XzC2wC|JSb{_X`sLUP<7wx?Td= zJ#Kr@>%849LGc)z<>AvLWIlY?WVkRrG}}TPEL>pWXA2xJeze`36zTNCT@%ewNR#cJ zZ;xKuUFxx&s=LbTaqfCrQ()M61C8=9`~gH#YJYWz!Dha#K?H{sJ&s9*l~J>l9gk7* z6)G;h@X^i!W^`?~5rMOH?xbbSnK&R6I2KMjg)Z*qpj5L4QwYhr#@?*F=D5By`%Mfi zuTajleRy|h1Hz2OHxlHD3-KP~m8@?eY*U*=K!@Jb?g@%6ro^QWu(Z4inQ!|bhCzIH zbGR|r_+q*-;4Wy|@!7$|>nqepd*C3)t@E9;*o__9M`4ui$`PzuQRE({)@z6pr=%k# zmK0~(t!!H>gpi)iVZVXf)tneYJ#X*+3;#nrJ5uhqbN;z~Ot{85gDr)1#xN&a#^GW2Wwn zI}@Hqyi*hNt|5~yD>!oeB8~od93GJ;%ZNMS@z|>zotP6Bhi7kSB-$n-B+oXpIb9Db z)8d?4VmKf!xq}SfAKewgp*Cl|O%9;D2=Ogy6|e20aLS|bCnG58SAL5=Kq6lNiBy*s zv*o+3_Z4dytd?T=cwIqy0-|nTpySgC>ettCSt`*oM-tY73lcqzTOrsSL?ZEyPOdmJ zVwE51TsG=1!_by6l4J9-`~~>?5V#hDxn6+&M^Z0)oSymt2qD4n9%-rC+I9|!Y87XHV{nV;cDo-%K4P*%i?Y3cO zTr6{!VpGkdaK^`D<#y@dHx=1RnazTx)(jZ=6yPGN7%iSw-C!Xd{aW|1qAp1Vf2xe7 zT+>GK%wU`pZoQbD$sjdDcrnFaZ*$nX;?~1OwNjKSC@on-;#OyU=yI`HkI6Tj7q<~Y zBM}VP!j&BcRQI7Fj)LSuyGpX0n-U2@F0ocNu z0))Q;5?#L3YOZxXPRM%NN0$Q<63(pl-P%q20lwQ#xfcg!XEb-`6SK+8mEG`0!b*cr z3TfA?4yOZB7dz6kKku<>mEUiTWaBQlVPl~4hBI|+@KAl$4j*k*Z4f)ltRiMA+*e8# zZ6HTAFR5qw^|H)+_Aj`aw|Eyfy%#yxbDe7X9JT6KNa7B_vT4_HU*=gq9?zGWbYLYU z*(x_*Fg(3sOry%h45u&hv8V2ZMs>s*Do3248}F7Mi}JJ<--&Xz`eSfn*3yt+yqz6| z;g=hR<2LH5y)DsCOBtAh)oa`ZSX6U8fn$?0Th>(5>~tj&pfTU{O4CHL-H^vQ-xaj4 zPBTe9B)&JZS9UoJ%NAB*n`;riS~0&ll}2tp^6Ui0@)#H(sfjW#!`~Wg%7iE&l^L&- z1O^bBLMS%W5vNy@fwxp>(#JT_oa-^ijFRLfNJ?1mrESpyh&DOqh@Na_&C)J(J&Xu# z(!i!KonLO)4+Uic-l!MsE7hTrXGeIVrrJD&W6MM*7kAg;J!Y{s>0qbVPyhj2YKm0k zwG@{Q;Z1-@IGt|b7;yRHUAsD6jF)W!r?G^5_F0|gH92H__Q{=sCv6)}9(1aKx9ye{ zJ<`PNPxG3|rNVf=rqhs)jwDOsr1hDvh%D*{Jgh+NzPbCn!hKqr9iYMvpCR-1+2Kq` zc=nq^r}L?o6sTxCr@SWCFMOpdQ&A#vTkCue#hC3)dq2G?OP$JMNMxl}^HKAxt7D|| z<(4PpFhFosH1LPu@)kHZAJMIpsf7u=NIP_ z7c}@P6%2rXr-V?bTFNbF@ zU9;@@U9fDDUs$IeHcS1ReQz zt_K|ADG-)cHsTm$d@rErQPbDD&+oC*OGPjz&8rV>1@Baowe#GIxY7gC3*dJq08q%q zSo>Y*w?}cFhc|az+%fD}b!IL{tx`4nl7yCnW_ULBbQDN~6L~1p9sF^%N}Gj_-byE% z?K`cJpGDa^+8Pi(Jq-s3R&+0Re@xe|ai^*xS6$r$$pnZjSI~VR)+&TP5U(0zEa$89 zA7|8%s}MR-k?J+@@QG~j6Ts=XYOxW#P1B_KgEpy5efl5|DAiN+=9Nqn?W|$lJdH?J zEorZ}eMe78aj>1YHgLh;(PDIM&&7IiwfGCGD+>n5NgpMp9CWK#3gOOwy7kQy{z}*@ zxzSy7rN~BFaAe&_09Eg6O0t%}l6rSB-L3)*oX{=^j-!Z1Z^Wal(^X?`5ubc8j6<<# zhIa;4SbLWA+4wx*h>`prumE?#I>$4+pqsGgtC2|X@mWnzvgAzU4WS$!se2b|bt_q6 zD6y01{xNRjlHyZ7_L37XTRf_p_twXv4>ul`ha0)vT5O^?hdSQe@W8!vIiD2x<%$3F z$N$TprPNOm$r3z;AEJQ0IB#goeJ_ak7GCyxgvV-EyYlpZrB2NM=4rmS*e&+HqAg}< zh*&ho)4err6lCmVVGpLy4Q3A#>ZA=`MZ<3ZoQ8`1V#L1qH3k+nLQys;4<(99w-fu- zLgOAC4C@qS(!e&wJRqk^0;76sL?Wx|K_kn2watNnkZ-Ofyuc4ZO56hPWzLJb2i9Di zzenV?K)xP!Y=Nxnd-{~(V0GwYFTvgUK;eXy`t_JMS`BJ+QN4HWlYGvRV;TrzW)N}0 zCmagQDktDI(4z^;e*DpIWotz13M7`>2S3m6m8?-#zIc`M-t+d?@o&ysz@}AQj5pyc zTvbm^k#7`M4r~**4k}V9&+;u_rOV`kIzFamq|=sBzQOUlc8>y9>P#w8LWrXMk%Ky4 zcV}yq=p3CWB;7|@QzaxC({WfEp3@_Y8@SEoXw}|~yeT^$nfL60nJ+4yW=Z&_v(@%F zAYvu^9qO(C46|SJrrLcpct7djbnI?J@#hiSJNR^VhvWKsByzLwqA^chBH~?SkGGSq zdo*W%Cg9eUSL0TXM`0vNV)+&K{TG4Ve{dnZ7O6Y~A>wzhJ~Bcfth@CnUa%?`r;_Fb zI4IP&G9%fM;p>fB7wwk5)84sHc^AYf@al=&w-L(P$J%&Wqg!sY6IB6I^8g%Mnga0j z)BTJrmg*YoYi|LC^#lY8;YHE8RC>GqAH4!0C^`PK8*jRqcyPu-5L-&0EmTnLCSAiI z9#ra?cLfoVmZ8pF_jTIsc#ekWx8Ay)?9=kDH@*bPrwYKaONrt_4C_sLfCdVEr^uIh zXP(n2UeJx+Zn;;)MnqpcN4Gin!-r~cnAMHD7qPaz*&5u-6E#A>acYsYlUv!`D7YmP z)dkYV_}3&J%v6pObDD&Gc&s`W`6mM97>LqwftmH%^@c%tVYc|rZgb}ASenxnPt`k z?y0TjM617SzJl|RPzhOkCt}V}lG~r3ldm^H2&va$IvBJJ04wZ
9|+J&Y^J>y=N z=OdL7%>iN=4PVBK=>M&);*Fhu;3`BB_@y5VZVZWk>{2A z-%`J}+87F~dp24yOMeY#ifk5W{bCy3}7!SMR9IPC)h51mnv3pOkV&p7+anYTx+Z=q7w7N3dpo&jOBCSI?LZ#0?|P)IQEA4u2HB~UqW07&XUuT zbfFMVrCvbs?_=LTf@A7P;E258D|yD557zl!0W5AV2a5r+?`m7`AGc$)930U+dfKH8FAE{;E9A(k{y zJdwaq7KG&jTMjo#(F_S^S0Q&gW#3r0qE!btI_#yiB<^~g4D+qKx0VNi(!|3!$m=^yN?%W>)4fcKD03b)TS&_GZ z(?33`7nyt)d(UHYYfE;Oz61Gwc~?LIm-QwoQtbnsqrCy`bjPjf?_vSB5(n1DDnMQv zB+05{wTWnKxj&mpV59mP<*q*sd}8E=f3CZ4S8G&2hlKu%zK@cNjnp5v=US1>88OLt z?*lfdmP&hoB<9#rIs)Tfw`#M zEh{yJo4Z5H0^0AuRL}s@qkXw>EzbzY`8N2&Qg|IJrg4h+gh3QSao0RdSRBG`4 zp<{a>DbHJQu>(QM&d3yOb*v>`9J%0D3D%NQM7hh;OAqXZ?p)%89ef0JNE8x2N z3#G1*R5u;DMBkDjmeXqB_@Yp0MEZL-^cOisi3+8{uD&B>66;!Z zFkC8Ih~w~X(IWDDrmytj(HVx8;^N*{nvkv@Iia;2wY3nNeSG}8&T8OTMsYcTJ;esC zGz!;E_`QvZ3C?mOY-cFtq6m;-z}1#}vWC`5`BqL1cbkafqTOUIyfa8f^Y+^Ylb9y? z?$*#g9pmn_v zy2OS{lvTea>~ZmjP<6I$;y)$E_t}6~ndGouL~?X~c9Q<|%Zo>6&b!?b?Qjv|LBvDr zg4u;L?8eMbd?A2JP6R!>rjPHavICsICm_CA0;Wr(v-r!|##lupx$CCL)RuU8muW`~ zj{#N+J0*diYv;EZo&qN_QX%(I-xO7)>w@!?dLIxHRY(QEpld()M~6PZKhAagTWvUQ z{2aOtFhI7Wi~-%y5ML^i)y|AY;mrVUOrlSLYi=#`AQ3xH>sJs1r(A72Cwzrm1=LC6 z$~E)^v6t76OqEY&0LA?gz=qTE0)SJg81vrDIwD)g?DmU8-YWMqUafLl2Ks@e*En?a zfG7RzhY!ri`5Gi&?_0tc-68R$6K22L>ywc(8#;l_+AE*u$IbPz69<4l^6H`*J8*)y zdJCj3_%SMHEuJhMDBWLdJV*m0abm78fWutuvYOXyG~MvJH$#a@YpifLVB=$DaSh1I zX6bAtZ{QG#pj;K9yoU%#*XxG62)`lEe+KJG;OQH*NPKe0j%7I8(U#Scq+dHL!cwI_ ze8;({&+g8tIgf%k&Ec($#&MYwbD9q;er3D<70{XQL(M5^I<7Z=Kt&~Bl?IM0HU91> zAZ+yT9(t_wvy=l3j^%fHTWmQY3jiPyoVfcT$LknNU;yB86u(pcr?N7a@?gGD8+p5G z*CU_iMn6(dl>}kyMM^pL*q_>$l%aFdxd!2?jF(@!S=1=llLW+U+J zPKm`BFUN>fUG-Syf;0t~5)AdpERu;E{pLNX7(wI$yi-)I2Y-Ouwq}AikccEP19l@o zfg-4E+t>r{ciye1)0lMsm1p11H)9p%Q1h+?ubyS9oaSkUk}ns4Yfadwxezp}4&cwM zLMC5WATacrlI8@IreV$&!D@RQDx?P_eZ2=xuEA%0_tPzj&O6-JO@!BZ_mY6}ko zbhK7b{hEY*9?5B{yb~qxd=~j^tC<2y0^4LKWYi?DWkIT)r04)(aC!Gh4O*_Zca=P<0Rx+G1d}$u+E^6mQ;ShA*{luU% z+g4=}4=>vUB7OY=H|+^ov(bKRZ0EjDk7duBuVjY1TML$w!R zm4U9uHzm7mLX(-~w8q=<7GjdO4P^oWeDzDJkC6}?1WTBU6@NHv&<7}~0HoabUV9RX z^YocCpz8(p>IwLjf%#L|ZB#>(N5DA*Ny9R3$Ga$+-=>+7Fsy1Va#4}a1$Npv3@G#c z1=SFnzc_%U0?Av`*v9c3m{ye2!2b+G+9_|pK|L*ml(2Pu+oBa%m``d21{UqfD;J2hvp@tQpRjK z)BZc}=T)p&&W+emI*!x$`+Xt{z1=5%cZ@KAV|ceWnD=;p#B#}(0Ai&W$&?2NW#z3k zh2$ULJ5+5oku)UPU9Tf>a;G(q_tma!(tVcQGUo1H8o$=fQ1B||Or*2ai+J^@`hQzXO zsC_f1HgAzsZozl*U35HA1BW#~0Ku8kkX};pec&+=BiFqI-Wbk>K&>Q7YFEJ3ROaI8 ziBI+i7}et70Dm)!N~(M+pPIa#BBpJ0J!~rN4cE>Y8NF3hkItIIX*6Q+*SNHlnir(P z=%Km0cA68C5BPCnHU0i(`QIjT4ioS@K7 zzb+=i^+x-}E|IT7G#^#uPi;nCzX=yM;`i%nHCKUOA)ywa5;{E5)7wQI(cY z+qO5XNC}FBfRqvvB7)N0ASvAf3P?yxN(utfp&%gnAQDp2At6XA4U*CzEg%me`Hg$U z`=0Y&=j^qv@1MQgD~x&1IpP;%-1Ct7v|{uFoj@nO(|RLHp9iBLut+6aiFjzvvofZy zZ%ZV3!#60n?b{MW(p)noS>6oYlM@fE3x1*1*fV#S%&u=wwiC~Xd<=RZl^Au>+D08N zex%E=Y5e#Cz#<#79gWQ$AFyAzlI=a;|*ltE*Z6>^9cJcF6^&QjJ=J!=%?sc+N zj&nh`?@FC9;o(W8nd2#LuvVJqFB*>MRQz%z;<~2Gc3Mxq)soirl2I}CcYGNNau(~( zx2Y>$ZNG&qa5GIzO-!iwR?TU$f9)0YEM^wTX6=qZ_`NQ(yw4TY)*kx9{Yd^W{;Bf=df9-hXmw$SBdXQ76FwZdtQfNgqIW3 zJ?Qnhd94tGh>2y9ZLWu(p<>RQIri*3qzo{c?1}g(a^0(<+WQhpXZ9jaxQ$#^LI@V= z(b_%T{nD5T)}i(hQfssDlFM-)KD8^~;ardt`>gvNs8=$+359QU4l|PZF+1vG_VwDJ zn{m2fd_>`S3(d*t*kaV+Ghw_Z+2#vw3tcUGK?a80=u3X#x@Q)4PiFJ%&0cM31}tqS z1&RT#c6oPTv52ZTtjIpQXoenbQk)+3Vl6MGMOh)=v~3rP77fk~;d}LISB2es3l&k@ zM)-9jZUmEy%1YrjZXI62E-SIsXU`88AEbV8_^Q06zSQE6xvL{HbCFZi4$oxIdwNH> zh(t3Eh+F+azO^Gck9vSy-{0pPhYxt;9yc~dDv9M}?bF&`YnTR^y7z)$kPsQVe^gtn zsv;|%-^x(-Z@%9&so&{LP{irl-Lq#4aFYd`f}pXKJ)-<#f;ry{;zRL;FL4YEmLb0u zdsI$AjZ2~sFWC>l`q+MDAUkTcKH>ozXMjfGy7UDml~g4l0}1a$RO(eZ&{2FGZ1W`; zN;3BO9hmBS@|4r{QQXQesIi=h_#R}_K|RR48x0t zfwNImja!;Is;Fl!4(cLj6L&)YgvoN;zPu&lGVf&Z?x}ndG3*C24`v3!G=y>2VjL|- zpl+~U$$*0Y28I8v?b?`>5U9|S^9vi!)%wg@WLwT0Lg8R7^{Xxk;0H9<6yY;MF>6h7 zso=ZE2Rj-QY!s5>oSb1IXz_*e0J1G;*4CY?gz=1BK+Z=+@FFF zw%lg%#QW>VeuxWIvO5dR@nroucM14UiNNarZF6O4iiTF z7FDR5O{5LH)JjW8jB%9oJT2}%be?mgg3Q;ShfUwRxWx+uIg3w#thOiQd(9(@+JBRr zf~i4&Bh!6f_PZRvc<{MEXV%up5zX?)S{l)f5_h2Wy8iB7ias;B$QC1YfDvF0@1Wkr zk4ih|ED$ZExd0-`>kOu^HTR70zctA$X#uFSKP|16j)0Y>7M=@@$$m$7IGYmoUIQM9 z^PX%o`O5>5{K0rJyqlAYUz0R(IRnCk88n4k%xte{hlZ38OCR$<>t(ai7D2^ zJ0I4P>alVom)ENmu-A8$Q$FK+ec!BE3f1bfK&e2>hF74KM+Psj{TqLZHUSPQKBoAd z{F(iDV;^}4tXUe|{X6?coqG&q935np8O+SbPwCJmTbJ-pp!W7*X4f5hsCGT6Mrj=bO6p;-iN5x8_28lb)(+;$10#Rps-+L=3`*Nkpf1ok@H}TEmf16p1Wx!k;&l z9KxwdCB&)V+^x7s3#t8jzml?R{w1CHe=zX>`hWdCFah?et71GaR3v#c%fm>JIh${s z9Pz^KAX0iJ{wGjN!7Xbtil9&k;qnix*|eLr9V;VpeZH6rTSl68U4v$5kkklo5d1|BeMkF}VfR^h81(E1}0znV}jG_T*Pj=ZHXIyGHRluMa>3lVWuiJd2C7 z7uYfaxpJTMn&&9Km5u0R)5-Ah0;itjD8sp(Z2hHgu)>s*(@ZX%Juda5!wiPVuXcte zNk-Q!M3I%J!@=nD2U^|rW})U*xA>>gWJxt$79H>;&J+r)29h`IIFow-rW(or8D+CO zqiUDmI9$KBaozlrp1a071xBxnj69mR(6bRz{5Rv{^-^C&(#=?WlPvm6Cq(RuIywe0 zE*HwcFXIaqj<-vaN=K@oU}`!!1^ti$XXA7ttuWSYOuuL%fU?Y%-U)&}3GZ6?KV!7C z`3NYZJwTe$9rS$U=1eDm#LroE3T;9g_`k*~rMon_O8qu+Ly>Rvvr|T=FsocW0Hlot zq%tUJo82yn)Jr00!W01~*>>`7^lj7D3v>p|vhX>8G^c<*w5!8`ZFSSlZMpwlD(05H ztEJ=SNU>>_s-osvOX4zr^iUoKDqhE|U>w07@gv7bqV;^Gon(uK3JeAk?wCqe_{`Z9 zC(FxZBJv{skJ$tWEvT8Vd*`7}ZTrDnNSIV2YQ zUM4hQZ|ni?PyRVmlw@E4AKnbx5P&|UlopZ_QfA;Uh zMX&H7?${KKeE^6)%UwTn0;GS0_@$a;TA<-8B!|Y))ZAzk(%NlGI|g8Dns!B{$gHa; zlAo1cEHG}E9F~?ddc;fg5Ty#ZX53*xX>ZtYEGVwxDV>1w5AmG~Ks3t3o!^@&ilp(U z-d5UA6)IiG2RbAv-r@Bs2p5ci_C^N<;O3e6Hl3XA>oK?-A@^> z^W+l&zgnS<8LLV%e+)q8jDrF1J5Al0aA#f+;RYbGPP!J3?*aivO-m1O{~64Ah>-9f zl1i3b_TgQk!6MNIP*M|_B&hm=-C~N2>I6Xo#%0rrD($rfNr?HUJX|HRSsD?KSdhP8 zC2~@+xtA&|52QTV{9lT$iqO==uFGcXY zMRh3^Zh(<%z+mc&Ud3BImXDR0_yrU6xHUWFfl~>@nmT_tRc=l0NEWF|N>R)ghK0cL zs?~O*9L9Io;^k&6rXg}mD;?X zf~OIWfoxR&`F9P#iRfl8)M>U6ddSZ_2N~uciq-${zX7BOr~bHRnvG;1-bgMWjS3hZ zQu47}UOe`b$^4$u0I^9fhuTgF;@1UTppXRSw4CO-Rv-dAzaXH=zZZzsg)({>5?!Hp zdojO_o=(({p0tf_ zIY-+`T+o5jGY-x&+)Rt27Lbe*{uTV~#<-~wuD>S8IJ;9HNj>&EKu<|EZsOu;3_4ajQuq`Xndjzgz9ehuRk-g^b#@m?2I3J}>^d8mn3z~X z(zWX%{?^7NvCc(Dbl>aD*-Piemlo*mUrK`hTB^$OZ`JO9K$W;bo*QpYg5j1-^kC*Q z&NVM_peY^9)1MZACWG+wrhZw@>r{wfvlD{=-OTq zOOpdWpr@+cjUbK-;j+4uDc>Mw}yU>KOu_ab?jprV_2+Smr)a!Nc zg~C1)I<)afvCE|V=d^)72(0<>$&8>^%a9hWuFUP=iSPu^OHax8Q^Z zE4eS`g*qUk5Sz8+?(WP*iK#eo)ZERUNL$N5+UO6T6#08$zYR|Iq zAB~5lRkFEHvpLrmD#f#>yvm9l?9*D}vFxPA- zbSgkN?{?%y)V<51uDe86Y6q-NfS?T$$Hc?M)loW-8gd}2nnX0>{d1`%E^dwVJt86w z-M7vXM5YM3usWWZOXRZn{>ECz7fLbo&yO1(uTdS#&)}f-e)mWyuzmjx_k0QVJ)1|P zVTeH2vt`IkEfz=oPk&t(*`I!{tsHPPf+z(b{Y(0|S3A{%%q}s#0zwtZ`09QBubqk< zb=T)-0VWJC9nNJV^vKP$nnDZ$;i{kCJX88)SM~_qPxm?c8`~t1=wRa_Ysi|}xlrVy zhxvoV*gdYU@5{9q(`2jg;Zi%r5MsvQ=RdW`^ibqn96&lL#0QMO?ARCZ+IN_JraDT0 z;5F;h7EU2$U$^r*AD_sy=xwi}IRjSemF_A-(7_Ze2B*2Ol@Z_aJ>DUAl$B4Uzlm1g z@qSGqC263CijPlveH}fpa@-O|egl|%?QCiqjktv_H4C!1eK%w1s@b|nuGKtQz(Yg@ z^(N0Kfw*|~Tmxaw(xah1!+V%<2ygr6cj{Um=vCYuqjarN z<+Fwx0!b=Bfl2<5<3C@wMhgsL2%$&nO>&+tJZ2iY#vin(=O$UP$a)q#@4;bc67nS zFK%2Nn5if1U{F^^%*`}p(1heo%X{CIMx2}WI^G*`n%8~0hA7W2|Is53KY5hy+)IF zs-*R8U-Pl}Z(0EVF1a9zh-;hmTBJim*?QfXf}>wtL?TfY(uH`MfB26~z#ren&jE2T zRdrJzxhhS=EiHposShk|SBg+P*d%_AXlIAMVp7h488af+*=1h#-n@Rs(6z>DfN7Ri zAf}1pYKcV;yUAngw|Tvanz*0OX%~A|-{Qw21<+3g!h_QUq=<#i!!+y9XP}tEdN$X1 zqi+>lLYJ5rI^Td3YBZqu!zkD@?R2=G=xQ95Q4nm^Yisrs z9}X%5nlOdmY&X$}1LvJ`lJ*yzxK=QM{lUmg(8bD_uQyY)Tm$)jbahWpCv&R1WR5SX zB!@6bK=0>2WV`o4aq^Xi^U8lHYXAC2;v-0=w#g=l3?rOg!)Z9)rSd*puPifZ!PBqv zD3{}+D_w?(%onc;C#Yx25oulnLYmc%D>?y7kbVK0=2$L^dqf%rR5XD`Y&Ik`92-^L zLPam6X4a^ZQxr4rH~6IdtlKjp(om$rtz#>8cO{hM=;@X9`pnt*C7GyP;yL!i&n#di zX`;}pZTGsuJMBNc#GK|O2|3Jp;uSfCpBW|%gt`or)Qg&{FXFMGCe4J#xB91lDfd=w z`|GW_wgDizZ;Xdq^_V0Admgp_kvc-=E-HTM;Q8VP=!9eW>?c1Yd`a^?NrZAUz)Oqn zOdOF%92i}ImtdLys3Qv}3$hIj5V+Jm=~DZcv3kB6p75_ncfH-UmZ4YI1ST8xLkqWV zQ;AUW5(9D(9h9d5W&WZ6)uP>+-G>16&*yvwSNbX>RC@b+cj(D6-EnLtQ_=$Q0j8$0 z1!^_W#qvFX3dx8-0k|!KH|OB?T`V8%gYhm)Kr0fipaH;*1>X>zfe-_nHKK@zDs{Zt z?x_y|50ZMaY*gF(v|$?1D6tLV<2-vpoAVx^57w96V?-CrWMw%lycnt>a(A1@r!j#4 zz*(E3{65Kary4zlKG>Ai1l>2&S5Gh~Kq93%rE%`}GMc)G^l8^ph4aQ_3N``t(h!@T zHyMu^0H>pliRX~6<6Evq+kG5htM#&SKa_k^USgb!BJd~Yd*wUVdNDU#m)gb~{(RNT zoN9&?D4pR#!!G3q#5`AKFTsHWSfvhT94SEzg4g67XCOZFK+&Kl=e-zi0zDqdFNwfj z{2oj?^zDD+*!nwgG(esDD-4@CBjyw6_c{%W;TO#--^>GzO0$p=IM3Kw1~5@Ns<#I~ z3Tp|hLvv#@JbRKAZjeigc%W&?Y(iokzUVTPN*#W9_&EnKQ2JmtTUvBU^X#YuZk@0c z)z(F+r3IVNy3T;*&gboUEfmfG7}%NjEL^w->HLX^+(dwpATHrcNi)CM{&$M6u2vl9 zt5d^|$a^?b09@|4{IH}qrR5su9=8DtT_G(){@9-IgPK`zDd{mJ*NrJ6&SR`d>3 zG5kNA~08dxKeQLq6K_hRh1R(w*{>t^^2 zju@P)&mD-HI8osgT>W5Sy#M)2s)*1rl!S%r52BOLSP=h)A9t};vx!c71}XI^FS&rY z*>a0(`sElt_%@IDp_g2(^>Q6AA1pl;Rm# z-%%ih?7*dPzA#3>eVF0*f|{^vmQ4#wl0)lyN;!kz22=`jR2}7JKco5(W|1SzN9{Y& z^dfWEK&8#7NuiayA=}M0g%v+LsHX3OTSPE7iS8J7c9=p;dTo786R)nyNv$!_x{c&Y zEJRbNgJn7D_vu@jn=SeH&fj-IYxAMVPh00Zh3NxjOp?6?lktiYjXc+R=;EQlR7z%4 zDT83cPxTb~6zPw<&G$5yw4g9K0BxJdR)_N2jMbk|hCDmi+t|pux!~dg&_epAlB#IA zCE>rQmj9HR&?*3yK>Cg(>R}$8!`7?}7ZnNqH844@ggCTy4}hLFNg~p9#|#X_3&~gU zYhY`PbaG?R^K@vNtgvIYiyy>V5_=dz5`zcK91ilutfmUo?Ri?w7$h0O!N&*ZzPlJ; z^jcR~h=tMtzzz!yVqORlkHoY%>d?dmy2VZPCb2`)F6#&n~(5 zb`)4b5pb{MVjR(_If8Gf?m_%Y2dW!)v?M0HcNMD6Sb7<}nfkihmq|T}9h{)sHL99_ zuiWQSnLphV_sWNLl^)&1)!2hMbNgSkgMZ)ADv+_qD?Z2}!0^NUknkcL+1TKwCIBuxHtz-o^3<4MY@n5VySzGkzJy<5!dIWEopuuLw*Uw z{x^X7(|Ll+|1ww{LSV?bQ%zFHLNU%mk(Qr8*WSG^`vjEjYZiNd1-B8M@Loy?yG%xW zn{R$V#JnB20#}NgE4t*tG5lO~RMqz+IllF%X{ow^mB(E+>Z&r}(;7V7xvu+TxdxSm ziCcM~Xo}^v)z97af0&=ty9z|{TWId8t_q&A*LnRevHqw^<vLU8%8mmNciRES?}l1oG=i~ z0u(Tb*ob3vSf#r_jpW~5-cLxk_f+}H`rTV+yVsia=mBfBek0;pkxg9#?Ag&Zv{9Vp zWsejQ&!Nlq>k}V`JG)yX*?oG1>R0c>skTTgA(^Cm02N`+$`u2T;JX#+MgBnG2eva! z?ZL(0gNLsn8%b?td?(x*0PuMJ&)pFBX0h&z-qHC`5)nE*AityqI_y^|?fkW6eehGx z*?=-Vf|BRUS33OiuF-nk{T7O&CrI}ZG%QaRQ?3QmefyEzNg9nPL~!frDF#e5zyYqr~l(gH_77; zWJy0CdRm3BQebWg=xRqed-V5{rMFQzDruq*x$X(w)f@au&u%mTyJQm~?WjP9SIjl= ztCrlKD0Kt4l>mi}zaJ%nw<}ZsE^Bn@$C}#5sv*|J=Vi!Vn0QaS1r0e$TQop$QsM1Y5f7H03oQ~R z$s(O$#37PJ_~h&fAB(`Pkf{5e^akUb03&KS0IhfA5(QP-P`_%R=@`gW7^?c?*>bu1Y%+?o90A$i&j!pJxD>yTuDBGq@N6 zeo{D$>X3cmwx9fdx!kEC%R;bt4hH4UA?~kKaO>gm)|&-<8kE ziiQ+~aMLVjtG1&S>qrhJINyKegX?wO!Gny7h%w{G5<@DRHg|6JdQ>brF2<=bhV zKLzO2zW=2`luvqf(kZbHVw=}?%#@TcD3F}=hPrfzDjizgzt52&<7_tMWLfHBQaW@= z+8F4&iE#LXOsn0u=#l<*pMpx>-(~7P97MmCD(Y=-*kaVVL$Uqoa|fXJ_8!PRzFBJ1 zmw}}P?B(_PmFwsl$2n}IdkmI>*2W4Vt7HX=m({pl7;yxZQ&Di5rNiZevGf$dcy;@u zEj_m$Cm4wFjeP%&GgtR||%hC^EdGS8~g3NIP@r9gih6XR9N&syl z&W6FFkrnwi9=Kzwp_=ndCyKGvKAEPxwyZCtb`@Iok`W(`3GS$n#|ivIPq9xwt8 zVHc3*Owp|y3ov|xaxTYVrb%K)iy1@^a$r1Qk&8Ll}R@uZu}gP2)Y1Gb}MXkLZKhS}CZg6sNqlFS`8v zp4cSM&wxiR$>I3~p$jpghhpd+uf%){7g6x(uS|y^%X+bk z6r9(uw8!o>j=s>#XHa)P>CcpBFj0WmR4@D69=C>64K)lEK*+goQijsQ6L8Oy1V9n$ zD0zzzmka`xzl^@E7{H4QAw&Ox=lSa&iCaK!ToxAEy!9T}!GQ$-rP){b{y@yjRwPRA z&aj|se2!~6oH`TkgtQVbUqjO2$9k|Carb9ycI+dHS(nuzt$yf+6>baKNpnwMmgHy) zW~iVCU{iyNT@vvfssW5-3C^J1=b_P*kS}`F5p>d`!F( z(%@-C4tWg0B_BaIY}r1O+@}a=kmUV7#KRx}bWu^d@7l@Hw#t`8u0=7N`?W$qR8GKj z6Lz-w_rOcGT^gcI4~6ITUJkpG29Dz^Roy@L4dN`w6H_=?qQZPh=ng>jJGUQK@0ehg zgSAHf>yvArp&5d{>^%Sqwj&=T?4~8k*a55pZm<2wZN8nz1MFG>9#Kr-I=gs<*Cv55 z+mZ|DDk88&D0Y1K>(L=0DKJzXR2*@b!7XhLgnt@TfJPB?`VroOix!3km63~q+9L39 zt+Z$9m5Qk2It!d+dz1Xq0#z3<7t47?^%;P`_d|i<^`aaGI2NTXQkc0wXU?HvePx}#!GuVPTDrvYx= zcdq9$)Nq^EiZCYvjlfu;>bg;pO=rXaDf357AeI%3dBAwo@shE=7$Pi;Pg0RuPja~I z!Fc@>eJfzZ1M?lP5R&-3iq)PE0EUCG!lmb^AAWz}8&d$%qE}boffqcjzan|9>X3Bz zQ7Z;jpMSWK<9u|VA{y_Yjh#Bs=+s>$z(P`z()|)fBl5?0;LrOBt=~84zN`P7_?8r* z@%LV2VilSDR027_7ZP2kZ*tu<%(nc|03CX-fqfn6)_s2dLO6w{x%mg<@M6x78&Cv= z^;!d=i=GYRK$`XZAfkR#kg;Y$K}Wdpq`N4&@g_(#UW>x;g%HNkM(|N;d~{uZNz>*x zk6C}bYrwTH`kP|!E-m^=k5VhW{ba4=*-c=Jvcc!u?wr#8I<4)9n(PL+$3O*x7n0B< ziws8D-E=1NfLR-UlkRgD4hRe1|o+_oE1@_efk}_)$%YZ(_tda zW>*5YZZ;Y^#)E!EbHbmJ98=FWl{RBO-NbcQfpgoS@a#JmU}U5!ll0G6OFA>$7Iboa z@a|)XHO>CiWDY5!-eX5@&=d&*FY@Ub=jcSqXDrsRiBPgztZ^}+=YGLh<+G2ZKzm1S zEX8&mx4b=!%jJ00_%s9Xp+u8Hm~ujbL-6EN<^aHdTOa00%+rJ^?D8I6RVy=wby>?RRE_3joVG9)zrE;%Eox*gDR%<3%ISco>m#e*hWC z{KF$Ca0M9+GnYz-eiHu0rA3a&gOP^cUboSD21e}+z*wx8sRYZ5pal$qr(EzI2U^^W zQR$nSIpVeVF0|3J<}XfpP&?exM{J=WA6z0uM;Su>bV8x_{iCH9)AdI%g=8AhTh!98 z+J1x<=qq`_^S_(V{}pC^b`L4*tavk;Oy()|(fulihqL*H0!}|-f4=0R1ktcY zP726SR15D{2iUBK17|HT_Tk}k!*a95znp)HuJ?D(;GEGZGBV6bgxAvlm7{^pjA~sr z2F|i?8nSI!l95_5L4eVi@}piF;yTxnKKS`npJQM2|5c1e3C zoSRK`ufR%|zsLFwjI)ye*!%YGCB!DZ>S|Uxuj{h^*JP^vdsxESf2s`BJ_(}uDRL!? z0@&c;SlJOa&9=cC1n+ASaVkNFC&$njG-&`R$Utl2*rlV5p0AM##Lnm;>V911bL{a@ z1-=Q4VTj*kylj63exCial=c}5iO~+75JilDCitpYGg%JK2{)TQ&VLa_V4xevg!y&6 zBchAuCfv1oo*+qPNbQID*h?_bi{vVEm?NOc3^r+ls&1%ZWTXeMlkyYVyNS=q5*MM+ zQdSY(3ZOVjW>h45w9afBT;$Yu#*@Qe#rtHL2@U=+i!|?{#ne0u~ zU52^RLjk|U?1DpG5Cr03Sv^UeYXUixt{)$p4RzeBW!EW8{=Me4lAEn0p?7lhJMNKJ z*~%%9{_g;%&`-g0zW{SC+&f?O(V`U;W}u(^k#(+?>|x_6)MT%;w~n}QWK`a8%Ug-T zak?{L{aCGlu>FGn@Lp9MOpj81ubCV0{l{8ZvfsV3;K^-7>k3Wi<@F&hp%Yss`u4$v z%0x@vkANkqfp=bHZY6WZ5}XGMgtk}lKKFC1zEQA3gmiqyY>}?2$1hh5orNZNe`gKHxXjwq zOP-Eb@aIKart3m43|)e$xkf40V>J|nbZ(oG+hJV%W_M}9*$L=KCh#QV7sYZGA7Mm+ zEq+DQYLo)v??8q}o!OM>G$qWN%7P9IvvB^G4HARpF^a4dUbD#qrQtg2e}4y2U4k{` zJ|Bc8krIF}7~egCNOu4nXE2w~Eu-25|G4w^z?PYtPi3A2`wRi6;9R~FCU^dOp|c5<>u zWUQKU?PRjEm}aS_?MsZ^hJ(MF0`X&S8StuYo4JB!DIHW#s(u()FH`V;ob@?z$|9+R z-0x|UVUjFI_HlbVh)WF_6osfuFR6XN_HEEj$aufRg!TYWvn7mNl>Cg}W+58pWBu(t zBlqNAXn26lD6V7?zcZo7d1J|~ATYzd4fZ7|>0B&G5_?!sw zMVwiZt=@xOMY3=+5}h2=8I}Ps9uQ~XY&M+1{yM4PR0ZFb6+{?VV~Fc^piqDRrI!!QMX zi2-$@*UtJTBvA;N*C7;OBpGi^&j23C8M$4$QKekRDMZ z5Q#Pl5=511I`(E;f)17nR@i7{{CH^EM$bX!8FIN>;jMaEW?pX|j415?xI7H^?^ho(IAe?X zyhbXz5409ASHOAU3+dFu++i5bp}kL$Qf3Mz0$3cf0c7RMO^OtCTOEo3kZnGvkUj4I zv;!89JDM>J{#!qVMEAf6p^3Eft|xQC$TUqrYi`yoUVHnGO+hd_1O=I4FY7yF@JlP6 zSDW-b^*)$?rBZe!EdE}5uy_6Gi6^4Li>uCt(JMgI$AZc`_d!=e!W9#`zPY)(pW2cZG)2ffy|h5MzQE zn%T_SA``$bnKaM{lsQfOsx6e*dfoll8HxxUFv0|I%3uET?Y|?k;>0@7rbKvy&?7`O zj13eBYl|A>sA15ONrC2|*lk4RV)v{djDdivC@zf>CD3_xFp={S@d`DBOB;ChT)%Xl zmy(rylFfot*n2eDQ8rTwM9kmqSdYkXkxd%sBZXmC)NK$|WyY#1y5>V@>=Fzx`lTZV zvu`iBcrpb?VZMf5pig`n%udY|GTwnQL!22-#zO8`|6|Ac-~VQB6ddM3m%OGsk~-di z&udc`k^W^gZZxG}kE)tmY{5)=g-6PYf|^i7CWlyU&9=xJx*R=)yn_b4vQXi*{7^)} z3xi&9-Q18V5-@upVEr7$3Jc7SaUs2_ePO<0%E<^b=5U?O|K1Wi`e6F`CCg9G&K8EO zx{JWvrFj1BWI1Sa1N6;C#ev*P6U41E!y*+{7!Bv#&1`<`piNF7eE9}j^J2P792Q00 ztflMPh#JfGyD!zfr~**BYJ`#93NXB(`+*j?3Y`&A6asGV79eAHV46^XouJPt>{#oA z*_ty|J}fuc?^i|ZvfWg7Hzl2XAd>|s$K6YyOv_{GCX?h42oe+_nU}?|{%5(?U%~)C zIiv^9>@-V=n1jkZ6J}~Fj;vmw;OqrdBA^r#*Nyu6XHW;I!RVI&Pt)S)FZl#;C3QQS zNAsldEVgGqDaD0LAAKl6Rzc-RhY%y`GGs~^DU0eSHN(m{=gS_)kf~!uy*(+KuLk=8R^}U;b6iVG%9DjV9GsA;4LmpaL zbEpmS@&+dE7zP@Fh&1W$;6FR6{_hRF2CbUc^-W4m3}}zjp=wfp-vd%Z4FFo+jrD$W z2bFHKgfaaBG8hOhT5!TaIT7i-S^)atHePR74^6FdT8xa9vtAo1My5da{S+yu0**7* zo?6KO8Zi~GK~SBYTd<)^=8?{XPZl8P_LI$oQZq?>mbIIuv8fHn=J0bH?s9pNTuL)2pEW!7-0{3$_YhbpQ+P{^@}qo+(e*A6#PGLS!ztIrYpVfXc~ z{)Na83m{vqplU@j+qGZ&Iy9x`B~D_;Xp%T~L%~UvO8cLab&8yO`o_NSDPRP%B24zJ z?UB(bIy!#j_(JKmi4F~(wkT*aVH~4D@0-s2yyQpv|Mk-I#vdpzW3-Ig5`RY zD-=!1IeU&QdE1-$z5*CcZnZGFo~+WMUlQ%%D2iLR^{pHMXA*o|N_?#QO~nJcvrOri zz63I|MR284L7*WryOTBEM2}g|khV6Q(#CGd98l}2Z_wJt*3kZo3&Kn z%%$KKV16IHqNIY^UiT(OY@nxRMvQ-#!T|DyDc8m4xFd){;;1Y9Oz2|id z4DpSgCzq$HnyJXxMM0QLe2WkvLH_h&9#O1~x@!hI!ipi0u-3g5pFpL%% z@&fMXcR5cvg#dE=s{y;dJ>GHlssRjgN!Hw{jm%#n0o7f$+#vx{?$XnMd?j+d|6N#J zDH#t~@syxbDd8tMkOHVLmh9>EAzErbw+M(6yfht>7% zS|I&cp=0*$|G5k~L;_R_p`d7`=N9ECQNkCPZ42hwxXhyZ3yR#!6yzm6q4PmwE+jBlX*Ar}gchP}ZseFq0{wM$^f@ z9xZ(i^nTYTwN3y=Z7Ogt4%-ByEn6&OIF@yCoUF?X%`6<(Bb2$oD{>2Te{nme=V9w- z`YNdFrtZEgf(UhevV7b&@lUZ&e?j8hL2c3rEH#;~^63wSVN!RV3&XFNf{rCi^F%KJ z&2y1>SEhVl100MN&2#CtF~8tj_d+haW%7RT2lQqA6!e;6R6WT8QeYY-q+A5x7kuE7 zk`=d!5!tJV5PXhg8Y^hKcDwZl-C>Q|5)6TpL z-T8Y$Ini@0G*x;C>Fn~|qb_c~jf8~7gWZ#A2e-T^x%91eB~dL51Ed}fR~HYF1)|ISKC0+*q*@X$&vWy?4rRbLP^&n z;FOwYrTbpeNnxAxA<&C-lAJj+<&`86PuJNq(na`Umr}iV^-y-s9in}ZJdqKRy;rafA z!X_$H$Z%3s_*xCsP(~+VK;h5P;{F*zPn8qI_AK!QE;=9eEo@9-rcGqTwirqE<`j`$ zY}CJ*umIsc{SpEtCx+#q&Hmc$z-X$IQZ5^!!A==N#JpdlM(Xgxx@S)iFrC>4CY4}L z7c7G*6_w@{JBwgV0t@_y0t+p_K#q70;|f92epA?vQTUP#2hpe^!o^p`!Prec&?Pie zDAuwe6Zhh78E9AG6LgxKeSny8#y6WjsJ)d{KU)GA1T8_brXJG#>;^*w#=j&mIGW9W za|(JJhP8$sPyt^HS@H%`uFp>pc6A7!>ZNx|K1%rl+O=Hp4XR_{O}QgA6DWF%Hr}F< zYm5u_3&iu~huL(K6g-RM@0Ipx*6|8|&6%;ac$A*k8|J;Rb`{9Xh`8~{$l~a7 zP}fJUVCSg=$^3_(dm{HcKPHTIvjVVZ|4BICg*YS%RozhitrVQ53EuxGpw$-cJmCI! zJoigzi(X>lItgHamJ69GY2x=XG{clO@s%SDzr!>B+tN%)8t^_su<$tjCQiI)1fE=X zq%9e5asui@A36Tv0)xANo>vwu7BwOE1)yn9?%^Eo!xj!i{i$$ZSYbl}6U>$64del= zN3TH|adapjUI5Y3tD5)^;*nghb31()xy(mblMi<#k8(fuhPh?d!@A*sKA;jp5HQ7t z!!$>o7A7%vDoFXheMVfe3`8M1HVFiIoYYJ2keHE36g|fp4XAroK(gFH4M#i*VJmp z+`KYxf)hk$=_J+1l`vaVBb;?$Dd0Rw4sb@C%FW|If4zXgdi#o!`^q^1P^l`!1FbIOY@*{I-HNP*gwENOf+WE8+fiy7jJm9Yt(87K zINQwtTrL@(9*`pjboojMQP}|d+_1W7#gO3_j7Z8}7O>@t$%}F5mj&~OB`eK5I0UIr z2REwubxSZcY|h(voY{qmt5SC(sZr3D;#NL({(IjOP{cz3_O3 zR)_3OE(>tr=~pIgktWL*nk?Pho3K!ML})d;6Mp zj@RLB9EEqL2vu&DIb=Q1G7*2se9{wLD(L9S90TMVvWX-y*ou=XCD3S%&0emNN+X!; zq~Ie6GE{i4JI)P2fu$5P3~0ba0^|FytI{XOuJG;|1=$ZI(0f_{?I05&GbTLD_kiOi zwqSKLfE2FC_VifwJd{qsN%(O+0rRC{47Af15$f&l!PkSqT}4|DGwX^7CjX43}4d;^DPj5|$Q(V#KPiSo!GNC-`eFNTra<-ZAJ6 zQ1W9#P8tw<|A0dIm_k-c2t%>dQ0&x$9m6^A!(D|_|Hj-Y^r0Q?+=v@*6{vc}74g&% zQ9KE+H1C zztx`^E3^oLL$h)y;tBYf$$P_(6?LUPkHaZ>LP3|ix@faWa1Dmb!RSKybuP1gFM`njv!>B$>CEsQ6cTj1DRA zkw$I`leBl<&sj+%bkA}m!yF`({-RVi=PKI| z+Qj}B?@wF^YB4K5#=d1jixf4A_T{(jCg&$}k6|a~FPi>Y7w81Df*%0AscGNW6&H+i%*;Yo>CIVHM*>G zz=*NgS_p^L6)x7aNfa%^Q~p@^IUu~MkCO0?>tLU8DwB($M+5!i(IgBLS6E9vZt(;0 zI4!lZoYyHblx7M^pMgXKW~Y>-StCX8JyyMvOyn~;v?4n@6Gm~5HMfp#Tj(QG{NI(R z>I-yWAoe)@6k42J|3JkVbs688<6NzbF{f7Pl8FFb>n881;iU)5= z^-w9gfwC6^m!AQKcz5$JYloc<@mcn!yH|qt`+>P}gxE8_f71ZQNPQwyN`Ll@3^P2X z*fhfaHS=lsg?q(8ADHLoxLtjscuJaCXx`FEZ+hJiqE#NTU~u?Ea(W>q_JfL73W&Iy zDJ8;E_sxfwK)_)NgWTvA8X7k6e4}bvLBdPE`xK}Qqw2Vv2~Z=pP&RQs6eD>;!)v3P zQBil|0P$i3rfsW+Yy|aoaJfZXxGE`1X@c3~r7*S$XA|h9#D99^17M?klJ>!9`y|dgKq46!@Wwn`zpaHt5ms4A6xhd=wP$z@!@er-Y?9kQ9d`$BH5vD=_lB8 z*7@KeQSiDZ;!&oMX>O3wU?uf*>Jlajz$?jokX@#mKX>hJ=VqC2h>;ht6g^lCK5eBv zm=N~d&2KQLq4~qz-bm!EiE$PS6NZFb9S0r#_7v-HY7o>*@U z&}aXf!tNJgH1*%J0BAQ)M?aF&eq1>qnD!N)r=!ErcNmh!^3$Xg5-s6P@{h#3?4{oRU{CB7fkH!pAlQIP#HY-CQf zPN7<7i{c~OTUMr758{z~Ps z8}4`6y;DVl*3;VV)RZ?X&yQcr`_@MECS+7&_Q$3j2eK_yXk~TR)DxDPSM9(aw{`t^ zUs(G(Xpt~}l8_=CNwHs39m}}1li1_M6z@daJi*(%T}1s;icL>DxGE8(VnS;JZFSsL6&s`?nG#a>j?N?hGLt>xZ{s-#QgK zw8u80yTg3Uw6(`5`6Cfq=ow!w?>Dw7(0FhSZ*#h8!(BjcgH^iQ7P}$G^C;47?(6y4 zZTC)bvSEkPZzXXR8@{K9648hBON_8iQU2RTdVItAr&|3!xafPZq||Nb=~emOn<~nO z|N8>M%LWhPmyFi2$?^S~a0ewXEi-aFmG3#+hJBu#G>*+g%diJl8KzV?cua1m#;0zD zTV!U7tXo-usZIfIy>*>RIn|c>I@9Ze23V8GqI{C)UFul_mFoa(vU~+QPK{maPCtMC zybNn%C5NL(aB{Gm`ZZF2#YRSHeH>*^fnA2VyruerFqC)z#437~?OG@DU_?~ zdB951SC31#xP>Rx?$X3PM!gPN<=!X4FBv$ zHqK@Pw9~L~{LP&DUiodKuj4MB&mrQAw0L^wDcZ^wdC(Y4*k$d^9qc~8b~LS|e5-*h z&9+d*y+7D3YTfu14NGSl zQk2KXpeWgw6$E^Ce9Sc~;xVdEGl}MQGs^FR0Z|Y3$1;(Yr7l zfE!soKVbZGd3&GQyH^6I+@*6Yerf(!{SSLdgCk!0UOvC7-Oy84vah|ps+2S|>juWu zRZsoK^0#+Rc?GIvj*h9Rsgu1{nE%(9_Dnnwc2=`$cCubvP&7VJpR$j=QDR{mgwjmn zVOy`;n-p37fvoOc^!WVUhNj4S+~+u`VRR^@qRMK!yvW=3?YB3Nqp0}APPTwzu3J{# zaWK}P>MH36A#Z9u+{5fDs{(RVWJh5we7r&&0_L>E-&`1xIMfc}sUMSFjLY?6)z=Mv zS{PbXw8Hq7?5V@Lsjm}a{(i_BT)`g+ruy}X>~_47RH|a9zk9Mde73?$T^bfDaE-mU zlcsEL1`K{_{6dTly(0s=!{18X{A^uihgnarO5R{iq@4`IEW*d79T>VBL6ZEg>Ni*S zq7l2lQQaW>k7CR^%Hek`5=Bc^nD^g_O4xZ)wOHaMe8x|F%%sw zOC#TG6@52OBBxm3E>xP9cuRwrxc*a1p7vCtZkHu3mYsuLrp=i?->%-vThd%QLa(^` z<_0{(>7MLpJ@Jxmh{$TU^GF~|r&yl|Xj)rZblsdGGS>Ir7=LqF$4$QPvbKPqW|_gvNnb=-eNsh{GO z8YncUYfG{5#u0rf2a((C=7&pEat#64`U*r9L%hi4qKL>qHru$FPe+Gly+?!jFT1=J z!@Lqpf1vYHyCg82X<;SJ2a~o!C9023I+=nEtASZ6Ua3hiMjTQ^=kRE;Dz~bAF2*U5 zNE^q4rR8fwKQb@Oq0fZcH8pv3BBw?lMn@Z;D2ODuWSSYI^=;&K10ED9ea5hxaOvnZ z+&Y!_$ZNB2=*Qj5fm2=&Cmd!YOBhV+W>!^F6Es-zl=)^(%~MfS1UoT!&;KZ2Zaki6 zwZXKQ}sCL8&2{$(~#r@`)i*!S9OOz7ub3lUe&!Vm4M}6ryxWsYUsJBp)T^w zrL}d92<`@)ww*k0-~F1*PrvUT7-#H-;J!!X=pcPebaSzKCxh~FnDeAtNd?|Sh=T7X zHaN6{#eK!2K$W;o7bVcVR_wrL+W5};*u$th*`T$^d+AcM)3eAqR@=jKpN_r{tfBMB zd?WcOxE5-{Z7l~})q@s3-x;=oE(P+~)z${dSu1ob(w-lMdA}W>cws&|;Z^6{U;;J2 zhD-g}N{){gW?G}N_lQY--^bukiuGA%_ub8*_>0BEeu{3YlXJ$ae#I1}C*`C3^uJN? zjD5TzV4{~@JQ4&o9b_G5{1zsM{q(cU)tvMsaz49os5G-7HpzAu$flZhW}B$p%D?h4 z@gD0s=^%2pP^l2U`u?XS&K4;?Pl$^iiy#b<@2s z)7{@ZqBE9L^8IC8LpP6Vs5ho9zSHh|^qZ{HIa0pU3I0bcExWk-zOsmId`a!(sFs^Y zbd|LNjGX4d$F_jwa1&scmX_8Y%#I?xSTgnu^VsZbX5W;c0}mp`E%qXp&=>cK*a)m* zxT2d{?Hp&>Y;Ihmw`O5N7YKjOpvxyjIJ>9{5iU%g!w2YJ#}~zK&zcdS5cl;HM?v8G zqG=1P6nVvl$Ij4-g8zr7Z0Eyf3;d^Nt_(FEm3UNc`Uo(%I?@&&HPRm(HCjq>+j*lIUgA$&M)V#)>JxqRaq9<7gKgzSlbaRhStqAs>5UQ1Y}s?hCsauG2vIB8G1h34 z3psCkUU2-mDyz$=72AD5XbIc~=aJfWHs^=xOiOZ3^Y=<~Al`X`J5iLOb-AVbAc&!c ztA<83Fk1dqqeFJh(r5k(EF(ZkCNxGz%S3OmPBN?h4Nj#{9+`vOA{ZJI&*jMF{BX2 zDM)52VA;RG%7J2+k&3*HC{DBgHJ^vdrkS05R#xF@2S3% zRW-KJcBGb4L-(no;8e2f<+=LHzd+9U8L7hvg5j*9!QP0qwep~oiPM#vWtH-<7KYj{ zo+%$vBp>6;aaRZu9DfJlMRiuanQgTKBG&QSSL~}v`K^UC19$ne3o1Y_KIKdZ9OHX+ z-l>_*K55rsZs`-J_q!SO0pY6Xp{?)e`<$MPbFhmrAlz_d^8~v0Gvj1IwM9#37#NBX z`Odj~HM!HEUT>fdo?XkG_b9F?j?c;rlk&!~G5Ec%*Gk6}hn>4T0-6J{iVRmD5n|tl ztllX|D?cbd#%s05!qc;SMz7cVa-YyWQi9eI*W1ZI^~sxECVG%PyTxcm8mQ zg;!{cwo?aW%Aq$%&yk$*l`)V`yRm4qhRq%H%NqKdDtBfoQ^j^h#Tu>c(ikgJL_X-7 zkP>`mzd*4h2tK@z%DblTIo~AZ0Zz$own5|F?{+(;4|?qIQG%j26JtjpibbGa1Nzy( zW>!YhUECZLj{_5A8m&IqqpkS z@KJ@a9De$-yPe0mCMjQGolkG?0kH7GlfVGzR~zGYu*6MDW*tW&?!BNR(+GM*Cv}x~y=pD( z)>iJ4QPI&UQlzrAaSIB*G)gU+Gn1Wi*{!ER6K4>8K~1aWO!Mie)_}jSS7|uMx2--mO?df;hLi)luYXziDX`{TbegJ>EiiD9~m2^ zE-3Eky9(}uTls;h$H`ux6qw`_5cD~Aj(O`WAizyL=({5iEPqU|#MasdlpraHto@W2 z{5%^8f?z@&Z=g$uY=8UGn$Paq1=|2@`!P8a3jU~14Oi#gn><=!??=9Et(cJu)Bp8S zhd-HU_sz*Ha~{KMS5>QS=VYc+$+(XlN_>3d-5zU%O6wzl>b+6yX<1D z`}e4oeecJtP&*q>Lp{%aQdk2skolfSfa12`ZD4Bt5mUVwKFqYoq#vq-TH$ zx)eML3eo_<{ylNY>+brOSLqQVt2OM`i6E^(ST;&O*Jb~5TccK(MnUWSj}e+&Pb1L z{BWy`rv6IaVovW+t$Ob)(*io<_I7&9w4Oz1bfR&P&c+yrz&?=a@r&aY1>_g zu>H>RhrBlnoEy|Npqs7}CZiIQfeDVuvh8cYyxY~k7lZY=Of+|aGD7pMA3H~3|7a+q zcq6ZG#VYk+Uyf=D(7h3=g3ZsX7_!cODQEv5(b!@N)%-=o{GZDHaI(LAO)L0VUY zw7|7btq4x?Y2_843vM!}NCzZRcJB6%z9yorYfnUlN9s>uN^bE@@>w9D`TH+qP)>Xk z7W}5&C3{G9bPd?jFWD>G;He6a5+Pi?Szj>2vQhcRvf1~09P~fD6%Tl+ud^c^#6;J` z&zV{UMg@#aka;ioSA?>D&~xy&Zg}Rf{_VjUO3dAb!B+8c)Cnw0RqCd{`PC^kc>Vh% z_z@r3lX}iKl^XAht^12zhsr5WLEx*~>h)K1LrjStRy9EwnJU!8&Pv*7h){Nxjm`iZ zv7?iV&_`Ho3)jvZ94ua31uj_&OFrt%DfZ3C81R~|2CW6#d5)gGfcF5UEiEH6!?T4N zW9FRdVAK&+#3zNutNlx^4w5eYzMAdKSM{MX-UZvI6>8~CG)k!^oZ7G0veTbBWsW(8 z_3k3a`_TsrC@WX8?0x}QS=QV$rCd6uWt9)!&0+aVt+gC3A0!b);ikTBVSVQKH zEytPf?y)UP?}fS*-yAnDxHUVjC$!eO8kWQty*EZ&CcFw4iR#8Op zsxX#8Jw3P(aNMv(-_zk*hH+4o$taIAJU*}%5-;t6%$b>*i0+RrRvGfT9z{ZrTrz52 z+HAiH5lqk8jiZQ1p1IcDloFQE1^pWamrk~h@u~Y9TY8pHj}uP(YE7?SLT1pxYAGCr zR|#~OYQtIkHU<6^c=xQGoY-hT9E73zhd(FG%JGc{s64w!FbK9ROiyn)c1 z0pIUqR*vMGU4@1%m=bb+-x_yvvC*NtxRuP1tBzS{A*+GA^AD@c@y(dqkn3vkS6OtG!fWJkE}9*V#`@IO1O!LDX;b2by~KDc-oz zK2$b>5zHV1tb0ngOallQ47CYZ}hY4tzsg&bAquSJkcgc-U)^XYkdCnQqg^Q8+e;7<$>4 zcJsCj{e-=){HLMb_0y+)tU#i9S7-p{bfs(5>El=>5&lQBw_lo!D3$`3ubA>YZZcZr z%q|JK@<{deQxe3&Ymasz@6RhPx=J{{i@Rc2qXWrE4^c<=;*SWR!d&4!9Y6a#>=bxpapR*I5bpg#Q_w3fiY41 zE7&>ZDQ<({hQyNFf_MN-Y;tLm1}Ds5&;YB^6_|Vc`xpP|o0M(RxY!Fj7OR))!NwzODaDsW78&;|IihEm?? z@C1!BjB`J|@$Y{JISp2$mi9vWDvlcwk=KIl)g@nW_D|4AtbgT%`&POzrjFzJJh;D) z3rpc@MM$H7!abWDi+fCIu|B=M^*Zjg8U-Leiub-PIhxhk-8ls3t)67Sm}bYgO5^;# z^Vm`6nrVHl0Xh4G}pHSnMPr=Wvb^ ze9ATq`E*)^G46LkutBz~V<&^#M!-4c!-uuqQit>t$M8JY2a`lI^X|j_GRSmdFpz@W z&r-N!zR zj>dTz2#rEl&o-#Mo{nya7ZW2rd^kV1#AVMho6LoQi9pW9;k>IdKJweQZ@gSwH|Ohq z-3RBmT7*X8ydEqRIRu6#K#6-soW#M*jvVhH-{alnuvGc<(v&dHnMn`;tN0$1iQAkR z#1n3VbJ>38ri*h~h?9XQLp+>Q-pe$4$;r`{2fm)ec?~l65Ic0I1L?IlgLAMUn)3AN z(_B(R#Z@f}L*oU0Ks6S$441dMLH&cY1$92vnClzA6}5zIg@ zUKheS3%?O|d-71L*SA?^WO)UPW9C)+y%Ad^hg- z9fbK^>FjI4{TzHEUT`ZNyy29%O)>?xGFZJ1k>WIe@G8I7DrP$1Ms2PqU_+$mBOA`Q z1?wYb1Q+*^dlDbFC9;6~5;1fo$LYTI7_nb|bO(qvlRk6Y?&Kc<3df*{D}h5oe3)g} z>qBt-?yv7w3nNii*n%6C?%}ksKSIWVaje{f9`=9osyK(@3Cwx8&8z!3FTjbaetV(_ zq&L%t^89hX0EX?7uyD+|bLXaGe@i&G=bpU64f^0)-a=(;*AbdH+&4|fBLAfMYOhb4#DXMMi9X!d$T7TyBk!QEmWoh{ z;Lv#ibr^EZR(jaZ?Ajm{bNnNgG=WBJsc{~A%!%?}$tg^K$#kmWSW4^q`uZ*Un$MqG zhlhvToyc&1B>tbHGzN%lzeHa@{reX~8*w`TM;NfW@U}+WPx}O^0jsGV^@?0`#?vRR zIG6Ze!UWhkK3*!(oY6eqB7pl=+2Bt~9DAv7T2w+LT=6MW{+0eA1mDqCMVwynZ#;qg zui^>ltq={B0xEXUd&3)WVy&)*qys(_mNW{^Wah)U(h+CnJqRy+M<_IG5U9t{e8F|2$>D9b&N7l8_8bo|z-Y7PsHt%cvkd|8swg zb8df`((`Y?1Y}9l`fy%`RL2i^@SyGE$7^0;iHV84f`Ywu-#eN28rp<<_$4aO8)_z!@2J-Bm;jlkiXg5 z-$>)XiZqbWVju-NIOPsjz*%^()X0+TfDuR>DAP_LaHk%Vh2VIK?$7PDfh8Ude#GWO z&mQ{*S7BWqJ)sbRJE_F~9sp>+|8?F?+&eZ9aqHdpSpMP(3GgNa9FJk=4cBF*43ORu z+$zIO6q7SNc>o_x7qlvYGvJ)ZULi1@9I=NRaZ)z;2eozPzeR2R?~hYB{rj(D{pTA; zf0CohVFO9)hpEdrACy3egib)o09?BGwcucy;J16nw0rtX_n1n2B4^+63b>8ob}I#1EO!t@ zf?%qJy}kW^idUZi?eI6_3UmDDtM&hZZC>?E&8HymsXN~f-TB)NvBW4hgiLoF!E5u^ zH;Z*q*Ur83<;cuk+#!}hnH1KZWqZqcqT_mHe*fqNfZNrxp8r>p%63IXa*uUQj&?kP z`$n);`5Ih<1j#EMTlR;v%9g^vH(w#9{=NAM&gXBk^uJ4% z{$25fb*+DaZ~1%k6<2H_W_pxi$M>*7roUFY&0!|?xJb3W$9opa7`Cr(AxdYh#?;Tn&-#^a= z{_o@;h;H40b!tSwM}u>Cq-(?WX?cv!cCT5;vY`5YKU1#l~T z3MNuGuMwZZI@;S;58$MV0hOkMS5afL(#4$~$HHCqY}*DXO&dSFcU?8Srg6MprF&BtKa<1Psi?cpBx{!uQBM6S)=2UGU?iiH0KrCi7z;=^QQ zX@P-(2$q0=fY!-Lb049x$vww*7D(KJ0}ohmUWBv+*Qb`3Exkvb?qAqJ{`HgqHSVoX z{J&7MAnpdUSUO`hkMnK82n4aEPPdv@abANwO7zcS-?rb1eUGBGPx9OLM}mC17{NP( z_kJJ%UjCmlGe9p={c^VDEr`wVii(PM)cm~<3l{Ky^d!jN6<@#cdw+$*K!QSMDu+&C zdmjjb#sf~|?&?scK~72J?Sy(gKG9nOWiAMZXubk=KB&VGLwRet@&szwMdLeY0|*M};rLG(yh?=0TK`^a$u z1&ldLjr&0f7I2r;H<{P>;7b$!H2D8{9W=NwQ;P22N-Eo_p-I|x(6)*GPtP6f!7;;+ z{x=j6?H%C%PcY@MIs4mTnxQ+t?xOxzNMf)nrdty9dsrvY|EvQahGxvi?nX`ThK|zr zdvHV~vs32YdUn-BzK%2K1nae3q8?Tc^QXBFp5@CC%OwS0?vcWQv3GflZ*fB4~7~d zC{*N9{c@kdcOM=R&OgSK_xtG!I2)C)|Mctsyisi|9taaW1%=Opx}GX3&koZH-U7W= zD*J|DKX5{IprLc1^rBvA{$Xm~>qn0s#S6I!&-+0Pdq!OmgA~7eqkrD=U;js>0m!)6 zFJI;leulkD_qb-BX|OeD#{4Lmpwltysr0FQe*Wf%w?Ck z4EBNLKyM){-IOj`$Ss@Ks*~MGY$h6Z=?ZqmmpTRkBZQSM371Z;cN=s2hs0nOTSuwJpIQL@Y`r1i^&-#DNnY2 zWccGTWhf#_PD`tY=1p^lG(bP(qf<{01zON`HWCB|>XvJ)l$z_=-m~E?UVtwXoq*PD zo{gtpN#hv!Uo7b=1w1i`+UECW_515j;-d|B=tL0M-V>gPP}gQFkOl_rW{n;Y%xJno zgxIiu>o4Aa2tK*&!t!BgQTShYB(~?b?~{nQb1x-wotEhOnT+mNHfc-cVIX@n> z`!h}VJSf*(|G=MiukmD%A%W%5(pLY8@eik1Tp0V>X+1`nL?7d!N6~sT4=oq`6~iwb zyw`HN#{|QBoUa$xdb>z;jcF1DQ+#Bhxe=*M7r@w(snqDhq3$GS3X;0uw5JKkJPq1< z>ARS`g`UP$X^-JM>~H%fL~MSxjX@?voK8mv?)3&AQ{i>HV@^un|KqDllo^gnRWqh+!Dsf|K86 zU?w^t!lMxVc+Y#jgnMd6Q+)RStkkQWb8ZmA};Qt&-O|LjiB>& z6&00a$Yh?y#IQquiPdt`nko=pVg~WZ2sp*jv&1+_*~=a)?5=WzPIxlPW1Ah?S(m8f zo8JV{OeJ$#>EiH%R;zP;ESSm&CaIw7fk&B>#gTXvq9g;1zT%3Mu&J(U_F^WNEn&?C zWUm8kb0dRsvKG#$Q=5DVG+8F$_9lC&>N5uUgJRQJ{wpgoK9L>J*jJt z#J1$4Fc$D!e`|JR1%vHrZC$B2e;@RzdzZDEE5#jQr$@bp| z?-$o%bj<*ykR8#JH{)B(?K40yP1J;1`z4ogKsmG~!FIvlq-l!OCUOq zp`i!aK4{wTDeZ=vb3bU5kes)JQv8MRj(4dfNxCB&=psC)@vF2dWeo%NA^>t(+fG!YEp!-6l=mHbL_RZ@M!>LkSwpF~vK^9@suS zxvSnTP9M2C#JkV|?6C!?_?UT4f5003Dez(3Dv9ia_ze{@(BO>|x=^f@(2)L^K{Y@C zC7oYH3!$IMUwn#fUlJBVE7VJDX){=^m3kwJsWdlnQ0C4+ekDlzsX)71;x{@)*>Ixh z+;GJ?t8L}Mw3L)DaOg*PjSw~uB0v@LwCgM5md$pdAH~&9KK4B(5mcRdI>pvK+)bfJ z*1POllLJ{!<&1-5@6(gFJ#XrJt*FI&q(jr7IyDVF=2#)O7l54!`z0}JQv+S}w>94y zAU#1LAs3#gRq2gVs?=7>QJrtmF0$o-&W?GCX(1 zPN{9`=bJyB@YZiKTwNnXnRN*iLtlImpPVm^^m#;6S1s=Su|II&0N0(J7W)3oA4LRm zs{N$trwAT9hwmpnYC)8#6fer{zH~>1Wfr8Ag|`-}lgLb^joJvEKmo|TW^_T!4FnJG zEDXPQ?9y}V@itf6TwkoAPXR58oq;d0Zi7?q_L}t8b1knP+=#rym=NNF40yHWufOBs zHKBGQ_2B^FI#J^~|XrSOxVOG=Al5Ay-D4*uDgbMG#Cm zJ!lnOu>1nF{8qSKXEoobWxOb=ZOI7u_Sk!$0pPYc@`xRK|LpTtC+;rA-SFZ4HkB6t zqR^%f@d!_Ur3MdO!;n7+((MOAqAATM!KlES77#4#Q=+@io zLnx0TGv@OwWY!oGi&L_avWd2s;he&vV+oFqNA!Ya56FU?T6UJ{dp~{ZC2BG6@~jG` zn^--$CFq#B;cP;QZVl(^EwYE&AjEgeL+wFywpd<45{Z#}6u=gfzd>3_t6 z2G9`2L>cDNJJHO(6#0`f!C&c=y9nP_u8)wkEWGTUW_3!m6)c4B9EzjtzdGmdO2qJF z86VChlnZlL?64kPEb@HN)Sl#Ipc?In<`4py9ceWTh{umTW_kfn^tA3Az+@v1if3(}0*e+K-H z$f|e8%*NWMoCRVL(u!%ZPup)jZO?a>l@|*m$&Qo1#t0cWkX@Sm-h*`bL75lS zG{t!srfpAJ#zW2$+KocSRYHrqq)~^ZN}s%hE@*j&^^>m7K3N$qR0A`Xr?9ex6eSvM zYHmW#lMJP;Z+5>sOL3x_IvwsW8&mEWH!D{(BD4KILoE1?N45Kyy+$#){!P1cQzhF> zC0Cm1)@dj0&++KT;YWI4YS^C0EU|nKze=!t-EqfL<@LVZ!rAa|ls~1FwR5me95-2f zXMUbRZhYQKmxb%;TP|l6j?gQu%^m_vx z7w)dolAO*$@<768y72B{2e6_p3E}H>{5(tLSz%K-Y+kU(;!V8bnUYV`GNEjw0}E#b~44u z75d^y%{q;X-6*gEkl@3$Ni=4O3`=XO-KkU=UwFqWDOItG4uFSCaT>E z-OTHLXOC?W*FXFGQ(7u6j0NTsJ;t*?|92!_8u!kDtsnkx92w4xeKOKBG-{>UFNVaU zO3WFNIwtJHePY*0lHof+%t4g%%6*9Ibs${IfvH~w38^q3(!YSjqF?smh40aJh~8d_ zmfOqpJf@FL)I=(?FDe6cDjD`8NroneaSsH2Hs|=2#H~RkqL5A$>twMJ?8O8(2}0{K zw?HwdQ@TTQN6!$@euxm=)N*qkKfjv=Zml&XUf zs&@#-i|H+b7O{5R42WSTW93psch-$JKn<0nZ~8s36bSWIV)d5U{VN7@Xs0*EZ5{lO zl%6T|3s|MJZfnBJkiu6Txq{_G#LJ29Es3Njp3|F>uatM2>t%*UAMaLZElrZV5E`HQ zN}SRrx8>f>LBqcjdWYH@S0UsgeuVmWbofSR558gZ(|7#B>2C5RIDkus`3 z{WpU?0q=0ZOx4i-oV%!lpMxR$jfq|)$;8l)A3HW;(f*jP=ukH)V(P$C>Jz2oFLVwE zwzjqkv$<|O2+dXRH0Xqh9S70JMs1pf9e~qI!p>Co!|`}t;;xg;iLJRN@)w|av=(0z zD+MPC{R|X^oAX2Hn%}m8j3LL&6}WNG&C-z%fSfMC+#a1zS0X)q`8cWOh8^{63!&R~ z5x{^tz@A%=NY`QyqPBz}G(8;~yR&Q5w393wmJD*d0X17B)rSx%si_Ec$F_yop49kP zAS~|j*|?A?Lev+;<0=!_)Ei<1RKca1xUwphv=O#L^1y~sb_P92FHDD3+kv}u7=Dar=(J_rfrE#ioKtv-pNUuII4)-~SY9$3{t~e|dqFbb zFgBcQ+FNQ)!BQ(IS8De-LP|!^#~YKKmR zE&C-bz}xRDA?;Ggu&9-i6jaxFaD-Mkk^7Y(=E*Ebpo?S>oucJuZBwg3zcsBWDHi#X;<$PN7<|diylrVJ2oR(1iCE`nz#lSCs+8td<;A$!uV!i zf)BBCwlgJ>P`s2)64@;sf_C(NTdxES4S$}#SW4^0NIo2+kz7juUL^?8A{h$b;2oZu1urN_Pq^7;b0? z=k?CwqtK$}AaNZ>>6lkKLq=H+DN5YtZm1}xmL)TVEl1zs6J=^z_N_P1%nI@WQ$=?! z(lu|CKr74gveX<`KYc;PZ3?11h24vnK)UrE&&4^(8Q1CdFPTD3p4jHfsWk$S`yam& z^X#I~THLeac=`ekh?3M(C(pl#(WraT&=eZ@X_jpC`v z8}ZXR;B}=-jA%T|W4ir9QxFMt09P;Xm>MMIuOwfXlm>Mw@DJS4o@k(&IiDY!paBI@ z!eqxM>LZ%a!lr~oh``AhsDtWVwj4{&U}4wF{n#nls?}Es_8w=T)|s0*zw+KHgV1G8 zIAIWINB6mcu3T5EjU6Baj(U-6)6?UXXG`J91)U4Q? z0q`al!lm7@cn#D92BdS?HSks~7Mel=27skCcdq4`z!>j0&8VlooJ5Yia-zF}Wt8ya zk+jc?l$V_G9RsS|sl?8zm)3+N?u9kDj18NYT$&`?mjBFg0Im z(f&RV*E1+lX$M$g0=mWf=kmohDBC=)XM{_Vo;E!`ovuAz7tTXkrL%4Cw@Sr6$p`Yj zZJXqIT-%U-zhDc=)_nINZ{($xN1kKerh(nIfdr?Xdod_G4^HpTdH(^ipEURQq}-=c zqjxQry)e!#mCJ7#pM(*pv|aNaaK9Q7m%n59JjS}JfRWq^Pl|Z#=_9$5!!;j`Y>ZBk z{p4%O8DYnQ!Y}#f-2n!50@qm)4v*cD2z6r@?^7B-C~L!>E~el0eijou>-j0ZKjU+F z*#!g0I6KP0LQOW`_hpW@-!85MdNc2YNk!$ifYW23G}qMMgfXr1-JgWSf2E*7WZqx= zu7y#@q!i1&N#!9NM+BJ$a;W z&^#@bK0ejPuEd~J&H^pCQ!LcRF5})pPP^I)Hj~ zU^dhH0P$pekga7XHA=GW()Y+h&-U=B=(S;o!QcYL%e4nuW~sR<{h%lPya?T~N_J%g^U+l6z&E$V2+XLiv3A9CY!%gsP+g z7Vj|V9_L$SBt%e8q3i&|-{?C{8_8`F$UGq@pgCFH4Wb;sK13awVhc~ zK9{P@PjlaS2>V{`lTWg^Ndu#@5U$%RUG+d?8D`M4GLnblu_(qqRiI+bQQca06jr1+ zZv2KvaCEtrIb~U1KntBx)#w5+=GBo zXAeCW;J+ys8}{zw$@cCk{N>Id#PbD&3`lc;U-FM5G~Oor(UT|RJe~@@?-iU= z{Gik97vM^m`S{eJ-wE5s(tWyT?Hykl=-N9BUS^jX6cPY_&$?yOY1X;aPls-$?7%7r zW}X$DUz_VQ!|W?ml6n*ipe{N)B$ZKQyZp4xdpcz?pnoS1v?&&$(avp6PJq`!{Q3?W*15DH|Yi{yvm+K7-%^U(#%XiQeU=yCD92+@o6Da&5&+e>WKf> zbTPe_JiuZHv#w=&cUQedN_<-!tx>4f#6my4LtWN+#E2<3zktFR+ft*9?0(DqDK89A zgY&M-)9yotmkn01vB!5@(x?hdS!)VF|1 z>2IJq@YM_N5Y?`*ddZFbU*hk0l7($vD>nK${+VI-OsSS@)qdW-*^gJb?I3rAPJ!fn z8cdmy((s38fhtMo4RgAqqx$)n@*o-201+fI*o`_ubq9lya?@+7+H9lj%g#IuQECfEMHu9D99ZLQ`yBI$AV2 zE5n=0MxE@nP65w!ePzR}k56f7WH$F_?1~{mNXb3(#xg3t9c$*0pGisXDDvwi%4+0u z_r?JSIn`H`)52KKt=Ho+KZx4zQtCR#3hJTS;w2|wTk?yunv$MPz4haO*r?-|&6~v` zY@6Oik9`Sy_Wb#syA7<9ly#QL)RC;Z#c#ir*j8+qZlK$gNJD!#bW0pl+A}no-j=F0 z%lj~lK$mJB>$L4D5!2#6RV5LVQbaG7ser^?p*Sr;1dof^ehE#rVX=9Lag0*6?SldN zrQ_Edh&=RM4JBm9;v-1C&R_d!9gN%r!cn-9+W6-U6+fb_iuwFmf#=l8Cmv9UGrYiO zjf+3~iP&sQ!~HUTu51;%IkdELisG|U*L;P6Xpsz_E5J(_1faMb$F5f!KfNeyKkN%x z)U9)hG3P-o?((d@#*s2e@kv8164jd7j;fc40HY;`bqCG2HD)gfnRZttHP15I+Bi-& zGhWcnHx~!E(Vj)Ob7ZnMAp!hqE=Umio+%ncWWdn|mi^<$8OYON*GYNY{J}oFZ^W8j zp^&xrl?`Gyea%#{2$V6~^MW#cgfiwa*c2V8tMQ?ad7$>_X#r&U zH!ha@rPNv3a#s0MBG=&Taqy8qwdNU+oC*V|e7cZ_t=(zL>ru-psQA_cUu@=F>l&3^ z>(Z%ZM%7sac2?Qw-_i*_8DvMy9v=~4Z7V4U=t|mmkD8r#_VHDGBXI`4(W2_h)Ej1* zC{3()Dz}y%nb*$34t6+a_Ki8jQ|{A0**f59w>^3XJnVw zEkw3`z^u%8h#r%kbCx3Vs_pa1Lx_8|zrM+BP_=LOT;ZBs`J%KEQA?%7E9u~|q1~M# zEGKi1J%Q}~OdV(G6fuhk0!>uHNiD5347;+wWL0s>8Xq#j%)g7;kv}BdpE2K}?0=>7 zgE~l-lXCjNR5-45Xq=|YV{IR9TO0DBSN_C0Y0@bz)WQ`9x@r5j1vu_5f8{p;{gHsL zqFW1+C7|$eI8=)Or5JOC5u1uxo2eHYyvMTu)bs5duwYVS9#a1L)ay$_;MGEfmmWT* z73!=EG`p>FKlmf@2aqz1M$dFLEzrQpfM=jh)3b2uKGDZlRh(38?zAHA=ja|ZMpD#v z;C*_wEcb~d#;a(;^3p}RpCIBxLx2%n!UBhX1;hXNnLTS&PL?3$D^a|6%uiyT%dnc= zLv{Z;VO@_nE6$UNpq0BNOShKlv4W5LwJ3mrRgAEns=eNB z6!zJM!RF_|n6LzmeTN7PiQgt^MEdr$djdumW$YI9!bE1j4pWdE%O}F8e2bwY<-24< zi3T(uf1Fcac(QnlE%JULjGj8%eazbY=Zb7~EN@<&4Lr74;}tT8$?dQ~87dnP>3TO6 zq#?cXwXQ*3uz(n$HP9Co%kpfT`kco^tQ%7gvlrtN8@w{I)Vjt4X?b&7F{k$9e>iqt zgY{-md^ov2Oa9MJiM-Yo#!BB#I!?IfX#a;r)-)_ZPPdJoQQpz+byY_hg?`YzAO#(OPJu z(>KALZvH)a+j7^n@HbNJa0YA=tHB7^-EGi`{1tQ$>cw}-hMkKN?VnNQGByL?of=O| zIzDCA=)6qn5_L45~$e z?1vV1qImn-qPM1+xL{Hgkbo6@L=HAqw`|8jetqV?hnDHuyl7sPrik!)>{J! zaWQ837tcjJR`H*2o`kip0$c0?Z5?r2GY?AXVH#j9Cv{OZzUcGU2N7OMy_@wud7A=@ zc4mcQ0PfoiI>9fqt@S%8iK~J^4|&Xsv=F|L9000~RYoWcyk<<$CuuJ~{dZhskC}vj zaVZDd-jSmP9Y*dQXV<(W9CJyB?9?T3USE^^C<0VEFdvv*ofun`$Jz$P>;6pXR?TY+3~Iw4*;asS7lB4F$WN>bA&BJ}^xz2j=5C1XzkoCIEPj zp^1l+7^Jc9Gx-sp>V!lNgE&nH$$tI{8Dqvtm*J}Y7xc^T7B)&iGs`A8@2e#K5)Qy~+M=r5GppGFKeam7pZgb^!RBi>~zobW^B5F5gs`L2r&Eb6`w zKhRJNMyr7c&h|Y$^a{){x704^h5hc%_IQ+89G9?OvcJA*hRxKMBHIC0Y#L1SAnYnC z4m8a#Cjbbkw^r-+Rn%oI21H{5@ARwD$@tP!JPP44lVH%s%bcdY-#QN5u=8%A+Uq^XTP~#VJd#!|ql0cPd8gJJQ6VuWrChX2DW#&Zy2kd;39vNQN z-ESEa9Ky~i+}YXM*gvkFkiq&jPNZ)k@@8TH(z5;aH!X(@0Dg{BDG@;jzch-h>Bdh3 z1_;2)kcyfN(t@`>jpp_^r8U9wj18d>bLA!oz3-pT-z?9*U8#>=9<3}{^etV3g1z#i z<)u&(P;X=ChGCf;4`D+|Yg7f#w`dX}-WbIRQ&qMbz74;iDh*8{Z0yhGN>Mj=Z5q7# zIy=bPPKm<$53DB)R%FKPW}pX!`W5C|(E49}Ag$5LBFQ_8~P-Hur^1Z#~h>R5~{Lb8*cLSR7HU_Wu@k$dGMsn1*&;g+&k6 zpQjM7Diq)ODC{V;fmb-AN`bWhWt;Q`$8u%>rQ+_?)YLQ1+oIJMon6lBG1vJ-i+Yvj zTXi*7cAHhg{+A;VK=mMyI^zEN7JVw8NAqPXNxek_;0m-JaSt6s*9|6U03wp2@4eiM z-C_}(c-kH3xg<-)XZhSv1NhDtDk*XVG)|-YUFKzpa9iI|bn9jS_&I97I8iU{qepY+ z35A(x(FlN03L0*)0#2z=R#VsfUN*5ut_mA2yrTtO(!W_$Ymvkq51q}{O5G;CX_@`s@W?3To-mt-vD^=rJ8su z3s}J9wGMdU^(UT;;eD<=5`*hvOX$}o5yixvSQb-xXL?@w3{8(MfF zYB~CP6}vl|zK(Y9Mb`SyK<4`b@;>)R)|IeJHOh@Zn^$9?`YO+i`BJq%)&0^BTHv!u z$vCz1uT%5eTHkH3P3-Z<(rV$*y0>TA6jr24(zQd?rI>xzuy;oaRf4K2D%UguF7@(0 zPx~e;)Yo6)lmaj!KXa-)3$W#T(z+xiA!9*7}zG#nx9U2 zEiOJGLc>p^2u!GiV^(n&AJy-W(ltq3yeLA=Cc&8(Zl$KFk&gm_&ZI}N30&TXdLH?nvz5SN#jT^ky7L@2|poYqUUJ96V=DTfE52O1TZ_4ai z#pcT0&%||~Yc7C6uT!3X{#@oFTMb#@r}@L&KSi3_uMj64muH>keotdQIfL)ri@2(y zxO{uW>Ej(o3ZnHx&jp1f>YiWnuKk+BT1Iq_1AGLxjzj&ILlX^AXJGG~-qdZkfsWE1D)=0C!N9)CFANuvW@f^Fe3kLTsrD zt>dj1FJ7F3N`Ww_48zZ7Q%QbEBkFl2L6@ZeuK0illb^o+A^m|9&AYS~kvI8CvtHik z8lWTG7CN>-hFF>F(|}FV@yuhEXP$h;v-J%94E3nfgG;Ej!DFO4CS50{61uQlhmX0$ z0RyO7Ae&L?0+xo)vi;I*HKY;+t(xF?@4)UrfR_rCI5PQ|W{wpw8%zCSmr;0X^(UVs zU-Q1sZq|5FzEls>#_#zs; z(VTq%)pGC1_5j2fgSQ!zuy`V`@r=9I^+*2&Z1GBS|$~tU#1@4IZ zW%mMK;@-5a0(F$*AeV~X`VuczY3K$PIwd>ZnZ*XxMFrnBpfA{+(q}n_2P&^-gqJ{8 zAhwMTgL=MCIYpIj1-kR!B)6Xu*`<~CE-ZYj7V+UsS5Cb8UHJJ0)z8cNK9_t+k#~=G zS4`y33dK-|N}4qQ4=I_T`{vpOZ*{-mm$!aq?Z_W8Lk94U_w-|eeH z@z5nIL1%LyYOou}sbsCxmlxO$7ivd}-pD~R<#CijpTyq$r;uybSZj)oZ?4U+3FFVI zt~Z4o{NdbjD1(EUS?XC>7-`0&e}i7BGner9+gIFv=(ctEL4&DZ=yxoqOZ^d4lMC82 zwc1`s&E*0h*D!5}G}QC~B;N}4K-I>-;06i_&K&$47=g1839sM(|5$tPc&gw3fBaR^ z&{DE8OCmxNvZ+vHk7FK{q!8I0D^x;8X38pLJ7px}pn)PB3dzdegkz8MdpwO^$NTg7 z-frLT`}X_i^Lf3WInU>HJ+8;RKgN~qM742FQpIcC3ucGSg*#>7ICHXundvQZyTrAW!yqjk?)bJ=h%ghH+<@`QyyKHN2W>$YI;3UBn zeRpr(O05*UzXWINls@GrZLr7UC+0$6oFo;>_|_X(hSBe~<9pOsi!UwhPx+5BXiDgJ z!ri!Wl#Xl*IzjooY+^~dI`-mnx7lq>?gds!GOkD-*01E3G3AvO<#U`6A}T#U_Y?wi zA3M**_<=Id9s64Wk8VyZYN8Y!=n@YWh@>A zLyY^X`AYZEp`o@^mZdu{=OF_@6si-BaH^L5(F77v#juF!!xhqyJ3&e>_NiFIJ>AuC zbtY(h<(FBXRpI*T?`cpPdC0K5_V${^bY8QV;485Kcgb;umA_bX)eaLB6cwj4RLyBV zXY29*s1?`CqA|T=mW)!J96QxK<4YYYJd$2hOBdT zv==$j#%Sc!K<(CD5I&AYNAd)v7@$6)?F2*_bFZRS@(NGpX=(_Obu0b8Dpz7Y*55NJ zWKrS&&-MdQv(o>xx{86w5aex{0@Z<}s+!US)~dF1z>f6ced!bAknvDIkd!E&t+;sM zA5S_eP9v@>wG!M$PXhe9dM|AQKps7@5kLhcC=xfYyVrNWX$#XqVT$;9_Y;Q)m1um{9q=A@GrfILtf)?qV`qO|71PFJ$ZMn0^xn8 zUH>QE7o}a?u@k{*hH^J63NUS>Wl(+6>&KmS#&FCIA1`wjJo23WZaKW7$}Q@&hDK|t zgj|ZU2JgJcX}yY6_{xe?e&Px=|&(ZfcgWDiY-z)ohdW%D^D@2 zbQJ?y*Z=woXxA@L0dDMsR0X|y$&ioCG9eg@G!z}KEKX*SoY40A$?7Q{Vb9nB12P~l9;X!osX7iMVSTdI* z=XdDbX()~fgWHwMix>!6s8io}g}|MYJH8iw*nV^Eke;;6?A#mY+&Ybvs_ILlMs&eK z^vgD7>Nl(1Kb~{Uit>xxTHJVQ&>leC>c&?DwDKnvD9Vi))14o|mL|d#lSpFryO5Z} zXfQYq&zK##2lL!lFe?AljDE)8puTMKNoPi2L^6bBaguJAP&|8)1j-$koJMP$TPf)1 zS1ZrjM=&GKMs+t&J0@}mL}d!4&s^9)>0AdLH!Y{`4c>mfyjYn!%D8<%?jKLi-J0pD zs)qwS>NqG&2mZ4a=_TYKC}abWgyz5E1*zJMRvQEL?>BCq*|!54qYiPRPsmio$m6T2 z4Mb-S?S?3+YO$;zR4;NP_tc${(gg04f>s@E=X#S?0KVh1SzQ=d5}p|y*)evjGS^G% z^0(4YfV(<_oxY2`FasqfRdAbhU#C)kNh_o>Q_Y5dq(%)?FUql;b%WY2b|c9Vt<~d) z2wYleVMA4i1a5r`1{#|(J?~n&=Y@t_y2Ahpr=K0+`LWBg$ZqKHkBcm)n=wZ}3$&h7 zP3q$3)<+!AK%X_o6X8bBXs+KCbGJ`J@oUGU{R}5B`W(oMzFNd@h>k`__+tEQ5@r}7~b z)gbi|hMZUf&0mYh${1Z24D^nWL&Zn&r}{JM&|1(toQ|EPaYhT@{dEEf5ui?KDQs!t zwzNFUaPF@RFS%dJC}Agn?Gr(jKWZOLqgWhS#Fi|kjrqvuB_LIkP|3J&>Boh9Z*-8< zs*|b|ZWeq93^P}wsvd+4ZWRo^fVa2Z%ELU6{GkV`rJl9D1Lq=faG?W1`g(SA3??z{ zXOV21X3bSeb{7*RE|%r)%%h;?9ktbz3W7U(W06#8a+`;|w+|Ap7~K0nzUt#T+m1ND z(8UYR?^+?XP1>4Oq9KU7E#{tsED}tyye<*|PT+J~oF+s=T$(Arsm`G|_Rj{<3shtu z#ChoEnT2a+8MI`_zL3H8R0dUtqtK@qge`L!y_tWw$QzF3U-CVieXHiEE*M3;&?+?$ zE;k@%>COx^y$pzZM81cx1nF;yO8!oLAM%}d>u@O@1A2F%G$yMJ5+T}J=O0O5-OG4B znoEB!I&g^p?0GWE-A2Ddo_Od>a`$vTg9F+UHFwK{)4cqRnr$mGv7M~RLX=rmZrk?V zy{A2MD0;&cs}2*YbaRfrKtZi%AI=MO1@!KOAk8sI^NRjK=iwc(P{dYHwaFI>Pw3c%&hiegMx@5`jHws_HxAg zBcz?DQ(1iDQ4sQ(000an;$z})$98P^_v|Lnm*Eha9A(yrp~Fpbyl25AkO8o_2ViH9#__#_q}C71{<%WLirJDK9WDWGJRI%fLXi}o_ZCiipmTJJ1$ z>z4+vb@5t>?e855GJ122vTpm50Ngnbu-){Lvkqhydo;zRuJ}>ELo;<~%WM8e8`Wc^ z?JGsi_V^a##5%P-W02Ze19OSBiOlaN+zok%;Fovj>(ib+?1Ozrx}oAf5=!65GKz0D z1!ZLjx-UEBd^#2h?&zmj)#p21mxbNF*w)2Lr~|4?ycp{SGP3w4qxNB_63Z>6%a^}c z9QKWGhJy7Z*iMw&PL~AF3>C?R&&hi4W%$NMp#~-lw*82)ZJIv#L*Szp!0^oPop|#F ziC#>PZA331;mK>kvcC*pOeq|twq-#dO{?dzWZRP4^~di(5k=-L%Uw7-N+hRcYI3^h zm|(i+LB{i{a!=^4>m>b+Y^M2sEOj9@6yTi=d2k+ILJA3?7FsTmPgI6g4B~T>WXSG$C>aD6$Si@YP>lYO`v{|0%Eka{oZQNVg`zqRhs3)G3T0eYQM43&+Z`@O?%|8pDij|~O31FV7p z2S@yr7p0R!eMAE=!=0t29MVZW^-xLd9t!hR-;H4sbGSN=n?u*f zhs7wN8XBnciuDmHM(+G_?nu6MkDOI=1Qu#0LqOZ#YVy^@SGT_0`Df=ox@#`t9S*D5 zMapdV*P)@ID+hJQix3p^!3K(X#_3Zb-HmDgCbh;40BHf%##b zqA!+)#uau+1ZC|@`jD+EwEvU&Z$(sueWChYa45vLVZjzXuraH{y{TWZ-N*$Uo4J!* zuOhY5yHM|iv`9LP{ycV+&llDPq%1& zr*uvy@t}n^WJq%3pin~7tSRG^^HlGoknhitEW@HpbrUN&NS*Yj1f)4K8lVsaCq>#d z`W`V@{&~WFgQK=0(2o2ubUaqaO1NA?Ej5@o`nb)gK!;@#z&&d8F~2;(-n~lCKssyB zQ6>*GTO%2^kY~^6!mXWDUiIL2*?Y`)?Fy=axb6~h@Y9FIKW`+@Wq1&5og1}K&MzkY zq_aT7hNmaT`9l{q9VLD#73@b_PIq`m{_Xi8z!74XA{BWHKIe32bW|b;nXQz@`+Hq< zboPWtZ)^vY>u^c+^l&ZsNkDuBk;0#gK1S{fNv?nn#Jb2sb7um#?@*`74i2#*>N)5{E@BZ$oe$=Wgth;OSWs!NEoTanf4Irv$VJ%ig;l+(BON0Q#rH)DFM2H zz6J;ilG=39E^W~IVkg)RN>Txa_Ums9o8TyXyLzgxT7a(*Tq1R-Yy0K~1$zH%uRpR;L3rwSm`PRGQ$QfK zLPA2e2ezC#4-MAf>gt}uo8%@XC0&6ea=-yK2)K{I6bOn#fwy*ljsh3x{EIV>09zqB zuA8u+Zex=05B~M@TjWn;#UBOUSAwGwdYMA=4xYT(GlV#^kDa$^6Z{rVon`zDmqy=& z8kwj{C_`g|p0wJwf&K7uN+nMb+kRw1&3=^vyCL;E_;}B*x7QvbS`>g0dTsNcZ(h1W zH2A0z>ajfWpsycV$}5s}BR%lqH8jRsz2jBj{Bn+MTaBbKEX-Qmc^}iZ7jPgJ7u0wx z;K_BW2JoR-DW{P`a`)S}>oUGN_~8Xev7Gb_=SKNuQQXOebt-p&FTYk5Zu4@s0JQff z44kwvwZcVuibK0qU=>d8xpJ&%SChVq2cFjw65E9>#x-j1gL? zY_7_AQzHA?EiZtabs+Mu^*2rExn!`X>53seVL^M|emD%8wL9$?l`rb;Ob2mu-p z*=xU&cYZ)%Q%v*@*bCyi;@fUNlrBbiFrVJInB@4N5*e9-K)&UXG_nG}p0CD(yMG>G z9WGfF__oam>axL5SpeIe@`yt4HDKNW?@~7(N zyuK3y5t#;3!UlDZ;gC3c-o$`_WE(~Q|7?GWytX|sV)y4zSf$f*o{FtF-7?6Zzshz} zhGV6y1p0M5P0@*N7Xk{kJb0z|6P|oLnYaU?R6k<9myggN2;L|&4-VKawK|`mnW-l@ zxEqok>=5#WLjSR|b}y9j1|xl02&4$yw)fX6Xyjlhuee;|y5QB61Rp*Ro9%0RLIX8S z$O@wcE4IvsC@_S>wj`3>TOp+&Wizrw%p#rm_IGe_EjM1NlfnO&{pTq!LO?o%fIU<< z&x&v8G3j>{;!tyP*6asobA>9T$dNX3q$)Mf7SF%^-)|JiB43NTlSM_!%r+sgvgmU! z4EB=JsaF*Exxlij;I{!MC^1@PjF9138LWlZKrSp*`UsFo_K=uY^6^1xnX83>bw1vM z3}r!l58!2gow$xbJN7qFGiHPY#ys7S2{8aKv~N3?;c-fD*x!wN>U4)mzsO2?36@96 zdSGi)`ddV+Mwt4%JqCU0Ug)RRiIKnI=UxnI01i<;c7$h>BHlnR+YUOx?FBNwVZH^@ zu1*hC(96)kaHyacO4Yx9SCQgD^3Ov-%1}va0Q$LLzdYraEAT3t1`8V=Qu+tO?GBi} zh^T&1(jID7dC+B5Prp~m9{Y_X_+jh{Ejcfapxy$5vC8NH6Jm#AJ{xI~b!nL7`14Jm z?#s~kWQiSEs0=}d^sMq8rw#lmHQ1PMWmMm2`4c%!okREu8bi;K)a2N+%VM%`tZ7cE zUjocW@=kMB40y0&I`)BV_oW{o46r-IlJc=N?3iCtX5T^5L!>+d#5K8p;m}r!^1=4n z;!+cQ4w{JXbKy#)VrTD*>^&}f(TrhCjzfhxGAu!!G492vf1Zrf+-IsSPy2b8z?kE*vP`ib5%nH_H*`L2QVs(Ms%>c7iCD$gy8#@+!!!TDRmXJgUM?CfzeO_dLnHV zjS4XWGy3Ja2s94&RG(HSFkZQQ`Cgl4_j|SMYqNy-G2TJ~1-+R2Ej&hZ9+_7|)ctc{ z3n?DcB0OXfc!|ln+F(F8i8uZomK`Ts^A#Gjb zamr6X^=Osain!89AxDFr;w~wkam!XMW`jUVx#Dc)KR(TzxP@wX#ATWg8;0fW%TUwWT% zE9sAR`t!-jX7yB?yxl=u+3jGFzh*XTg?ZDL%8l3+jny`tb^UfTxh7>o8suQd>9oJD zjR}#xd)H*^DV+6 z@(t)OC_a^8Lx+k@l+;=4gWvzvLKrymg`Xd(WZjqkhWKnzsn=MyJic!WM^;-C{cgAL z;LNoNPa6P?qhp2AW=i#O7Yf^2ITLLU6DWLZ)Y#S{ioZIv_dNVg#8RPvb=?}?xBwfP zz5{PbGL1fC|MEx165jwv&<+1nSP5=y$(eJB0&DMSlq16#^d;GDv7u$u@UpBo4W~BG zdd)4NWP+#TVm=H6E>zdyh1>O_bGwdK&{v4hyeXC%%L~SUGCuqw0ba;xtwz7eh)`1j zXn8u!reK>cgxY_0F#j9w$@*LUz)2l2156>WaBka>RIlvXGK(y;cw60Dk2KKl!94uj zZwA}VnraZ)vY_9{PffB;CY6r^;ptq=C+INUZR)z5rfLsTKoM1&_Xy53A<1~lOUeY1Gp%pn-%7#~gjfM=fbna*EB>}%;l1R&ziS`c zuG!!gf#e7Hsfm-|{lCs>2Pfbz7uR3$AmQ0{&xT7e?=m?~?_GfrGs~iSM67Yo7|h5? z94h4vmniZ$=q!!A#y0bq3}5r@A*infg}?cK?!io^Nb_8~zr4H_OuR?*uUpS$0QbWL zBvKTf0PS4`u=L{w?H0_G>L})dOpvu*pZ*YnmlB z9`&;Kc(&b)GNME!chl-;7|j_KV6|4>)Ph+(Er~ixMX-4ouDtcUP{?^zDGreA7d*Wp z#F)%9cCb+s3QL<4R>HLb|)8}+!e$Sn83=YA+PTSZE{q9-H|2B8#_nNZF@JN$` zgls!b$i(wUT65U~tAWL2RgOy4k1Ko+>~{i&*~evYoP|e^wu#PGAKq4Qx1s<|O1Lj@ zfiTYI<7u1yq{|C7N7gJg^OOP;eo7zR@A~e;$~Lj{pMon3zj%oVFy`*gP{4EJPc*PY z-zSfP$NF!3(Vl=}>mu#>|3CsV9FnR7=zUp?n}_tJ4^ibF{~V_$n6>n~JG%}l*6!+3 z_lL%-f|)`;n?GeGZ)$%#;73BLwEEhEO81yQDaeH3X>CG127II}*Gq;ds_QfbJGX2l z45)b@0}*@58Yi71Jxkj=P_jDUjIB+X48E^APpMFUydjX)jbY|C%-%PlRJb z8NUJ;wUT#gZ=yDCXYH=;3;w1)bD>e+IUcmzrpWLSsgh$Qd72H-KT36mf0%P(m7`cx zl|2jKx3&y{Tf_0UT<}G0(?`}@>!^jmPPRx>77*_tWdUH zee@eKwn1v4>-Yn^G34y8He-j)KxUK3(!7W>>l(twL?}$Rr=>qLYBI3H_uk(9T7nsd za_WIh&=&xw>M+s9tHhvrfUa)CrP%%7la`OEiXg(BXcM0Pkx|R3CTMvk`CN*YXtDC@ z9VI;X@h1^6pt^gzSU1@2JhH%l{L}uf_y7`$xO%~gxzTtD- z2?bmZ0_kzf#}ojISeHE$ougxIXFhu5za+3C8LYMO&MJMXxPx@x-_(aDUYQnxPNruNMxM>@+D$63td!}%)i>;K(a~;seC;C0 z*8apZ=`bWB!0qm9@wTz?cvJ3mGjb7+vD6-%mfovdGOAInaqdo&9L=|=>$>s>!Rt^s z;z0fvB0!eB%vR=yo?K9a61bQr%o3uue>Td*Ovk5pZfc<2;O#})3UbI`BTv>w_>5aI z6HYkhjUV6f2*;Qkt52;QE6y8`b&Bs)W#F^eR|XFiPZa#o&F=CMd~IL63zFChN=o6a zS*4c`F#jAF6Plj!erTTbQ8rv)&Y22=Ak;lJv;UA91nz3O^* zu<@6!qd<8G9b?*fahVj*w|hDPJ_j>8)z;kz$`N=NT^Af0-|WY7fNa0&&~6HL+NVkB z?vW2ZUmDoqC+*KCMZB~q`LiKAnc2;!YCy4HmS=FLyn(y{Yo30eKgQN_&w^dKRR|P` z-Z%d5P8Ptzv5u7@@fPewynw?vzjUF*InLV>!d%~CM6_a${;bI!`H+B)WNHvQ_AZt$TUDEdi}Pn!Yh+S9jy9ak^x6WN2wTwF+b!rUY;iE&iNQw6NEDP3+X z1NBa+U%HY}OIaq(EGtVc^n`nrDR5B&=nqHXOq2H?bs}R`E;o1V;Hj+oAPt}JJQ=q2 zgphA=H<>=3+^X(qfS{-E9k=^(1wZmiL96%Cb6PS9+o-uO{1s3;9cyp^Zj7 zIirg?Rt{eXr}+L~ceNjrxw5!-=vUXfc#aKa*cU%@+;G!~MYL;|EGtAwx+LY6j$oR- z(LM*6!p50iRMm}?3NU{E$S0)eJMBeWD!SfdU~n0*1{KN3Ej=#QNMnv=a-io{U-LNF zPT~&aT+Rh9Q~}p*+-2>b`7V)ddc+&myhM8q4Tu_F7cVYT zdZ8oK>&4(pn5O8BKqnfZo}}rP!L6)K;8KXh9ZQzMLIMbJqPvDQb?eZa1&>b~&J>+E zaqZBu6$U$ph=+9a~^Kx7v_d7YyXp zz1~mQGd23hJ;E>$LMZ2yfnv}tyd7u~8Mkbj5GG#wBwYp@ zbn9wJvd^p#Q>8a^yv-NZ>h4I!*xCKBc*?Xwv1c9ai zJtVF+CV9oaz1lK$*5xY(nemA)e42)z9N#(h$+1(%VXpj_gnG(7h*}~P>LuX73Ol0L z>wP_=5m-plG?H}E?#6+=RP{o4W-Y3ZTe;ZcLj_BVLd8XH7~nBa)9&?RrIKVS%#5}5 z14L)~X70pc(h#0(`Sa;$`f#kVw_Nc_#|}3mJVuwB@n^pLbK$I8@1LWph32YzE*e$o z#mpT*SOO_^B|R)kxWOKGbdcGB^5TNAqou5*lK6Nw%p;#>~0nL$AKbGvF=i z9)?3yyWSj_X;T@Xb-JjqZ^(qIL{XEH7K@ zH;ZvxVI>Qg&(+KTD@L8FSg#P3c17YZNh4@egAuQN??(7Sym!s&#=ZeRaAEWXQ^>Qh zo4S1JyQHF~v-Y`XqiFNCL-8uU@83altSA+atMx9YEUyZ~mjo^Yddsm*}E)KiXOr^ZO z?Js1od`jE-EL^v83{LO%(mNxQ)(E5C#hv*kv5~@E2t3{y%)Iq{ZX+t;-T8sPuJ_tM zA4gDLuV*I>A3~Q()}vVu>HfM{%Z3r6UviBeZ=|9143WaXDc`MfL7f7uODn%KSj#yG9MOqPGrI_%urQ;i6kdPL_~T^KHi1gNWnEI6$af} zkc1`|ak>`QcV(vXz2wp@pA#q9b({9i{celW=u7o%KV9T(_U(nuNX*>LCSN%`_w>zT z-xQ$YW19y0t&)g9IBM7XPmqe|{rh4dpKaDPzs%utebqDyal6K=euZ}}1sbx1Gp{_N zXZF|Z>Yc3nV$OFUe{_5!`iO$lGEu*|#PAgg3NyH9hgFvxdA7{bKGs3 zlB7RBoEJZHA@hynoN__tDtkl?L%Tzy_)7P=I8id3>-lao9~dC5MC-lPSct<{=UM;V z8uDLAGPQc;>s}xs8^h!!;i&1ZAMxIqV9_4(R6S#pe* zsGM9(EYQTI1*-7N@UaA1!1%B5%;o66_{6`_`QvBzVX9P9=$)U6M5xKR2wf35%oIx! zph+Hd!3Zu3ht*8&5xDv%@JD8+PN$sfh?C=vg{~s2(@TaWmT^Zo2bY}L22HNedOGt@HS(B5FcaY z+h|*-KbzBA^P)K!c;|mPrGL4bl7CyW4Bv$BUTQ7^akPC;w4uG5@an{a;KuvGmr`_h zZsTcq{py(U6&G!a_~3ekkcPH;BZF-WT1G|Ne0-9O_DQ}2XdAiV@Um=>f?<2>p0Y0`^uInB||oDYC#H5$&gkWU7@ z!1Cx1?(clhabPte)nq_RpsUNr{fMoKXkjnAZQZ=iH+{RV`yyWG9hdbSF5uHm2s9_9 z?{B>#=7(0i0iJvAqf&u%^Qnir16q(u(rH=0XCR%07juR8n!{XBQ#rofk>!j$-u#~{X39v7pB4o> z2}PyRV9!*#wvc#;ijrXFS_d?QhA@}nyT2a%%$-aMDId?7`260uPerjr&+01!fuRlB zoWF^?5}n3dW>)0f=W?g*u|fK!`NNJJi;R0PGaqhGzKlotGhB22A3Kd9h=l3sB1Q&t*pPJ2Lw^s<*=Ubq3K zV$K|R-{4_fX>Wpc^?nc3yPpgt)+&t4@*35(={G-mtPIAU9}9Dgh!&j~t*dAgsN*>= zk56Y+R8*$Gf}wEIHAqU+b*X`XUbsBW(A<7 zIkpUC^(Xyjo)^0Ore7RZ9?$JKcd&R`>23dH9Ia7`%Za&hy??nh$=hhPsJc6in~zIm zrO?CRkk8c@YWD);Q8MaYM=CaF0Yqp9{yZR+oT0}KoRb0e!-;)sG`M3{bt?$ZP3r|+O4&bAXamTYT9u{^R zA##iI4r+zH9@@bqW)w)Yy?(d!f!hai+RreG(?G3;BhJ-Ec6w1I?N$Kv4MQ7GXdIzb zwq*eBuJLQhF^!Cge4#pW)XKGop+_TbQB&4@Uw>paOkV2uYjT3Byl!Z-OTBGLywXUx z)H#DA4ydFK)XeK;+zQYurE$YW4o{W6HN&Wn#C>SYcsX&%d%;$6uDT3Yx*G0)j?#%YuX+kRMUr{^7Ff7faS#E8xcV{!``Ni1ktbz8FK8_w*Nr|ws%S*l zY(X*xY=IWqdC1xNJ1QcaXP%{G_0zn!Fq57GI(5%nrwi)al7Ek)lCQ0#v$N};PvxI3 zdZF1Do{Mr0hQ}CP_~zDqvyYZ6eGmP)9@-D)h6(hg51u8J2=~x>rR&1~pwk;Ux)@89 zvix!)T)*#@_O$6eVwp;((LXeVv|0(JoTWKj!>{xOy^NWO4zs3|nY!O*tL`c{{)x9L z-8nteF(;W6m_BqR(+&HYV^Q+G)PUkl;l#6&xkgTp=x}1-_qi$DxLAUk+r-ku+meyS zk@pQt10T^7&cwKK#n!o$=7zi0aTrkHIxOO=nUm9v2P5qm1Tk0J?1|6mge)$7b?)1V zEu8uY^MQCyiHucPbud^hf66}+!v@@Ta&m_k+Slsoc2c8yMjDjYKfX|>`R{FFmxBO& z=yB1)#y5W=+IxIkF{%vXdEdsJ9D$RWd!Q(3eXDIYz=Km{rij~v2D^asl|__e(w#+s zXW~(LadtKH4}T+@r}fI2QhZ#nDmY9_w&ZT9`JHOpKNG!8bGFHC)&eCo2tl{q*K7Ww zdvaw2H^H|Gxu>O3l?fmVH_;Z2=N2P2Wu3hGdSVeh7Wd~0aS7M3qB(7qrAHLY0`}A+ z)@aony|Ro!h^63ys&hsvXWEJ|O!hz2>iJ!O-*bw;OR{gn?H75+zb;=RxGbS2mZl3X|B~D8b zk+S{-(+|)e@wwP}I!(WAG-5Y0Faw4FzA$oWNW8)iJ-*A+tGn`6SI{8nnZ$n3Mz@wl zq@H=$n)G$rHxd$?Kr9e0$ItaIub`DwTygSs9zp%zxSUTc)H%EHb$)0YSb`)(wMxc4 zIra~4p#)oQoT)evF8%;V>RabR6m9-qOrIc^Sc)m$eC)ZranoCkD&FOigUimYW|NUs zwcJ8~y|K4A&6|yS7Ao_;bB9uFUQN?rDkEt<|-XccJ~9 ztq#3itUMTQVE1E?-@19l{Kn5Whu84wSSPF0HR^IHtI9oz{lSfVns4-L!j-rvxoC4;q_Rcc(*9Pz=@ zJmBi2aL)nHbJ|rK)+)t{i`=S>ZkaxShRWCulMEu240iQ+UtF`;;7m|6pFiJeN$luH zMeK_9xAVg>CId5ZZ{Y5c^W%?qbQhR5&#_cVd9282bVsyK)JgF`i?0KZQs#(2O4UaDBi1U|AYkAp7+n`@JwXZO+3m|VEcA3KEDErnKI4E^$P%{w+ zR-~Q;xegk#5AZI}Cs|h+@tzjp5at}zir{=pMTR@vdi3F@yMt)v_!QWs+R+_b)O5+3 zn$`qgDC<{x1e(!kwYxj}p%F2U<#X7@!m8R4v92sJtAJ@5*OZCWZI!eh;}gdHORIKe z3#$+yEc)k;LliJFnt652vSy!z3iaAR3QSUtRoZ$?^bm2P`_QM;+KL(*h#CPY_6h}OFepy>+psx=L*GiyXf?SX8R84aUl8~XTQg2NcQxi^8C+YsQ?c{YnM)q9f2Ny%CJ3JF2$Jthg~Xw;@;z=6eF59 z5L0wOf`Gq}j~$MNo45@mLKzMb7Jn%j`GM=I`R*Avwm|TJ6y{coGmh$=RV$6v$aA_E zLM1<3%*rY+zGkLrZl7A&Hn11Chbu885GXV>7|KP;ns^l4S&n^VOuFgBtMF`ce>xBn zHMy%};kC~6e4QdqXoK&C#Rx zVCdLCSE?cO%6roBagCO&w%{HOSwQuQ=+bMmwNI&Dj`zT)humw1ekWt_t|caZ3t|X^ z7xFWQV52h@0lP60E1wI`kEaNS@$|2=s`g@g!D2`6s>_6*vxl`SLftyzX!d;B*W?4| zjRqeeiop&PgMxgV<`!L30-tviWTYYq)mcs8S%8%en-{uybu7txCVOh4$$q8Tol@h# zP}=R#&V{C$$*7>n)tAcyB~FTrwxdg2g>msH*{e!XCk7jj-11-ZPdwj&E>|#pK@RdL zthM$5&kNgEywMN)duw5A7%vi~$2v~Hl{&6(?|6&y4Qfdk!8Z%IlqD8^It_X*DbSoV zBRtcYoG5+V^+`S=x>?L=a>4YoR43-*#CI>2QpX8n{#n5#8paV1b$oj1WlZc|Oj0t6 zYfDUv1KoeZa0y~! zsh)5Vf3%NY-KN8~PMq4~?20}6K(S-fsmZA7tWzS%6Y=pVqktyny_meF&f#^QR77}k z^z~MrRK=S^_qfDFn?2kOLQ|M{eSi#?mV8$eKN6K3``D4mTDrxHYlV5DY65OKigs52 zx@73ok~cZ7>!EF<+$kf2Ejc0hrlQ3nyS8gqE{-aimlb7=*3+)WH|COq0D75GEe4u`{tvfOj&v4^+p$YP;PFCN|A44Ercr-k5mXW4z!NR2h8iB1NssqOZMm zo0=8p`6idCM2*_Vwo-1g6W{$LDKJ`}e}r9uIWE`V?cVSkS>FO@`rZxrxPDw|K)5t3 zBQ|ezoz#*Yqsvcs6MszIl8SE~Z^7%q;-c|l!CL)4FOfi%z1q?*ztmj&E07KN#6NKW z3@Qebq)y6s1n@{_6y{0qnD#sZ&jd^EAcS&ir@b-M7k9eIccgR?%imXH#dCflU3k^d zb-X(Ww2*{zyHwWWi?3@2-Sq0^Ww3tFo}ayfPe1jI=GaE?dqNu!wzrvTbZh&$oCkVn zX3YF17?T>W@5+ytv{w~dK8~+UnBMsOSdDxjX1Y|xFWj+R?Ow)Quf#zKCN_Co+T0yk z7BGT06nam{U@yIsRU_pt&Qsuhg;cVV`q?2oh*4fWcXlw(p(cTUW`x(n3WVzZ<@p_VDA;wg@0UFcR|?T zLvDNwOoYW+VHg@|+i}Na@6yM{&c|$lB9>*vv>ww2#hK>a_ov-dikBXQ_I#|yj<}?~ z>apqc8OME1<}KjraBT2#iCcX34|srN=VbPf?juDVRDi;n@0%Y-{w??f;wm~OJO%G{ zCVct6IBk=jduI>-%`d@*LdzpYtH%R9mOd0kj}ULx-toV&TFE=)Q8;QVgU!CA^a4n5 z&TQ|dpLX~ZjCOmZQ{-^BM~*kvUH>!HeO%icxo3hQgehP>+H7bUJH1_ixT$gX-H=}-82Ctn@= zKkaHwD3CAyk^|~3Y72`FhNs`Ir_br}arDTyu%44xw%EraX%p$ZSL!VH_{xVV!`d0Q zvtVQAS651(wEOHIaVdNSI?M0bVUa6%?x%Ovj*v2v0*FF*FaH+tK(59XERhO0fRUI? z7o6-c?$B{yd)N^*5;mt2)F|biGMM*ote%8~(L?7Puj11SI<(nH5TP|X<$5FXU}!0a zA93kiHFKScE->-e-+%zV)wy&Pn0AkOh>Jj^^I?~ylUcM!SN4UACUxMVvMY>|$zVUf zleqrSIsxhGYbN3T0f++xVnfFewT)~eUn^Lej=v5dQMbS6w{w+ zIDPw@sDr{+b7iR!uk&}^ha8u=9CpZHZ=9`>;u~yqwdE!aY}O@$7m)~`7XYaaJsjwH zc898p#_{9m!AkN-ba>p{n7)dSlsal_i?X~X4ijX&+?t*IKQvwh{f@!>6%Q^(_1^-FNFxznSkot zsKx>hdz-cuT&n`jc~Z8B%Kx29)m;)!R3|*gRR#u4bWGodyAyH;qtyYZO0o-)I5`+T zx|*^jw0#I}3PtMqEnr~ST4<|QgSfE0*zWgiMn(<-$quDr&GhXP4dE`Xf9CTm>QGV! zgy+pH=h!U^R#sp2^a0yc?Gdyxz^6MmX&?U|Y~zn?AjVP~Ke8D^S$KLlt}xXL_S~I( zVcJ|Nc3|YR2=SsV?PsrP+Q?`PGcaRe%Gs4lF3ZKeHbV3$dkjUv4tmHgTy z^@evMXZ&33*#@ua&qteBHX8r;|?< z{T0rWU7;q5ZLpEiG<{9=dO9E=lMinz+lF$BKMMz36}?z%b_+kUtG#o_>LB#H6uExr z;VvP}=-5^Gl2>9>VnvBm))l(61j)>ygm{#ah3|VP(>iolm5fwJrh%1A^Q*o|dIl#H zkrl7p)`sXi)Y-O^YOd4K@vBlFI?4WfMWjNl$V85!@$A>+@MsN7;LK>%=?Ze(bJM%* z>uh5WtIKfwxw^@d;XryV%SN%BC>d-GpVgC3o##p%hhDL8FSt+gbe2e=iAx%O9SI@u zGlj4a5y&h>xv5c;E>Zz}(C`NS&9g(CQO^f!id?1!aR>t#dS0nSOAaczO>|V=7&V_| z`r$@~+gXlk_9w>;R}I}-JEQnaaKC;Pt_Z1Y@W6U@EzmFaoWNi>_n^Q)?9tAGv3@@@ zZlNLjUCc*}h~pkK3|aG_h{`p6UjQ2RK7#WxDQB9%8>FlZ<(winOpzqw&o+srLdj?!7x2X@3p)m5L6$n*TGzGG_bw-9;JfueOkT z;UHFp@jFS6RFeZ3GbXZylvyu_0k8G6)_;S}Q{hbu`&`(e%cQxY$aF=wrg;v{IoQZu zzZm3&&YbY1SX+s-_8MXZ$;!Wo73&Z<1&xj=Ei#ZGM+X`col23U>#+Yy9JgiB2(jYx z*QI7xw~2>dXx!nUC7@$TY$~`>|1J>94t8zs6Io*LjWrd9g*R?`FB#6pj>OxBqm#op z^OBOe3o=-=Z<7Z^w{8bD*Xf@PU`}`~ zH#&?MjojY0qoS9F;7WOo+LX-~{TJ)18XJrzx~XRrj|8R(=ih8mnw)j4Nn0$uKnwV_ z`kfjPSlN61@b#6cAuBWd-&gkQoxOSSN@nN{koq*<`C69MQo7O!^YoG|Je12n)NHU; z#&a{dsk4)uB-tVolQiC22maR2)H4j^SFq%I=5CDJL?E`DgSI9WV>>5KuzSKdla_|% zw)?*Vf;ZWW8vN;Se<@q-HcYj>7Pa`;n}nV;k9kY2=CJ25MOJ~M#n=n|bv$U8l%rHH zgH8+fvn4&=%f5&JFf{tbH$S5Fpps|NYv;}^S$F=^m#%x9l90iwi`Fm!KAn}{xb`;b z2$1-O7Q_FA`vJ2izMnAa`BpHImG)bOI1|}DPem+FGgnXo)HW$}6bc}2&dJ-8u7L9d zh~?#Sk$)wDYxJ%n@LZU8 zuuflQAg9P3zxi0s(M1r=TgM`fZ~#*JH0=GHXl*Sjzuke*~q#EXewIn+v<8zIqkp+G=fG z9$BJR-=@F;)#F^@#m;b+BffrH&JwZR(ZSZFR6DVb!Q9cpn}2^4+5IJdRHJ^U)SnL; zJJ(n6$Z5gW=zx^~F$AuQa1v**g|uL%E*o4?*OlkQ5>eH&ZxEs2w5|g(}2s zQJq!ZFgI&vi)_toiO82Y$00g0dvCj4tD80cgB7nB_W2 V5}kEL`IzJkR^ef%>) zL8p<&%|`alu9JJ?cwDZqUASdJdpBnfOy6y~*Sd)u5}@j3cvttS^{=ba!r2YevuuDf z)Q0mcQtfo$yHM9;-c%naX9yJQzK;~Xm?UFe|9qvBxD}K=ofdUw6D`{Kir~vckdG_O z-lQCR8F_ZR^%W;6@ih^*Cn{8ztl>1M;QsN{m<^42#tByPfuV56`ohkrLPj{=SAZ;= z6aKC|>lYOXxpa{lER4@s(8h1!j6AObwM+_eLh|wQP;PV!Em(3Hr!WNqyD2%?nUJ1cOgA_(?F)z_q<%equ(0>^0ahx?ML z1M2gy*sVLhLSTQ*zE`vgk?g$%K2s&fGUg?H3YMTQx{icy+5H+GB3mO!)&O45|DCUK zL|A2rj0;SA8fDcvQr2Ey2!eQWy*~R?8bR>X+a=3M)FS>haKYVQPOT&d#>xQCS_Tsg zDLH{RwG=AX0W?~7% z-%olkY)S92)g97Z|3IXIExMZIJFLNE@QV|6GW*Gna)Z3-7&)$E)?;A98kO7Mp*CJ1 zeFy%oK~iqdO(G{ zNt%@&e*1OMn``$uJ@`TW$MMYwVNGQ(0v9)S{$L^z6;}K?Si);H{m0VPu;#x7k-;LK zFmbN}hBsnn1&u47JQKD;_s>2n-!B1^Gp*q=iGh5FZeNZHnO#g_9^;#K;r}P8G9fQc zKH9|jj$-C)l{BT;^&yMpp(+s#F-3Gyk3ycMjc!<6GM z!0#wtUfpmJst{z2U7;5+tiOmZ>kPax)#~Mb$`r)cI({;g^eBQk;RToSF7lHcJmq>2 z&X=<%5phEU3^fGSdG|GFi?d7WSuhgG=0e5#3O)aO1v$$oa^gVr`Qnc)e-)4)oGEqw zm8r+^sF0!rp*0Q-MtcsRe;o>vDiFjmsETacPQ8b5{@3-l1B2~c&&>~$=peG)CoaN* z>D^SxNuQb>)|vTP{nkxhDYEeKKb~Q`M+GsLM~QmSP?f6~wYG8AdLWXorAmD_8P2bP zZX>O-^Vpvk`^CEJIe`6T2GzHxQqHf8(;jlsn)CN{Cbc;Cr7=kK^_@GM6wbSLzbwe( zt%#77i%FR?$8O3mfo1RcFDBS_w&&y=6WhS@^}t3qo?VpTfl541N!kqsCQF9?1$9BQ z={&M>F(;*2xwr39VLrs><9O3VY<5Eq*Hs{r9X0b8YCS=+V_oX`7CzxhMaY?E|1w%S z9!&Mq5&M$=dNi1;|Bt=*j;Hbu8^=q82n{JA4GlsvGQ!c6GD0M)p@>lS)~V1UDT%C# zgd&;A`XnP|BqLiz_R7lsUN`4H=e|$+JkRgb@U)S}%_C-=Z zPxixxukN~##|>oi*g|;@(yqi}UBese;%kDZyldAJB{(EoF5gXfv}?$8u#~25j<3wY z%N$-m-}TT5i>74-dLz$mG|>&Mvwmsi_$EHq|6*N8Ypwo0?gzL5r_(D&aEMg z0YB|HLq$9*e1e~&KX(lq9Opb76J4;BM;lL8A$wr`F3DcC97_{?MCVKI> zO4)UP`q|ihlkk=)!>Pho^5hk5J@e7hMH8e$tc6=CHgm#qbauG#yELU4X|h$jhwJQJ z%v+>;l`j*_XdBC>GFwF&n$C=n=wXg=5N!}iDAq6rm2#1j%~?P!aLeQiZQ9x)z965G zW)0gzF^f@$VuQV+uU6=i-+z7#rqzh*p)o;x%qh#l)thQ6*K=KR7Nk?u3KxY5ou z-O1#tvyX7uShR(3l7pyyhjTgok^ZZ{_0@KxXPfjm=bozE8fZpt>DK1OZFd}T8B-AQ zX!|rzB0UzD7$Q^w!(3%GM6gd6GrZ8v77QIP9+}&E zCXMAo$=&(4hpzq}vQult;s^qJVhdAuw;NwVJ=2YzuYr>%jsbO^rOEZQdw;*$`}u{} zI=B~nB>}76C+}0hK7DR*QrAAe>ziQ)SA0uj0!=l9jAjO)iA zr8MEB>FaQb9pF5_7Q4BxHB%M_^1W`_@h=2sjeGa<$C4 z)UTFJr=nlzf5r~Z+b;uCHt9&>pt+gWZJ?#fVI-0}@arvQHj`b%MQ)bIGscgp8}wA! zPs~ws@$$5ul)BRsBTZR!NTFm1YK+x6tm}o+9WNb?ZffRod4ohLyCSE^a9yQMuVM+k zdAU}?8y+}|h+u_EpsrK^cKSpVl&Gd)2QypY6x;OO-gWtw9n+69U?&~RU?(P{u`hc= zL$4fDniF**R8E!zOC{KLSvD_P6TzVLDEp{Nl~4pwP@j{;s$XTNH-&CLbG_l!_xU|F`j<35|; zQZC<}p7PGF+LsStx!v}Q0XnGnisKU%DaU8NTQ)JKT#8Kt;b?fdLqk^X={}JL%MQRE z6j;cIKtwOvx7HkYk1e8;Fn8nAim&5}cQ_MIlq_LW^vx-3a9{hY)?D$*P4y+D6F&q~ zA4#cvhEBb7nL~HA?)NcogF`RYr&%b4`_Qb&@8y~3l}+-ow-LZXnau==3W$|C>56?S zrgk#-WBmK4sy5k))jo@bzE{yQRBYc7GZB)}xa0gj2FNjL-!u0@_!5|$|J4yPm%lEq zYH`ch;8%%(96Q6NW6jCVeLUH_v%JGQWNPd^b2qbIvBkV5EBzI+G@x{rDhMxddu~K_&5ko~6b{_E1IJ zq#im(jWq;2e7|w2Oo6^NOa?t>G4guwR|AR|9P5X|8gtF6ZtJ8se0{wk{(=hJ;U* zOB@cyo96k6xI;=1LvcdoBWLlT+9}5S?5pLH zHnKSM^sknx7r5@9x04i3Gq?V205GNVKkm<&pyw6~asDW9`*>T`38ISv*A_-G#gpw8%T^RDy_Mi> zv(5n}vJ;3@w&AyekU+b*A|Zez#@GURzSRcul(6iYK7d?m%F9+KJ0Cn!YwVC=8gWfq zBS(;Q=5trTDv z%W3p$*Ri*Z4&pvFAtcR+A2}D=!-Y0r(-PaWhO`<5B8Nopb>p|$*^`s`A#?>Hn+W03 z#JdfBjmZV=4X3+#_MS1aFN6CPOH#EU1jJ4`Srp(-hu_yZAYNJ5F8^7v%$dQe-a8(Fn zA~vb*(@mE0RW?itMu zdA%-ss^@p{VCm>~d5g*4k`_NdY5A9rblKTtI~^DLB>qCCp1gO`FiSx0U{}9{3pfBZ zQtc!(?z%LP$58+UCp`Ig{|d+JU^oc{qqu+VnZ4oOv2)>RWg;SUH8N?rYekElXH;uHL! zqsHXyobxcD#LJMR-VANPZb6ad{ACBZmKldSo)b8nfHmzYFK=HrbRzBIDH~f`@$qv6NC?$l;V4|yfp|heb%-yY&2Wm>QLz5-_n9-*W zn0NWXTpsi3!*=u$@jC!3&-^xd!r2eR2t-zVxv>+KwMk#o(ORT^^E?h6p4^O}Echh| zFa1fmA>c>6k6^f(?pS>kvx2}TTzSOY4E>t~r+;Alr6r}6*iFI7ZKHv5@>;tGKZ%Yd zKT`?egf_w|({IR%20;HcjOSmRRf{^w!vkT$;);}2=*i5{^lAuIt!|#w!?Jt$1*cGd z&LY9SycC>4D*g^NSf};ny#ql#@K_89;vTMIPh#t?}SiUZB z+%2Pu9yoB7<n?OdED1+rrba4PDo z1U3_EVyXM|;6U?L)}QPU<_n!p*b|LQOYV?m4lHPpnOuSD#!V$L=4RyS2gFoo=dOtj zUs_0GzWo)HSUwJIXbi%4wNy@s?AifYCC8y_j++k_qcCK}J6|3Fh%*-e{M7zSUI0I| zJut3!_~GqHY`hGK-;Kn)TO0?Cjg9*>)=^mU%9Z56<_n}mE`9=Es}*GxF_yqT@cH89U7q|8rRQINNww8Xgm`9HQm;% zw4apmoA*o_W8o!u(n80mWx^T#drE!seYn>aa9JPmv@;gKAPqv*x(HQwMm6BhgK&Tr zWjAiSrgHVih`m>#se7JnwSM*J{WBC4#@n`WW1u_cJqk12F!ef(Jq<8k=cv#0 zKk-lvGw-n%&k36>3{a>eC{Tuv3a*b#`*MyHD4e#d|}e zu~C)@N?0ElHc*>+N%5Q@p87Fu%ZHT?zdmgF0j1}QIqbil^2&prm6(`UJcwo?Xaqh& z0h!Ds=)j5jw(xq+F>QVJ(Rypung+xO=QiGPsp9BC|I89v^oPWemHj^8IiL8K@Eh!y zB%r=VxLt@rEh7#hMud})EShlgpDSOJ&TEufjC;Kqw@8>d4kQ8HUDEniVq~jd06E8- zn%rRLiqXGG)_~AMBcA~?I5N!CT!g4-ewZCe;?#qdS?oXX0X3POv+YQ?Ht!pV_T`BcY+)N8e3@v>Ag$$D( z#P3_u3+XK61TXS;m=6^Yp>yn3kQsVBY0@2@Xf1%bpAi<(R>d2xgVsugO>fG5wVP%W zZ$U|G?Zec0Aj~Go>`RB3dusb2N=tK_WQM?EvmH)xRFJh@nFbB{e&nyU{M3|Sd0D{y zb&kS}<1G$hA4q?p5e*@mGfa$u>DkyBLX3>7hBzKJm_s6aRq&wY$kPMv#$ znJim*zy&Y6E_N(?qmSf<61d<^g+v0-1&1%4Jt!mGK@m4GzBrz+!NB30N2njln*541 zn5dRg2B;s)>bVz1=((@U_=)0XRqw~z!om!*X5QMg((_BM7B|(@vv06v|0L!?o;qZ9 zYIEwT$%3XIDrov?hMqtVk94j?Ily7p9bkUOvK*ooWBtY4r3o%$zk&k#)-mfQnH=SE zYJV)=R7$zOYd{QBGFHVTT*12iK*{ruuy{^h-5gBkiwmV?5Zm7jwcT%mEcxtbGF+ah zxKuo9y&eOi4`3TnIKRx<=4+TG1boau^yGDgznMHGlbJ-5@*^X`Cj043FUO_JFR zoeomFSkTXNKP9eO@j?d&ir$QZmIqy$iGdvQ1jX|?%Pky+A+KJ)q`bI2LXc!KTqN?oXI^}35)vGW>qhu zO_HIzIBzfLLzDGiRAdB4pyA!*pgKLYIJ+2q>dapzI|Y+5@2bYx4+ygfnSo}}ODPf$ za{jC65Sgah(B0L|jKtV*@;iE}vr~o76GaYV>;ys6uZH6oZsY~?)@Pxtc>W*?5Uc-8 zngKw8hN3OC{s;N>R$8ri18y1n-tNGZwovDMxb%zkPY0u%3J#wc5c_QSjTAifn8Q_E zDe*GpD2~vm>71~W7OnoW?HtL}Yqav&63k1S{1-vmFMysT4T<^hZ>%^2M#81~rNtCc zcJ}zGygle*Ob|oIm*;}}NEboj3;6q-uGwEsboRA1KvzYI8_n^C7;3>c9EnjVR!doo z;WlKbkv{zISd?{JiMOC`%FuQ++(Gm~ay~B|G%I2=tBxyfI{QbY(|jAKL13UF+=JH} zcOp{7x1Y6~q*uZZ(KLRJJyd3EaE@daoz(v`PMR7j@bX^-H6(^Aw>Wt&+hLYIA^U+= zzwIn-W+!u}-8=)47#tGB$CoChgbkRapeVB13M7Qyz^^#-4#YubzC`Z>lqFGEh)E=1 zh`{hmDjgj|o+f_qi+kZ+8Mni=LlhdQEJ&1*czsVIG`8#D`M)pR2`mnkhhM)LT}BCM z#V?#R-b*CPIWjAm3hgmxrwsa<>2jNpfMr85qR;E_JOcAO+U3Y8HvioEXM`FmEvX5^ z!krFDjLHFu52v8V&3lUmn$=!M+C>_hL7OS>Gd@zAZ`CDj!H_G!$rp)~I+FBxGX`c< zaE=53H#IJ*c@gXG9YX2tPli^;<5DakQtBK~9yZt# z1M)e#uhjrV3UTGY#J7UBj~ts{`eUf_j9r)-5@D2d2hXRl%Oz|kv%s2FK7KAzkE^)f#&`WCH1h z(wXQ+kPn<&`zHrW3&7-b*YMQL)F`B~A87gSHLf8z1^v|Iva7&O^fWP16mQ3=L88pg z8TQ;Ht*ZUwt&O^XI}DS2rZjMe_D| z#)N~PFE>E1d#GFmhkm@dFxnc6xR)sOf)J16rgmvzC~{`3mP$TCl}}a2h%~%IdBK zOGY>Ju-2)6CX1fe+-VYzAUeVmZwXZ2|KqsWwt`*xSnZbRr_?&GBp3?MG^E?g(yasQ z<)zR2IU=JRaM6~agGc~rm?nDG3E^-PLsMm;aM&H)Z!r}~fHwUQK2@o3-sr{<0H)H#np|HKvu<10cW0!DnEr8gu{m}Vrf;|CyQ!mOQq z9dbDQr$W3=Bnlsc;7Lfl{#uEK2RQD+n+z+BHV=7+dx}T(*NEG4n1w8!gb%KkS@Cc=xB?#7D7cs*XP^8cgG9YgOa0@A(&~hAy373SY zd4b8~5RV>2_Y8R9iu)B62T4L|iirEzVN9CHFU&y{F%^g>+jBg?G|d10p*Xh47@n+n zeb0cbz{Ls)LYO`Csv#2k_fIFHNKc7&j6{9G7KXP!*qO7--6qU@2i}a79&ZwNr&WI+pa2) z$io3zefpI7pCg=0b!&+&^!{#p!jES!n45$k16NJ$V^h_?|H267zq>TzA1kOJb?w*R zr6pb@#fR4OIo;ofl7(M5?YR`kAZNw~Pf=Z}+_R-d}&g z76?RTwCMaRE`|8Tk2A;?BsGnbqrDBMPM?1Ay79}G6(ij>8iV>wxbuW_HioNHx~tFT z@8H`EClRytX7%R@ZR~T5U5x ziksPxMxU)(^LHrq%PP+W7QwQAZgP=z3}3|ClA(Mg6E~BZ5j@?XYS@Q+ik)d8%~2*j zt|3k?F6Kz}ZJj~v9}~o)ASG{zZe|Ye zh(~}pcb>9COOt^F)K04sEV+z@TUzs_8Mv%Z4R==JSkoLbzeasDBw&08*xWQ;#ZSdp z#bGU`<;QUlBg2|V2a~VLNkHx5q+gK;uZU`>K_H^ysPAjYuj4*c0JX+!F+x*}Nx1PI z902U4+tKJ=ltN8G$g7Muf0VJ(97mIL~E8-q_%5K#ignya= zC(&wtx)b%Ezu!_;UCr_A*)x6*$Ya{Z%*^~=I|7@VZ}NsP)+8JzZ5#8tugG6uucf8s z$I@3=Uo5iFGEPAMD#=bSuu(m{~M&3_P#z=h)s&VknK45^A-3@?6}W1ty;n* z>m=6Ql%1Xa1Nv>g19xQXsL*^XCQ(s2!mNxRNaLu8cntuvoF|U#EL_9*`-Z_9h1A~< z2L=Xq?B1={V0-c6j{8cHotaGYqck!+(FPz{9vDewCpks@)%4x{cmFq2FVMi9)5bev zaci8T(UombvoY8jU&{x#k#KI{XgzAlTkvFqVIsUpt)s%lOWgLqz0g$J?l=E!{PVQ5P3;;`qoy8b`pF`cRi57(~Bg^nbFwjY~5?3Vz17R47umN7Y!K{v^jS*RZG25~=;q$Ht&C6k_Ld!mN z;dd9lf+jxWmy)DI|HY1v`#Me{)hfm`caup1!xezXIl=tH z*IwYRkOf2sAk6Q9fA~6-VW=RnzYz3EL~WtFxSdL_D}TApBe+? zQlQQ1v{}KV7EkqPO`Ot#m#u*ZB9N3$tiT7&Zra{oBoKe?Dm|-~ zO#K^cW2A%yc8w$&{jSF`bkkZWh6ASzv#_@c*iqBTk3tsnHry{9OXtam<;Snc_{!)HrK{g!m;+t57j8 z%vl*B1^XDSKeB=pB^A9Cy%juGe9d&WfB5j>^PvWkg+p@DQY!A80erCata|6)&shwy z`_LND%d-!4@@73nN526t^6gYK$8ie6e})E;8YOo$NHt2pR*CdsOZ)kl6VC!(fnU9P zk2G2C#_X3or415YarL(kDkN<}08z-BjM4Q00sk|m$%C+l-73n(q*;h*z)QXdd6`ns z9ho#|4{`x?9H+Yf?lK^n2A7;m+KJ+x)FvJzViT_ol!KzGWF41|(wh;IcodnEh##JS z^n@rL8wn1^8Zmfg2OoSkekX3KDR<5(geY8ZL}wD!Mvy7m1|zHXb9Teyo!j` zAVjAV9}S1Tr~`M}+2|f}^4c6nDnqtx?Use26baV?0o%{M>t@5_D2Y@zf$hM*UQrxx zm0?7J+ji@bec-e6QX7tmDvTKe`D?#Tjt%<%(1u7;2-|FkRu4vl1pw&?o`4l=PRmdk zLB$Zq$fj^yQyuv6D)_O^Rv9!1(l>=*XlVE_C@4nk&G|3ScI?>UYnPjj!Qb!=WThsK zq||&M(--N(U4YbnX-X^v=_CSSDp7GgGDFdW8>zbW`4~R)2nF{XWfAP)r5YU>DpQdj zq;yrE{W6A{PYKP77Agq4uf)HMV$+<5;yt*Y*+*Td618agVS;I1@)pF96tWlLeJIgsg=TOAX`hPU=!H*VZe zYLbmgc(&Ak9zDn!Z?8f=bKnU~!mup+DgGE;80?blF*2CVXv7r(wbwO7bgn{2-% zaKY@CK9Nc%kR6uc%itC`a6f~};EW@Z@I{CnWHCMc-e}EE0-q0A?kS8#_Piy^0=MU_ zPvGgYs5oQXQ=BAA)I&t5(?@QiL+wKt)35`s8zx*X=MZcB_ls1E`@ZsFa4>Vhopn&< z0CLfhG>`w*+x>q;2ePK|f0^h&-q-AhLTH>KaM;`!uZ9tkws=8O#YFm>p6sk) zuLB-1`3YZtoL3DA>8D{eM9qA}RX~i%4CNA-J)*!sC!?)Jf)+O)ufy9Dh+2p(ME#2a z=~d!s_?z9jH=Rmi#Ex8u)Yfo2hi^K>w)GV}C3?l9_E5qtIII)2fddkxE8nT}zjEbM z$2ox45)mh^7zi#-#%C{Mcg~AcyS?u9^v~k0D85077g#1=U=L&&)Huof-IGv;R6|nf zlzD`g+c2~FBJryS$^_*Ol*Q(Sr(ncVGEblTun$xY_JND{@hS%v#3ueQRP3h2255Vr zYZ-+n19a)gLZlJ7`H-tIS!~6uXAYfqA{q1y>3Ick&Z{f_C4J>O7@#)r_>_||=8q3A z5#0ATN|^xdp{oalDd8p!M6fVLsfNWRr9fWX#Rl8z$gj^r@vjq>10fx{eH2?y0>4(G z2P|OK$PL?LWdcpX))pyuY6?Z5TBL1|+nFs?kk>-^hf+CJFle|p&Aof~zCn-C$`#P=h2w#c{t28B z@4p&(KniOZelbs21FrzFr!Lj&m*KvPT`yBTbuHfTA?3GDL5i0Sw_zIe<6W}LifsNl z0r90+@6+5_5mS(Tc~@amOkwAwT-hXYnem)freGO+A;`mbC|L|y(Atn#&GB>hXt*PB zw8>pQV``+X(i{2|uYDkFXn=*lobPZw%7;nE$jbu<<8bbe;7!N}irwr@j$I-$(?K41 zNLgj2c-#Al-nYWLC%O}@L!B|V0HF;?LRx+iJJop|fXAHvzj*L?(;f~?gMNUUe4nVX zOWL%nArbgK-Djq(c_9#)gU7Muu2iWk=0*%31j-BvBee-3^G?E3cvgDv;wsU@c5^&&1Zs0q@Q!PXkqIqYrevJpXwMq5@4z}jyE zn2qa0sMv!Lm}k1IDqN@};+>HDU9`IX5{Yz)wB(4Wj&4#!)Y}%Jf`WqFDEX<5pe8}X ztlhvw`CKM{tYcj_-S3rK9(g#|kMHn1sLXS8@R7=PB%TmT9~4VFIXOvBj-+NUjZz5- z8lHSt9Ae+N1|AC4k5j^62o`X^S&8y{P~H1b)R~pkEE+w`LZqxcizk$|pIa=Ug~rv@ z^^~b;jCw1im%Yo$$>Gy>fy~$bggdU=gLgar^2_$_xQciYpk9!j)}dD8byMvJjb$G> zibAkS@0+&Jz&jL!3R|H^PQsajcO74%1d<=jyt~>J_g09*R)8w-b0rv5W3Ove&lP<} zPo@Q_25!Gm16A#KrHNY5WJlTGoim5k{daEr_}lwe!1uZT<69h)ME5o~4wXBWgpFlL zD$)uEIkLF+oFR$7eYp2+pC^(LRxMw7xa{N1HVpqqSB28WX^`Y@Ug9@J4Wgvy9Ke3` zX}^=J#y0!WonRm0UpLta4S8Scrv8-Z=4-e8;O^prUnXt`zR0O8F7%Zh!wpq)3Xj*Q2QLYKQ~P8>9I(Cwg&MO#R)F&+7{|YVAF_6{PZy zq%}o=$;rR>f-Nen6e@*RD#SC{X3tDdoPt|0WAfL>Z(%c&#*$VoijuaSnht%Lqe2H1Z{e25{xMC6eSmmQiolxq90=zLCw18!>{@%4weKs7l?_>GrVU+DigiioI10I=NI zN||Q;T@qdiqpn#kgyF(e32L8FU2Xj`) z0l!9MW~ExH+(^vMA^zgFkB?_ZM~4y}^y7N0$#45&miWFwfwZM}K0Xk9H2eYT zT6Ippk2}cYxTza@Fjd(*rY##P%DrEKR=>KARORX&!}p=2W%vZQOi=VU)H+GR*?o{t z!QOF_oJGl7=8MGbw+izkjF{8O{c(WXNMe`=SjtPBT0ul$1cXftEw&%5w=VT??#r~# ziZluHziRHq!W~nU3WZ=0dnL|l_g{)RX(2R2u_(vio!Z$a63d}T<(B$O1>F8SzUo?L zvb%wl1s+kk+N-Ozah{16U}2JXYLh8|FNjCWUJIAk*7#%T!F2%BmhlNXWM^r)B^iUK+9>Ied zlP9gLb|QKYg3;Dh*(1sE9$P{kPL~WavARul8F~r34cy|eFrIR23$gd)Fm$<8CAy2k zJk^oHLyqZAI8y|%0xmZ7N6+1GOqv%lZ)_@vA+;cYN~^ST!_lf9NUa#HY-kw6bkK*3 zQCr`(@RWHDsknc88#us|0^WLQNh0bdQN-0$-~glxZnoR-TG9EOb?zy6y+eZxWL5W zKDW`g{bQx>ZY6`a*15nLvC5BqP!)3HYIN#)3BmiaZDXJ4-6!gr-sO9*HI|Ff>uY$s z%LV=p9jwRHUfc%qU8LkOd=4ui29_PIqGgC*HqxIcU@@g08fNvf+gz|dmrg^^o^lN(&dJ{pD&PJ!&!25uRdwy```xuM6b9!S@Ksvjup zP7HlRkq~@i$Rd0X)OXpA)C`MGiNO#7vNB(1IOX~~YKw5ogbt&S#BoID4(TazWQh%b z`>|frtXIdzqRbntJ__1C5~6=6H&Ihs;@H;_6~LKGk!4BD zFla1Hu=~Cj#)A!$z_l2mD#XLYq0By4;mKxt`Gz=MZhMDjsH>NObdp4+7*PV_JeHt~J)oC9YLp=}Wa5h>i+f!YVsd_ChfALd|ku zRr1A7T3zRWLwsxg=W5~uONap2Atxu-RpYYDzW&m|tuuo|pA4bA6=NS1w)%ANv%xx2 zAp41?|70NhZ}|la7R=0?mVLh%tqa)Qz?e~7IXyGL$}Rf%-g;}Uy?h>rzw>+Orn6r< zyLCfHkk!#Oocs9Kb8g+V*Hev=&&cF`@z+w{PrcqdZ8e@98ORRno9X=`eV+L zTJ5xZcgb{IUs9E|d1rT@TfBeIclTRrY=i|1X|ez0e5N@ZAmsny-j=A;OYLQk)xqF7 zWNyHFc>xW>+&^rJ;Sf@1HPWrq^$H9LN>D_XCS-Z?6cn$zujJ1^rl;EVrM%pwGgAN} z__ySGv5v+)fNDvHl$36&y5`>mW&a}t%8F_oN;a)Dg7llf&fxUNJ&-4L+-VT*Wj^3{ zDosD>(CH`>V?JGWrhUC^$g1>Y`3l_X7c3lI(%qP;3xd$2Ar=YK3dJSt{P(kpnJ~Yi zVcfK{J<|MbZ+e%}muJd>LD382;9e2Y2swWSN}85GQ(haUFJss33whMz70&!(3=c$& zRZ7o|m%@F13=CwUSv@Trw3)%XrIVK8frp3X^|%EJ@$F7l;+V;?pq~A8&Rv=tEKbEC zGTK4AEPKx+TD2;DX3wAW2h-qTfBs`GEP^-hIl)lQaRQAk0i{E;DADvyP|kz zpmwmRMu_p=JC)W)&c3_G_^WI$YoMJ}oKJYb56^3h*JF@{jJfDDkrlu^MNbIb{^J`G ziX%UiJeb3%AY6&mT`+1XynA4^{^P@SI+053VN-xL_rMl7Q85GJl3!3zkb(R7(D^kG zPllN9I&2$-Q!Ec-Bhmyg(MfpJgsy|wFk zt@T^l<@c+eWTur{PwETkrd&uVkUUSXAiTwUvAF_z1b!MPCw~uoA9FXFa_<4Ywb}Tf zi8%GDbqH`1uUc&ceZHu_!mJ%RM*&M$sGj=TVc>M2jJ?IX#YSb}3i&z5-*K*N&Zo_! zS&*=njDo!3gr+Rlvoq6ERhQ16KmQ9@|Kipesps`#?p zDZ?I$e8(H8!@*ojx$$Ttij@qT0)St%a0;Keg@>(gG99?1dIq9KF(Cc}JqVxZg5lo# z{M73r=J zrj;>=KnCCw6jX+G0Xsh*Oa1k;SU}3E#TOb*D#7iRJgx<`!_1S|MJ@B5QNLxp35ZG3!ubc|S-T~7p* zki2RwgMEmc?*=xw2o~~0Z6MPiNyHI`l??d(l%q_Tz;-<0v`5nUF9*(jdboa9gN?>^ zi)BONcetcBz@53-W>wN^dde`H-olJV>lAA5P;JE@*-GT+XYrJgXFnYhfUTVO^6?~Z z#b+5b(F{$Fc_8r|8tNh3A}!gx=JnZf$YkhpTeO(jSK;d7L(swLhn;)oXQ(~j3Wfw< z9$BZtf3DI;!eTFF%J0J5_)ugWm?vMUE7du(G3djww4V#6g8F4bJ)l`$Nbp+nw>R zKzKD+*14)v8%hPe%+bh@4issR9DQ;0sppX}IsLfw&X_c@Z#l$8<HEGnQ0x6lkCIwJh~%X`P|ai` zlrHR-3JninWn`ype7;XmLAbs3vO+i`^_R5E0W%e^Pf?x|X@fb3%I!FS0YkNCerW5$ zRN);y5^s=sqBg}`zcPV}Fl3h0w$f*N)2AfC|73*GC)_0MGZScNv^dXhOsmsxO6B9Z$6g*?uBvYo-^N9UnkN z#VBf~4!7m`x~!=Y-OtDJ5ZC(i=&ETT&5uyyas9S!+ZNsl+j;h;ZO>);>AuThZsUj4 zH8ifu?tWabY)EOV{I0)o-PzOm&JWxouD+g`8dI?hPVqM3$yI~eE~wG>vGDH3NcWj( zh4mFN#}0!}p8L?@!=gO##Fs+)>4k88-JZx4h&>j&8FFPGaEhBc^jURFC zAXM}Me?CDNaA$-b+rO>jV(8@+R}6Uk!5w%rWy{yfOfVA(b^zwI*>tsm+cyGE$_i4d zj||(b-RsGGH_LTG32?55<&u%Cg785WCht9e&P+kl(;Hw7n_uXy{MR)+xP0%50sqrj zXKzk%CD2PbR%(J2k0=g)!iI8=ozYP(U^imv+^x@k0V?;ZgP;-aV_&NPLZHNkB>I0F z`+5lQ`mZ%sXR*M?NDs4%&k>l_Ku zx<}s_Y!HS}p1b4qWEuRTR}MtMeA#k%>#>;oK=q%wxu5<)8)<3~Rmeo|FGLn8u`5H_ zO(Ts8i<4t5XHlm?sP)e23ifOp_IQ;L=|dneiuNkiR>Lms+D#?7@~FxgnupP z^78VQ8sjkABXuSW=-2453`VQKN?CJN@L$2~PxFCoNd5Ok?NQUZl$Y$k_@VPeZ2q0_ zfLPCaC1jbI5y;w~&CcQ`%1lV#5RStY@b2`FP5(B>0^;!{M?$2JA=UB2EibhcE_(wk z@T#1g4$G!?|Lq(1%aCu_Hc-5+dcT(Kq#o}tmm7XLzB`DJ{`tXo{oa;54nyYXrO*Jy zf1AcLMUZ~?NSAiduNbl#XZS3K4xTdxs}Ghxkn{&R-dUg^e1}$^ye{+&99=Eve`Q+9 z-|;o!X*Z-WciY+Y)E+xwKSr`?Mr+r+rWrq-W-tKmyx2@w14Rxs(~m~0^94#-KdX3tbR&QD z;g8iQ9t*G=h{oRgpkbu zuQPI@!rb^-clTe0S{(A>`-wDkEycT>jgvzM7oLnSLx7*tn+t>v7OLkm$nUN7I@rT{FOG}UO?6|gpd4sFo>?yzG(;uZ=JX)X<=zdkqF6llQVu&j*AuyILzb8opXUHN zw!HIT(ptM@)%LC6YPqfJ4668M@3p&F9I_2Eojx!LXYfLmPDY0p4GY8Er{5%*Ru=UY zoubX`9(qm@Xa4ILx_kp@>GV?0BQI?GNCxQdRXzlWd)R-g2sE@7VfV4anzc5-nGtU0 zGC5F?-?eOL=W8_zJ7`uXfU0#Q&(VREi6aS&eRb#2MKY?sU# z=n-{4hqsKD`e!tZAee7@??!$O76TP#C*{!HI!XEo+rcJr548Rdz_7|K_d){Y6>4@! zwR{%PEn36>3I9Vu{|XbmzW11ij10WT7jB`;?5v31_$5j;08Ht$b(7b9z&8F6xf~EV zvP&95s=v1wk=I6k7jw4|SHr+Jmu(N&k9mmvkT(eM>`C;8WE2R*7u(7b&K#5iBWL9p z*r5+W$+sVip%=-{^Yv*7w%P0VdMbjX02)4@#pr0gNzwgn2sf%pHVas}RV_!Or#%d! z5PbWxXXG-yp`4#eP)gLUp#K0K3J82}-nfb`0^Ne*3n0Yz|3jw8H}J|K?F$-KKjhvJ|F`zLQ8qEimG zf42MAUn#MHMN;9y%EDRP26vIn^(e!Dr;_()k18K zn6S@*`4)#h^V5M5ci!PKr?4XYNz9(pm6f1(Z@Qn$va!AO)xyV8^T7svMqq|*=dlGG zCd(t}H`ryWAS&ejy9_V6sR3z71Pm>T&z?5zQB+p;hXC-rQC`zO&?RRX908XIvly8f z&$Di}>n{OWzmDg3cG6~^=KE-i%6h~X8_QBH4QsO9vn#(PC5btX)EFnKv~_Hw2JE5D zfQTO!-~LH;9Z@g``ScAj5*!p9XamkTxw?!7J#gL{1O+A-$|+oGmW+rm5G$d8w}dM- z?W5-c`z3A{fT~9lZ!UhjQaN*ljtQ|8B)-_c)3#owz4EQSWi(Xt*=b8mp{#5ZAhGrp z`pb3Hpl7=RU|@SC`G*jOE_7~MqQerfD3iR6KE^B3ILP22_#s^X*TRS~s?*d_f;k^> zT1Zow=c#DS_Xr%9-}xjH0d&abBY>tQTpTB;FPO`RfD{jTNaNSGkd|xn#i912QXmF9 z#TJ2C1H!4OJ)&tE@TyFKzX6$w2N29vHSeXKgkS{g*)aIy0`)B4H$d{Ddf;wHtMN=v z7H)osxK^g}yZ*P#0_QLEss0ahJZU)>L3B&S};w3=Nzm;`cwvpi-;~WdYey{ zEU;1$%y&dB39YRs3+5zhAr=C2-2bznhU!=u-tUKV)qD7c5j`)e18tUuWK9BJUY$Tj)tP9jRX=tXq>79e%n4r^*(H+rHu zkdc^WAASRBq}VS$JMA`e5j9hjKj%CUh{h7Di`4iFE_wISpKA z{0G@_0E{l3`boC)f&w6I)6Q#L8k#%~rZW8w@Jq!`#D+{j^8MNBIDNUP7iy2ML_|dV zuF$|(1TU{e^ngp69V_sMaak96Zfn#*1+&Ysh;VHx*};td-j|K|V)f8C^#0=$7q2Xs zAj!jFPW)7OLj9*=1nhZVr?n;Zp69NFm0n~@WB3Yc{rFYS@alD|GN{m&iZx&K0c>r% zO(#f=tvOmizk*9;D62rG8Kz&U5HG*oftrS?5RVJtCu^8aP}xW?UxqtCSeMg|!m??B z7`j-uUNLY#-oFHo9O#jfR*)kFI1wzmD(MfL(1qDj#4zdvgL=+3g?PlQ2+gL2z?b;| zEO>|n)DmFmcDKoGLfIy4`i9nLGbn|Y+hzq(41V|TuMBfi-4BTt5WsdE{AOLzN?%Js zbtr=BcIACd;FyhuLB=Ki18J!qe25=q&m2;vz-{)E-gbDNclB^J`kW}qD7n5p`RvD4 zj?F}RBM`|hb=v+;U9;Qa6*Gh&PBQK5IC90{=#`e8)a-x)Jl@NM<#638T~T;)<6nOj zfrN0lf|)=asy8A2h6w!|bz?|8Oj1%(T0^PL#w!M!**np-&3g{^h^eS8tz3W_{jhKG z-a|~fOV>)=SCyX<(|8n|^*Q3u3LWSizUR@QGJ(gF^<47~P8#z%7y;tbtgJ=vu8zT1N6^Lfl>RO-Ww~AhS$*k^ZN9_eQ*VRfcP<9a0-rU@(q{5{~y7; zymR<5Y)Spez8gnuFJRC6aw4!9DQ6cGZj_y{aS%)8CXRs5?^H%L=ysd<`0~>3DE$FI zMG%-5XCl5I;q1xRf?8l^jsWY7)#0^)oCUPZ*sFIt(1;n53d zO4vc9txtG-7F{x-o;Lr^I$`~E8O|$KYeVT}NsnOy#|8qjDfgVe7baCz4&(c{1B=YSxO{Sr1(y zi=f%)#`N#U@~t3%ya*O7s&&?hTWNvdmm8-pn6MquzMu8Yb8Q4jMEOl;7wM0HQBob+ z-hWop&EM%XWCEAj!F+xL>?L^lDuOu<7fm?Hm?KPl!8RoAH1r%8jvC^eguwjlBab8m z4KnJt{6k~0&N+2d_#ex478oBf+-E=)>5z`udQVkuV4y+*wEUVVxfAT5C-Z%Y_$Kz=?`wvNzDh226~FLR z4U8t6Ay4CczWzeP(qJhK@E>+r*Ml*q1p<5_{1K;U1aAS>>Ud_)Ep%>T8dI87yY07k zbb(Ow1l$**Qwk9GnwFin+-h&?oy@w40N-6pcVXoHoK_m>Tyqa@4BgD_8qaPct~U!= zi^>@P49E(V8@)CB;ybicEsUnHk-{r*9ldT2mr45w3jHgsm$eYKzVBsyhGU}U%M;?w z;&y$Fg6V!C_K54osL*$PV`X5<3ihF2VjKAQ_8;)(uhR&C zV@BY}l3M~1qMl$OsroGCQM!laUAzg_RUhCX^m-8v%ML=LLh{hb85Y9(tFYiFg znbM`0_F%>m0>5!u~9hOH25&@r{uFe?Lzw!!;zC3SKpKH5eq`TKQN#vTT< zQ6v;>5iuKwctRsE{s+bxf;qVZsxCg0-eTDvQ-auSEKh?&eNWH_*{OY)m{%hmkb!jU z+PSL-pwGp=)*4ThO-}@_jtPuUoNO)fMa<1w15JxUI{A`K4LOdj?oeR33 zYvZ=Vc3h?ju2G}5l8}Y0?n9psMnVk5!aIXph}jPs&D=B=ooX`o7#f7|WTrVT{A+M#_xU&t=zTI$yKX03uDhyc(PwzO ziRX~BBRGg;Y{dR1(J4L$VeZ;5>p^B3xY1Zl_s1aqqT6dkcLuMi{pdC|{-OU{uB%GW zZpYWjoDWM`9^G^RYat`T{uVA;{ClZ$F>wByE4Q8ga=OV>Kl8}MH~m#YdaEmrJA@TT zU_i(49-JuJ-QI5&p%0rKfEs+(ZQxw}gtXr4xglp+wzT`7^$haCUhG)C@4`sq7g#|`;!QRNA9wPPP(aXK$^ zz=vxa#0S1M^Fso#Ai~4wZIiQC1Hb8u&Q`%Z>}vnBykq;HE*6LbErL)8t}|uO!RAAN zaM+pJy?fD155(&0kpS@HomM?+w2*AG83MFf9d62HLxpayMLfpRgr>fY2X6P*R3;j$ zI)Yg1+BtP>(}s52?yn-}Azr&D#-Li?K7@VoV0G{92_b_FaI|aQ<(9ECdXZv~7(*2* zB)JEe_(o6S0TL6xtV=7Tm&OYTN_QcbMPB{6+WKwa**1b|VC(-EY2O`Bb^peFr%j@e zmC=xljI8XHl38SDl(I*%<&Y#9l?Wj-dvBQsg;Mt3mCCW#v5pSU_3cLV`~9BR>v=u@ z+_&3(obUL2KG%3(@9TXHAUQQ+2k(7*e~?8xKbf54hGy=aPevBM@fwIU|$k|_i(d#Dmog<0?2wmz3qwxxWkS|}pyj%LUWF9FiK`&(xm^)3w ztV;+WZzfL0w+I?2g@b+2Pm;E%KeZS`PWX^g`%)zU(4 z*mtOfHR_<<$`Sh8p!+%sfL8ZJREZ7j0G$zDv-S_+k?hukWny&~-$?bPGhoIjlb{cd z10wT-NtCOAeunY0Txi#k8!7KqT!z%!fA#!(xBuavr2Ual@); zH`KUtcAmNUKQeWUhDbUKhl$xRctLwQmdEjasPJ@I8=;YN*J@wqYNlq66x0t+3k`>v zp5YbMf9C*wh zMWQlq<5Er(g59r&3Ge4%RzhhL`bF4bqf;lrb{v_tIlpO49{d+VnncJk*~OSpG3CyH zXDOJxKG`T1>$sJ(LUo70*^h!U@inGAK|!OcR4N zkcAT~<#GEbeODZCrFC|ZMDMm_JTar7b^~0YGgs?sExEE`eiQv$ojQB{5-7MeC=65-x6#~ZOg*G z=p4Ow0*1BzFJP)O$7#nVKJ3>=`&!QLn-W>=pP7f&&LYh2CvXbb*ctembEZ4eG%9Cw zMe6i4Cg7EZ9HJ9y=WnN+SO$x%Nbmm3h6OUtA=Pg^;+9X=N`;LZkB@d`=~vD`UoMSF zo7OHqDtVv38(SFJ2rlwkR*f0Z4#GJTO!n=39FFJ zmon#hg;J@^wxJg=j_0eb-m~L^=~oy~4i{skrjvrEm{O6%{#;T2(V#x`Of}Km_aNm< z_+EY2SugHAi2^0izmx7`RBoNxWDgDYQ@V6_g|A6wLC~Qw=sK@xxC*4IYL_NV{L{pl zS7K!CO+;rm<4MFVec1`3NrX_yCdPzjHYV&kuPZo)v~CMm#_dW3HN&ekye6bAZ_N!R zp9xPLSfwdR4HKh+{vzx2nXpkQKsYH-PDd149y}4deQ-qjo_kkUM>clXWiU{dBz~8u z65$i-7d|eYS?`M_@#A$pW@4W!(ba;-J&+a`0sDQQ%bU4|8wh|YycXSDtd<>Uwyg5S za63$M)uPKR&Hc!m*^q}`qh9il4CA*9FKE@PE%DD9?hr2@fs8HPy>CK&y1qM8BFlcE z5?WbB4P<$m?Yq~#!CCD<5=o;TtXUcoh{A-(+)&1$StfVSU+q_9$Xd_#|U(j2{I9 zb?*o<2UDs^vpGNp+{LK$hkCfxgi)>U5Pc_r^{TCa7+4i3UjD1U)XuPC5}NuK1S?V? zZr3@+u4&Nwqo~oUfCoi?AJcpSfpv!Fdv+8uXA$gK+#j*Mz#kvK((U@E+LTI7;dN zr3S2824`Ft$B`_q&rgn=x@auwZ*)yScLjXFjg~SCsmucBd7~5nsk%?dmR=g)T>_09 z@81g38)l^qnKE*!-@#&93iY%7FR$n80|~Wg%~BX3!1j`qG)C~*%>`^-GQ_MOh>lvSF%W8pLam=3p(-NaUDLN17Y`;t&s6UHC9CPN7$hAY8z?Qp*lMUAbo|_2$uW`G zG#}1udaH!ae)~}X5ns@rlNbrv_+@|=V;Iax)1i;d0K=`*TV&19T7#t2&&6;ivLlb~ z-_4fHLKp|orb#xoCK4c#Qm)GT!8(K-o&0CyQxb!E|+L zIuz{9B~%R?Np|K#M~H@;`PcqFLAuJbIPO_lP zQD0Y@*8KZ-$%7(FuyIh*5X!dOD(^gbBcToka~eynLD<#AZO62maLc#AK|k$X32;X| zkY8(S6XWAQ07v9BG(gqjJy=20$&u~*JpeMIrHLq5PXlY9LK&6M3f&7FrbTgGpkl!c zJ!f;pF^?0mD=rt9z`P(752RviF&C_%$5JAsaUnsJi>B@7GocP>!)Y1SmWb!WoiueG)5GRd7OK>j80Q|b_O`HWtm~e_p zg73}e)*TzCr$V^Q%&K&ff866DFac!VX~XV5s=N0ywko45C5+$mpvyVU>`JQ2k9M)^ zR+##mG@Y&drpv?7R6hrTSbgD=ipHr#iIxa5!EI*?)R-2MGM*6VYG)lbO6*z~?`q1= zplyf{2n8n3O1rsC_RVM@0&Cu6Mg<@g89~6YD3sc~|M3{yFY1CN)$L)S{sitb!!3e7 z$_e)-tqua+Ye2^^EZ>jtgM_Sb^f1dH*yKr#e#=rL>?1Z@Hix0lkM8}nxveo#kNc5W zfBdKNN{Ej~?JwAEY2{eYafJjDa-F!8 z_v$n?jmv$=`CS_`)HA{t6#9(!?`RMRAVp5%&5^{-lX#7>$5W4{BpecB8z%b6B)3^v zh3*UcY<5L+hS$96O9c zMt;)O{xq$W*N|#s3_??U;(fpTo=o`zgfeDj&(z22RfWb~Vva&o*|GSmTmxw*6=EwJ zczeG$MUoR`y-LI|mP6xqJxhW~8{f7Rx3B1z7dzieLDSaLD}Xnc^|4P5L0scF%LV9Z7#e&Ql?Y#6uULgN^>2yx(oVRPKVjSijb+4&B)?ftB=5Nrw~z43u#%463QB0v|!W+r=fjVjc9_ z#%!Y;Wq51dAS&!6v>WVNv>5j+0{)g0zb7t*gmIz#mIn5#Z(}uO^2ImqE+m$IXtnRW zF@m?=A7Huh_HN7>pUhi{l9$|L)rlVs+B!s6^o9> z_^&#UIbt5sA+giQ-MaPDm>JA2)p^F{L~Y~G-d@3-pZdcbc|%nDpRp*ZPX+$p8`;xdO;rXw)dlyL`HQH zb_N8WiT% z{K6KHF3G8@C5N_)X4BvpQtAo5B(c1y1<($x+d79cg7^R}PbRQ`0#O;elF#3%f6^F# zhZX5xz4qo(m;?1VHdY1*s4||EYjQCZ@_R^zcgqck*(Ncr6~bVz8@P+L2EvR;EG zFjcR-WXK%TX!7W$*xIsT@qDAO<#d8RO&dkE6wBr3^uhaxX}UAr;z|7N5|VGTmZV0o z9O?pK#N4a+Dh_H$M^d1EZ}W@qz!Egr>~nQG@)XW;5_UR01v7S==2HIZs+rK`ao!8b zhSo*7p>8c#87}T+ZNWqfjhpE@$Vls1moJ12SQ(ZtWHpq%2fy5RxO(nCzT_AwOF0J5-Cz}Jw+v~vBTI8+N?faPoiDww|(vkCQodkO8q?> zgtC)SZaC8x#t?K|T8PZtaV>X^&g@qy7+aJ(*~Re2*d1NOw~$6y{>&Sjjj&RrvYjy+ z7uecBL%v>{t{!_LFEB&5R2@;CpkA;$Q}mdamS$-*6c4kF*deFon<`3c=i~#S8a0Q~ z3>UyxhwGky0%XTbM1S{bU`KmS?ysp0I)D8zy@)0(k2pMGta&Y)fKd08x_%OC@iUQOmt+V|(hTb>sePpV|E2Za0@L55>RBMs-P(Vcb zt{hNKqHa>9Ee-oA!hF$iVlp#!X&G$F51_d&_H)EXudQ(`!Mta?1++v>rX`I~enigT zYB(UxWiVtfF^;4T+G9^blREF@qE6Dyc)+M=JY8Jk)>f19q*eYz#=r|*BhX^OB$Obz z9$(MnE85J9%@id;u~2b%d@b(jpF=%+bb00-iKPanV}uNj@u3tJ@X35i`6)N4B1zAO zMRj?#&+eg1Z^}V}*qJGsingJfo-8!K+77BJ%0vr@gI`b7mF%d2Y=rHad9r+Tohg+S zjN{wY_lE2$U|{-l-#?PgXu9u*l2sp0@8J?R&R;07qp2)wCr0{+6PDCfxN)Q5*5mBK z(}vCR#s_t>4J0)QO2QQZc6~Y`2J)w!3c87rLWcbS&MKCsYkjce&%Nv0cxVv{uRT#? zoR|WWmXpO$3i2N@dw|rnYEf0!$IPz4)KPAFA@DP%U`uTDYwKm-ag|5~XI{UyhY=X1EGZALpCB_OlllZc}1gJJRwwV?%(j{iIPg$s-f`k2TrKOfSeX za6vN|+n3$`vUO_+aRbWBv=O>cD{JZ5_wF1Lu0!RNR(Trc8kWJF%UTpQ+o{Ru#?B;s zTXtd!yH4R}eoZr>!P3SdGZUI`Am|da#A30%CMt1!Pp)?Jry+$O8xWkx;jtG(P)Pdz z_{Od+o&u1Oa1(Ql!qYEdi&6_)FGaGE=C>ZTY2n3@YT zJ|t90fGH7$UBvm)8VaR=pgr{lB)*NGJ#_XFEGqb)U?@}FDZ7m`caJT@dJ>c&aWqI5 zG&#L|JRRjacV$?NS$#~?e5T<#l+oCr)W$REtEC%nn5Dn3of>PmD1y=-A0mT)C=;`e z)@yw|=C02>0k9KhmE)>UzIzFQ0pO1QMF$05htA$Ya{Ea_&yUzU_5P@D+go<5NI#fK za?<5N#~8Koh;m|B{Xv)FuhtBxKULZ%i4T-6q%<8>l+nPrfH_nwbsa+zMOaB}N|!IU zVXbD+5gJDo>Ou+jhkC!OqSJ!|-Dgl|)PX58Tw%b9-&_7rP+YsN9kkJ~D|WoT6-5lr z-;WF@U{{g{E+*HKC2#9t|HtDM@Ns+(vQcgnmp!iT2|S0P_-Lj_9|~$tvS*ZDV|qSg z=Qj$Aw^RFo7+#QjEwO&xau4kx?#S_!Jff!epEC#&f!8sn#sk> zLW?a{&zBbLWH=+4T1%0J6r{;<+AMlfb~R}U(_fs|WM3d!BSJD;-q;O&B#P`AXJ0^CaRX?%57FI}?~nm=_tqWP2Gq@T++Z94<8{qrLy z`1sU}nxZWGM|4xzJR`+Wy`|bn4pt&Z&V(b#IpAXY&sdmo6gR(mbRF!NzPFsOkAqIUuZ;K z@&Wuf6_*pib@QRfubFqX=B}X2Iv0xm#ne^-X|29?p+#6qV;(oepe>wcGAh4N%XHK5 zY~7h`x~9>P`5RBDXQQ6+9x@adFXA5qLRko6hx#bxejw>Lj;sDUpwE;ABom5V&y5CN zvZ%(x3E9P&k|xUsbz~X=D~k-@&{LP6Js5dpD-mzV&?ve)@(r8ws3xa&{s)Hwetc>A z^t1$1>)NR1=&P`Lz^5&jxhBTq5<>!wzvz zsSaOm4n2dR;HB8s;@^D$W4su35CbPZuJC+7!{j%Wfpje%L7Q)nX7z-=!pz1W{;MLv zNWBf@j_&jG;4Rt`pJJvDy~Gf`KZ!|5QypbhOT0JxO-9we*&YaFUb#}b*^#Xyb|8Sn zPoKH#GAcxVM>_f*C}MKy;*J{xc8_(!(>l!%TROKwRqRQv6(EjrJeMPUO$$;%g;);8 zlPGw90>~x3(%;57Aeg~*#v|XPrP|4k9dgQ$UU2jX=${MdmKI1GVD^9VkP~^R*UKL{ z+ms1qfi5cz*#)R_g+g>e+XQUn2jrZ3{coF0gK3)E3|{N%PxPkkx)jNs2~wZC7~b~i zzPrv~N?LE%aCHo!X_Xv}t|)?gV8uD=Nx$}xD}&cb@EA9($UEo=*=rzxNgulJ5Z~*? zM;MBD>|j(q@s&^~aseYHH;7u&eKR1>S$-Nmo@^7{#tLdWmKRM(HX)t>R|2a;F698n z7Ota*;;#ffSVxN+RQ}U95{Ntg1|&j03N&A7{-(L7e5!m&&l%`Gt}gl@!6*p1fHLYV zg%8LvfYP+f+BziizWaFI=pxk_D9v%}Ie;2YAz3Gla?H^!X9bYX-~Gl?M`R0Qx}d!_ z0W4pvUCT%}?8YZ<*%qvcoHQ#sHf%I(@++%c0R&69-y7IFx>_|Zg+ljnXn!VXSvt=lmZ>rL-F zB~!dbPeQmJd&rwVnnRRhO{al0)=`)3);oLfaI=prM>$rdFJx@wD1eQH1Jm))y6~Ep zCVwljmQ~s%u&;a#(1IsQQwFTSPY;2x?v;aKD_c zC@^qZAGHvy?77hMMy$v$9)jjzXuM_t;>LpL9^Gs5MTTl_zh*D9XywMw&LItLh^Xq{ z3G#pX>D;~Lyp`ADWsgJ*?$OBD^#W=NG$ft-&jFX?Q@*mse5;y+f(*mOi$|ZdnB#QH zSzkjGq1nRdy)Tz{5833==P_LnT*n(Vgtu6^$abxQBA#jKoQ^D|?{L+clus=08z9Lc z-GAdv!~*?{LbkeRnbZNlTaC%9dQK~#q@H<$S}y3J5Wk*e)>TnYGuH$Ai5bLhKpAD0 zwDHT;Epv@*AVxb6nDpx7#-c2K`NM@1-kBVhMXeebM;|;jv}{;^rrR>OMMQ5wRl~3o z=7F~OEGyi3|Kx}Y%r0%?VB50@(?Y41M(6zyg{-%n$AWR6O#zY0d24sID;)}wnAKBy zOvfZSBj3NnDwH5K!zP_mENc(-O2%}ce^9!_HSLZ`5(2sQ~PM$=$C~A7aUZ0(6_D-kk$r zslm>-v6s(rs$aoAinvX|HVRd>CIY2n%x8zhqjJsi8dU%cD|{CJ_4IP{IM57(;pKE8Q3*SLuT=QVoD@A4_>nTQ35c0;5aUT8NGbRiYUa^)Q#ERBo=q zq4t+9Ie}BS^Ye4qK{M*2Z8W)mcJ1ubX1s~KWZz5KayilmT}-<#GhBZqm)=LF&@A7F zFek>Jr#x-R^n+LVjzs(=;7XTDM^ETlLo|#qLcQ&tsUGg+dYRdag<2I2L;9n(Mx>Hx zSCefZ$Lp!bfWlt8UWYU|hZzMnbw`6Cv8yTn#EBLewj5q70d3RskmkP~3jL%1@;mRs zdGX`tjLuU)voo)D{$>tKdN=Wbj;9`aeXZ2w)#Y{Zq0(-z!e{iNtR?kXDVX^L(5eY@ zs7VKPwlMzgkt&zM0ox@|;FHIOS)Z_vbL$EIyZOw>iB)smdFc+bW! zFPQwYDWj4p*3{dbq3tu{fkLhM7^^<8BAYaW z*Mw-`LU|6OX8OQqP(Vj75s5BArSB!6(|z|fWcz5`ioua200Q30sxIWG2_=`OT?ZCD z@v_xo8**N>(+pYlVCz1Bf^V180LPgnLHn_n6e{f$qjrxgWIsW({zV3-y{s?%QMKl1 z=+lnN6S}000QW$~%>+JxEX;V{$8s(15%&ZIu!x>rsc>C6IjM4SWNExh89U@AXF^ay z;IXzk!h$HNK#J_bzp||iCeeBjV|*Cl+=DIo2wo>6xX=wZ%8VG&muwnJSjQAJe$qNn z+KB}wAEt}y=w=swVsQ^o{=^PKf-Joy2Njy4*;}|WFjY9x7 zuld$`usG4p!y5Q}xvnesw91@vEV{E#q&v7m$6RZSmf}FT;4=V`8gg|Lg2U^@YKn`#h`c*=g4&=ZE zC849e8*2$9yFxbKQC1z?THM7mcL9%FfqG8*mr$2*`pQ>RF+3&aS3m3f+@4q{L1+mC z+5p_mI?@)yiA;dK6sMot=IvGWhy$MyvsAg8%a)Q1I+Kv(8bzkbmp{Db8q($UVu~N@*Ii&37K7{M)0abyK#15Pj|=_;sV{);Jj=By^BrE|2+hNEibs_V&|bT~;THWSHA{gMkVeFW%H* z><_6pFWniadxW&NUmvc}RJZEQXZfaequ$`QE;62iV{JN@mv%O@Q(B!J@ zZt281Fh+g*PAb$Hr{9OikTI6v36_1Rmq<~hur#C5m88so1Z7XXF|1l8V`{Z=`nO5| zML4gXfbJyGUfA6wTaOSI7-nAi4tNh~yTfNc0pm|G3F+F&*w0Gy1pJ&F~%(osSZ~?t_Q27ljctCp!M~7>HZhg&!6YJpek?VRj5!0Gq-vj3QQvNFKYyIK*!PO_3hEy-2$MP>cRB2AkY}0t;+`6 zaUO&?%%jB^R%YuMP|r%|>e=vOp>u*BzS(A#jEJ^_^XDM-eTl6BGTL-NWYzG5+;X?*0J5gA3e*O3tks?BrEF zjGV9#ZE(uPqgX#IVgjdDMxnk z^9C$K(bY#xf>oz5y#N>H#;v%H^h!ZoO8XDEZD3Ce&>xFxkA%94}q2rDSN8Tg>@ z%wpC=LGb&2dT@aPfbyUV+GY??#VcgdIZW6g(EfxSUI&04v2+4N|h!eRcWxioW!>F zqpdY*N$c)Bts^nDek$97rhf=dK(hj2m_OM@ipAJVutu;y8nqb4Ytqd4U1trd7=Z%T zdVTXJp|&IgiqS< zTe#Xm7}mU`OwAE!KrNLo+T^_mDG*YxSVOy3ujBOKq$-1jS7TWc&rF!_zVA^^-ww8M zq;Ul1EG^1|i~H~ogb<08oj`cD*BHI83R0i5pl~O)n*5{!0e%sQAaIv7 zw59(xrDfXaa|HuzGiuuj$y9we9)T{%E&x!xW6^97!k2)V)uy?3AfPcb>B&v7ag5>e zo{e8u$3=L_HTm#j_@HR)3_8sdR@OT10H_tiJh+SKbZb`7WPbs^F3iDa-znZbw^$)& z4hHWZ=CC)NmdoG)<{feA@8ekA(&1}YlW-Sm(1Y0UHUt!x92$zrA67AU4TL*DgN5!$IDu4mrpdCPe< zVR1dPbm}~PBAbjo5c=LKvz?;8A4y$t`(t3O`8caNR0VjZh!XDsT%iKZGkwwpXa=|C z*h#mhXR^4Yu9>3Tn5i!Jq+^_V>74Ia*9gjJLfS~*pt-Cx1sVvLVWtX!rvl|>J=}we z2xGiosMzXum(-y4G~5FsX9Y$YL_glWhk;s}Re=cln3WRzVqZCy=HHfDUKF+6tU$!8MYX=n zyxOpOA2Q~PX#d+9c6-n1USWGV2ewoe0-9e@u(mkJ6MxjwTJuJw-Q*+A@KPo4O2dX% z?Mt`&!K>^1xs=zcUlT@neABh;uYf2U?{-B2?dkv&9h#;EjbhM41L9xQe7xsfi!Is{gHM! z@Q68v%m_2Rq7ASGxT6EpQa4{CeFno5D_$w%FojS#)A?=LRX&x!?a=cnCZi)qSJiWak6P-eabDps%AR#Ss?IH0zdWDNjz zt{Y|R6~S^pB{>_&s+=_P=uUY6l=?g2REd9*Q-iyot;9%;taKVE&eo1N;mSRLOXe+K zd7^?Z)yjF#cEA9J7ScfUQJ9njz;8O~HRu&b z>m*FCLcvVN*e{RGLH}yKlj}x{9d+9I^3I-mteIDYhT-Akvy@d22;P@{NoW>mI4Fs#sT^$oe4g)|;`!ZsSKlnU zz^K3P5iu!Yj;C5w&o;B}o{5s6`XYGwfvHN1H z&aOWv3Ks8VG)*_ReHI)lhd~+DL5mQlbYV=5U)x;r+!fQyjHm>9ar$*w$!Dn=SV8AQ zg9}X$^dw((igUu7U!P72_y+s+hSO0OdmUBrAXKWV^v(CMh!jz0yFnf%kt5h=&3%1! z$UZ7g&;2UsnOpkP3S_$3zF8d=ZSM9q16a#)W_N_$*UnBcFDhEx(2KOp5>`mthA_+H zjunva*nH-$d#~({Ys2Zs?%;5Hi0l3mzi6cJqF8hZ{eywZ{7*71dm zR6cSUPRx6#Ba@l_$bDRxe^OhUx}ozPGyh1 z<|24yT!R-rpT-8)CzGO9Qh`n2l)*hTknU4e7Nr=Cf*Kl!P&^+=GE2#;ecPsZoIe#H zW^VCuOwly>Oa5F=#_*orQ<&gO(PA3x3)Hqa1y&S)L50rXLp`En>?Sie3r&BZAJZ4+ zR=7288~tN|hk+xVO$}QA&iIg^B|@~&Rbmoj&xx7}iry*7`65rKlF(tS@=tt9#<_=j zqx38v?0}=nzg*WmW)p8Hw!YG&F?|fL@bwvatCaChi!qE94pr~0Qe5gjmWd2FN8CwpzxffB7N3Oo=Ou zjQSZpEmw{p!?D><-YZNf7C3fGj;)anKCf6&*LGYF%5~LdM#NsE&o6v?@)qjJw2@Yj zyfz2@OtbC{WHi*qPsey3W?lKJvapcqo*1C7-I1A+Q81rYu_ie3nrsNyS(NH1za|2K z)i4O}Kf~sKsx2hXS1d4{aQ_scT3S2R(z6QFT~>uFe{Iey%b=7+%fB?2$pStXIbPHd z$|}!Rr_tyZpM)57p`x{rwW)xLn9lA+IJ}f=&s4@F*+#3RQ7|N7!Va}He`U{s>oF3z z&tpI(znpg|N(>#zbyF(unf=wi1Nj=w1&AnIwgVD4CTnbxk`XCH5R@ z)C`hr8>o$k9$HQZ&eaU>G`R)qg5A{P%!tg;YBNR4q(}$tAZ0rFG?E3-iJ!TXXE%^< zrj0*W-qfo4MGnQCzL@h-G|ct;DFz=Bw;4}9-O=&@v)AfL^0DaJLadl_Y%VnG@;Ye# z#I;h+2i`H%A-r_P)9wa-1#pzinZ4y8%g5E~tM$}US#0I@AqKIs9O<}%0_!p9MDi*5 z_4_Xgk{3o)M6;2pUtaDEqxgZqb6UIw1rBZ~cvABl?^w<1juqX7EP6Ds;{+kg2?5e7E)v{_d?)saNvC|R{whYbOo>tF; z(0MkP?t!35Em`_#Bf9aJ!XqYQJa5NJ0o*v%I-T7x&{}U1jh`@@{eHM9vsI6gyup9q#>yS&y#m3aC3z~gtUof*8hpoZr05jd3 zz=;yF=_aT-hx0^}N6Ds*<=++_OMY__f21rW11RP>5HZ$9SXcC<=e#V@gKYr(k1Q{> zhHyfOgx*BXI|XC-t@f;CpJmZgAi3dZAsnJ4NFGJC_X~DJ%i|pRpw9%? zV85s%jP9kBm(0MTbV$Ca+1XZ(_dnzHo+~iHN*L9&5L!0HCH z#ph=8n6)Pxd8PMOb7Vwy09Pzdo^Q$Y>(H;C0>i>XD@rXZqmF_3y-tQY+I&Lz3SD<$ zU|pROeVcT&F#nzw?%vIX5cPazwu-)Or}^wL*=S(~yJpFjcF%=Y&q>(#4CML}@}-1e z8Q3zubv7HPBU9pT*tN)v^JV>B1gv-;5`jW)B7dufSI@g2j; zEy*fN$g^n`Ebu!_bXSjALS}9CQC=Ln1HIM)BiH-hf>3KWj6kCmj!)+7aeyLEvM66a za93?0@(P(Sw=~3kIf0)oOL?GdKHU~Il0@>kwcerO;EBwp3u5?DWlHsXeI{?7hKrP+ zGh44F;v^@yp*WcCwLhMd3fP6CHjv?ji11|iBU78-+}5`@)cKr+(;gjxR?A`Z+BmvX zm{faJJ3(*#rvTaqXe>q7Lko|=K(K;;KE-YsU8cSD1#ueM!R77Ky$};bpUkirXzy$u z@((X%?Rv0lx3AmK6_Up*UrF4%+cS7tqUp~B(r2oagA7G+w(tFgis-Fx+H;_tLbfUj zGPqRWahBVC2FAF4R)^JnB4mx83ZpuM#e!1N2}8nd)6wI3KUgrq^qOH6;@8YdVLr<# ze5UC(&&@~+bKiQBJ+EyH5nN}x3srzx7+|xYe&6*gk^U28P;sl7Rss5eD(Y+Jo^mD8 zh4!}~ydPfftGFOKzOE1(i;N0fh=?eUmf}O)nl>IgO3sA?_7?-3dBudW__TH0c$WOTX6#}RJiBE+xlYK({p;9lp zZrCbf$2!=6psk6OdBZUMvtPKjC$ zcDi_A$=Av2(*~g!s%m_vY|KG?)A{Z>zL*<2=kFCU8fQM;y1KQ9)W=?l-gk&sbEWO& z?>6v!3lpnvy(1Q|z>rQ(o^Fnk%&HapT@|N>6y0hzeGbB%>6P7wFT})GLGz;_2t+Wg z@@^Zv-8>8HNM}9u?B6R$kP)za<~D7O58UmLAH|>%H=Eu6gJZ-?8RJF~FTr)r_v00%SgiXj~xZkGrOoWBX9Gn=J46VE~f%j=H z<2fyZ%S?7ZWjPdB?aZ+9;j-B?s`IyHutbaO3)M z-`#AZ_C*nJnB0iHTXDx9Kuku#amNqEtIE%+SCz9U4 zE#2$B>0RLTqqHy;5AY%_%k>55N6ZvJ;T)p}d00+smh&A8P|*Lzw!l<7X4}?O7_L8Q zj)dPYD$a)0_V)HBtka4Ac&U04Pg{88&%n9>Vvl?K${j#THEt-)Z7wMKqKl@oY9I8Z zy@4{0jCaCN2S?dVb=)G3cr=!FbXu{X2v~?UQ1HD`%95A0xBR>g&F-yp^dVaUcbpf` zam|dD`_{#;c>xe3^=ib3dg?*75>0c+Ho^Nny?y(&b5S><9_pE{eMGMn0=jh(C$cpj z7H&<#0ix(!RiaZHQxVM$#{csU@gA?lP%a2(KwWt>l!Yru3ZdMl=h3K2L#rX!uNMy~ zi7i<)+pk_YvA0|W^Bw8@^`DOO+xVAKw(oxijdRiz(9W@f1xj#n3 zmcCKTEIz?)UA``3`kmO^QLHuy7-0dlC)fIgxh(E`>bXen?dYv<1HWDm)3Vj*9UV7@ z#A`DKsJ9^K@i$0(2V(tN<@Ybwb00j%cXp&`#eUTGhVWQA=W})SQMR1dr#3 zgpoF&gF;ceT!hA-10wyT)E6kE-t=Nr!NnLrAd(@LFyNLr(*s%yC-(-ly0aVz?&pWN zPBzvLtDlR5o9neQl%-U1KdZE~4&6T5%dOX4@KD`5ramOc#+$I-stICMS+dI5DF9hl0Ev6oS{Qhq>2g4O__MYL|2-*Ty;#6j<{+#nnai>gr=(aBN9*@_OzPgs(^C%4Y@ z1~{$cgL#&sONGfO?F@E&soeA=xs>AcB(|oKH+0MyNGObAYqQVm!1h+?O(!ixI)a;S zFNY%RkGG&dLZqLZ?q;3TfmM;^#!qWcE;*#?YoBP|oIQ^mKmM)V;TVK1y#;bE;d|H$ zST$_lY2Tm7x!kTFTwXlFFPMG!f<$xL?tp=Zn}IT;BLOh$SbtSqbH6oe>;F@PD{jqs zkxiOWbY{VjVO8^ERiZcTRQ9{WOsjK)P916*HPEJ|!P|gcF~Ye?r>N!o9q#-Sc0Z8t zhEc=zR0Xsf_A2;vGqkxeHGIF@VVfINCgig)lLJ(l`*g zV1Ow)XM2*k83+)Jp&tSonIOT1c2z@2sAMK1V^4^>@`)N0tvNuTaEVZ6N6;4nL0{q6 zP2LSJbO{XFd6fh2{@C4rAG22>=Cx!NuMaqNuBo@f&xGa$m%n%p13@>XP2Fq+pj%(W z*Ciw3Nwzl&X5yZ$TgvLs0O7~@+7dHNy1Xdsa89dKFz%q`cI)(GC>?8C zz%Gy}wCi#3*i7mTul`=2 z+CBAjSUENm)Px}e8_g|T>u6WSVnN(uFEIiwokE=`SaX05@c!}i8k#sim*JH@hW0u61J8f^q`lU z_`;}6`}iSoixo;LQ5QRiaCXa(;bCSlJ(8gL_0IR-tv<;79#s3(d25I`-ZECqKD8z< zFyFR7F~1;^_>tvtWaFJu!?5IfgKLL;Hh23*erOV0*`A_m2;^5UbI5K*kPW1+YmHJ; zS}FsF@23P~lO6o2xZ{$`9i(w8KY?CsE>HuaOW%M>m)@pc9;7}mHv(+${F0QPBZwcO z)U~_T;1;6aPU4b~GhNtcrA42cv*ojIE5z{ZEeC0loC}lrZF1^gKp@M1{a4dEJM2|` z0w@H!1B@aj6V&Uzjvs~yk=?Ax{2jax3NSeS02a50-O_GgLovreOSygTzV2Z@mkie`=_8${dcpfIP+Ta3(2E$qMdDQ;1Oe>2lL zrZ@=&x~GU=lbNLX_fdZD*}qaTiCW}Ts*}F^Gkp2qeSGiK%rL|Hg=a!z$$|4%(p>0l zKXY?K%0(q)XA)@$U2Rj{)@b}#enKS7ERW5;r{NMwpa+7OcCTSm%>tVvJ~6Nl+C!I? zVWP2SE~oiZ5bsKRGQ0VvS(RgW_LQ%|Zan4|3RD^axnme0VQ^wYC)!`hdpvgYRzTYV zC@vamn<;v z@E-aNwc7R3v9W5EV+^6eD9Nn!>(v)Fzc-jIAT&k;r~3`rg5(B0W@#COB-WZhZ$?-U zERVEc?37NfEUiVh1c60UMATS$Vy@-Z%y@-E+|4qG)I_ z(~~*9ZOv{gK|$P>iJS4?r)4C9GpUvQko&X;F!jZ(V9c(K-{ zCoiqtV+H+F*VY}UBOQ8KfNn`PYO6=P&3g7wfJL#g^!NYwGH2g_^`xF1D0wD9Xg>j* zK>Oa@{Q02hU#F2_>$gp1Yo!>A2q3!jrRFhe|BTFf_Jew0=fr)OFG{^9(7;n@5{Tp}!{Daa%(A{e298V}U)K#)V?vhpMG(`eg|F328aTXtk04&h zNY~$YV~^xjLKysQH^Z{T$7^8s&izJqC)#9VGGGFFuXg*wf$N#Djec}Ce%NxtCn4CU z5q7XZ+A4G29Tq;jck2QkN-W52Lk`1$&^Bo^47FoN5$l6Hr%4kqA`&_%C&HvR@R=?2BS%vsYP?|n80_x6H* z&ppU1I*ReIlIl)F?WXC?k>G`dkdlhfoApM=3Nqj~yF~H}Z#lC)2X562PIN!Bt-d*a>pB}6FT*59S?XX3?3qT%gu-q0=l3n z{Ah$#i7B8vX%3-Xd*leqAlc6W`4u75sAD^H3uG(lBute-s*(f!e&Jt}+-XPAnYyK& z`Obgr#RG{&*f+WQr%r8450RTXf%&1hd3TJ>+8)@KbD(KV&XHGhOy?r_Y;(?@iFv5X zQ-KGz`N}eiB(8ZHV`huQ$G0MCNF_W8oQYldF5INs_DaAF#yo_pvc7eCymiKZUn}x5 z9JcxPUWLuAJIQnx`sC|CKbSC-0W?#AM3z!+^`m6lx80Uy#G_VV^ypdtWPP+7H@JaA zPaCgpduhL2UoSY00=}{{ThH|S2WrFuo7_C)w_)h7O2Wq|uCq&Tzbg0)WUrYx!$1egeB@5nCM%T4|Z9idK`*j%wlPF z$=NN7`}YrwL9p0i?0%A)i~XO^->VLM3{hpE5VDv3Uv|@nXYjrq%x<3k>kc>eHf%!= z2QpHu%X4-3HlnQmyw28da^b0iBekgx{(ZAs?DVQyt-t*H2NMA>^GhZm>p>6oS?8=1@3OXb$ASO0J51!hK)?Sp;3gb9ObZbXLPtnZ z#gc6q(|>=sbCeXzN>?yC>aE| zO$Bu=N23aDL#fC*(5U#_n7O z#yVJML5|Fn|Gmzs2=0lF9=SLxdRS3?1)~~1!>3fs+q2hZ~o%X8Ws8mAx>o{Or&md(iJuO412sxz^ z(`s(-Sxj(8Jr(7~4v4vj*pVdlOL9Bc55hMX8}`?$|MTGYm}dO+f1MOtDsG}^Fs-9= z^he?!B%zC|Nau&GQEp)8Bnlb+Hk)so6>%;1%{lS)5IYsy1%*viILPU=s z2S|RkFkE=d&Z3}0eHQ9v7Xc|7%Qv>a+{~61%+0ZJEF^xw8BkNo@PGN|AR-S05irTl zH|Ie}`ijq79m>gXjN(lx9{wHNKaj|&1SpzxR6b*&2y(YNAYG3M7~j)xW(&H*eRlJD zGLW^aT`xWqvg3L}#b6zn-FLqd+wnH9Bjw3E?H|u>?VMbu8Kme00K+xjWY`tL6$=^G z6x*SZfXlE_BI+HlcdR^P7VI|XjCAMCSN{9|N@T*uxu>QSw3TFI7(Eb~Km>7(E2TqW zTkO?8pA(YUJWjBr$Ke5fB)>}BX*7x{!B?Yv6hSR+o57i51XJbB8ou@C5`ud`J{e$+ z^kmy{(Bfx?zS}_|ONT%A`dtmYRaP(`=eDmoK>7j6zUb4;ly+>80AYC4{^0gUTYoMj z8*{zc9Pngr_qAV#X*~M=tl`-#_zGwIK;(`)Jdl48wjPuAv%g}B|1eP3^D7||je{UY z#04C~_wE2NzF&hIf3C*Iyz|D>;Hr9T=6XACoDFXLMsVBh%^SaS#0g4Fh7ibl zL8Lq&47x1WLu~B6D`tV{H||FY>io@#1tnffS#1AaDVfJ%9-= z1ceXnSG}O>Z3tqe&aV>%ckHi!JZ`{A`Vr@UW?N7T@822OE%&_|o^Py0^BO@U3#I^+ zEtcywDvkomp?jq}uThAt5lBco=@|cE&0fmG0^XRL+Kiz#KA;U-uY-Wyd&^$EBuAht z=;x}v!W>Bs3H<#l-t8Bl@rMeEVV-!Q9q;NdtMgJCu5o;9>dMaFYa>(^DK76V??#Tq zGz-9@sDrAzUu}EuV`z~^%+NQ43(T&W&4~0kA*`Y|p)fAjtS;k7d-c5MoJ$_o6)Q$Craak*3aHvXmy~Fg+xWNQOsP7~lfh45F=FxBWYR|3U7lW{~IfFh<%vg`g0-Tyi10a{2I{Jfva zuH9tV)CVW2Pj2i?#^K+Ho4}yc8?Jf~wXXrX1&`RRcUvXzf4?JA1FkUNIr#a19WoCk zz-ZONYK~!3KbKJRZGO8vLDK7caxjzXEmyo{Vb7UFg|NVcuhc|4&IRO&hy^qbku-;B)*#Jk{C(Ex>{I*Td7t-rp6k8d>-xR_xQ^X9d#}CLz3z1npZjy~+dD8} zPK3!i$}@fGimh`sc-LI|?|jX$t6$hRaugT+UU;_o)E~$pyg+($sKqz(t|7I3 zdgb5*r4@D0a}ffmUuoRHnT7ppa7*duma}PFS2X=C*1Bd`bv=CmbeYRm_dS+wgy`Z{ zHMPjc?iRw-%QI#o_*MD0{F1%2y9N@ZrQf_=cVi+3Gox1qdffQmI(!w+c>w!83!aH=K0nFF&D@@*+pHhRqAkCNqcnzW zVMK9eZ`#E+683E=2c-5Q>Kj%Nuq?@>O`R2mILdP*+Q)t~@Hai7>!!P^cqZFFxuS=~ z?vR?#Lyfa%KGL=E+GUE%oaj62%GRAAib7WSg2yN?#BUb>Q0HOaC%UWZnnZYVku2O& z1jGo-aT;z9Z#Wz+hQ>zbm>e#4P;`6M7)ijiz_)n|TjQ45F3?MRfakzA@@7qy;;yVGroq0e!@&gKA_a96Sdl;OHWa>um z;ENn7XGq|+%r&-OrY4j%-ApeiywZpF?@;bz&p!T%qd8SwHtjU~=jnO8C`TL%<`JVq z`N-#=t+lVUt9#s%5cWi$tcb$Y(9=;UlwE{^1(B13hC3pHD?{AXWY{$7mWNSmRg9f& zXGy@6TR2WeKjsSZ`0(JewKQFJgv1k#_NGzs#GaKfCuxuna}R1=edqF^WXAF?82b;^ zqNw==XoF-9PDa0nHRLoEtF=MaA_di@dJn>^Nd|W^K~C4)@yMYdr&8_Tl8wR=#)@=S zte+OjyWPPP2H#O6*Fy>q`ut~RPI&}55tU@V`NaS&dHzlUkT@?+j$FATQ#TPp+>_5Z zW|7tfT%<4>poQjz9Wz7nC7FMwmvVwu0}3y%o4uTRCbMsPiZ)rS4a~{t(Kb;(<_x%v z9{n<&^_Oh01`cCI?VB{Q#;LfA>eP`c~&?cq# zj<<6U0UAJRz^kd8AzU1Sh?fM2e>9fP2)!i*6MUVs>Fc)=I&p$J-_K5SI*XSN4@@}r z3?Q*7DAo!r)Q;!|UP&H+55wlF#HC-34=4h%B>2N4YDoj;g)HY7>W7J;4rU8lUh$WO za}j|`?8wa6M9oakYH|m;;4N#5+e!&$$^dgsa)q1Ub}V_?G4oc!`b%@DF?D-<#VS+!; z>-dPkWf8b4LorEHke1*6BIcI|aAd|<5{G8CIRvW7A4U&KS8rFyFAk_n)i_Sh_f+gcN3URH`Qif7tF&ZXnmL z;itSks0MJ7bKo(&S{ZyG9=H&F?7|SJqB;PhWJ!Td<)=;Ef{m$`M}~13PJKiq{SFyS z>bKq8W8fTf%!6d(M$(lzF zDfc0M`1`5fJV-g!R=cs2kP+G;q)$7@IFBG79y;SBZI6kYL%Z5%qzq9Np3-5q-msX= zV784tvgyW(WfG5Di*7Y%-be1Qq9EGz)#J=o04h=jul&Ju?S!s!>TYo8+M<1=)J?42 zbI3Kf6+xSTNZEVBMvHB#F`Rgb9}Pr-(ynz-8E4V-Wa0XVMMCVqvt`9a5_|q=)`x?A z7m_kIdd^><`zJJdfY%xTa0#&#D>%q05K^5?8w7ijv5_6~Rh}Vixf?}*^7bf?K8I!& z)8X}uMiO6dtApWq>4Q7o_{b&h(HlEyAZ@l(uO`h7@EZMX*I*XlYe=_4?1OzfAJUi1 zw*l(X>V~H0>_5$5`Z+R0Fxu%H3{M^_r6&=pcbZgcJ+5~Pdk*ycuPV6c?k5YpkeY% zPGIq{sG z#_eP3`&z07Cu~!It7iSW`tWcz$^=%i9i%-k8m!a}UWFE{~i)nl54KS0Mi?tK0spGkNpZet|`FK`_%WtvOR=OY-`a^_N`T$Isu% zAP~AedYb%ISe7|??xpL1iBwMklAuR4-QYN=O|V+1NPi?31cmi&#y~!z@^rCdJT93a zx=%6|bf?NODbK)X3WxPMWEDw9 z^odYBfI4>ToH$`BT}T7RBA#>g$7Q-|-m|0E!MpJX2M-OXgn_&@eDA_4!bE2xRuRRI zM-;;3g9Cc8j$OSM@MoRPTK)-m&_l-wlI-XI_auS~ZHAzrCm-pxXqK&EkPKLNMoH6K zG^|h%XA)H4dwpQ#np*)B5g-_W3+O~3{mJsfI_I#z-x9nn8k(ps%7h!R7PJ#a%?Z8! z`0s^*yEuSa-r>YlQ@Gt*9sG2lFc{KV?l{P2&pouwLbynu-KZ@8K`XGl2 z`=QY%hQ_RiljHX8vo~D`85?n)YtAVLI2G;7wk%2ULsXPj-G2a4xU?)wQWHP$1E?|y7TMPQ*}^6f_Jjf+b|a> ztcInVEvWmj7c3SCpH%q?o9WhSJpj{gEkiu1IJ<8rhWvSvZqJl`Nf`8x((-(J;P;CU zke|;>m%(2)ipS_#2I*BOf4K;r!RG)7c~0Et*_;T}R2`^><0?F$z_my%#vOaG?q5Xp zYYHILV-swf90@?Eu`fV~u8i<lbY-H9k6rF)jtGuKE(4hm~WMZ_XfhI6Zn6#5d7@9K6EiU9eEIic<~Lt)mq zDO#WABltuBx=ki$It3dm=Nm5J+Nhi*;^n)dgXyjoo8RxxgKZ~hpRBG3ar-miZ6-do zqm27pjnU}GILojbHvcGAi~?dWZy&S^s6w%Vf{3l@A+lCFo_!~?&r<<|oreSAVAiw! zf1doBzeFpVgpE+?X8Byl=9AFi$%m5c2ItsD5~Ckz-O#Je4zn?^7vCv^{AE1LUM74O zn=B9oKQ~+DPKkYL*?LL{dcHUZ;G&S-8x6kDMnv6&ECsV{D4%VDLDr6R5LJA31cG#i_H`Ko))kK?26tU~TwS?{@mCikk> zV6vVPdUNQ{`;pgU<_{bf&k=N)2Vzs`wEhzmK#@W>BmbX4!boj@Y!zdzh%rA&v zl^&9CN{5Q!xeyF&M^5fA+Rc9-<)BN~!)-g!rfmT_c-#u9MipHh=*7oIM%!Sv?1siu z@pRE{=z*VM;S)H9Icu_xVMo@hY+OX!k<7_M(GgD^q@G(BJwx(S^7M!)nh~{qu*55d z!Z)6={qBxYF!`)mSvo`uNR}>s(@&BD0a4Nmq^CpNt{K=q6nR1p1Egue}iM$Bb z>O5R9lJd)G*Q~N5xV|C+kz9tvl8xjA|2t3mq76x)uD>7GF5d&_rudrV`GRznlI{Yj zS+g*ov##}`tc;tOyY^`GY*!1pb9sB42k1z=0IgoIGwJwto!75=cW6Ddft$Oc$6)HE zSyN(3Y|*#Mdo5^KUZx{S6St1`;7l8Xp9x~j6hxhiU=vEdFNnLd=NZ10>`s5-HedE2;|!Wu01KC zmLxVw0I?d_E0T~TI^bkV7_`F;(4-RRL@_Dx15`%(@WDz2i^qhyMt0py&ibN=$FJL` zgJooAQAW13d;?7*_2-El1(mRQ_U6AivQ0kJ&&BR2uk@g*Gko<)`C%KW{u?ZE^w7!6 zo5fAi0J0MpWc0&#oeYYUCdswuy(`MqcfN>&W%1#X!mj4%TpU0Xt8?hEBHF|-e6=w% zrLH!5v`rwI&I6GdyK@=Gy7l#A#Deg=>&F>=P>hSEG$YNq4E-k^6+(tSv?dK1ix&C( z+ysDUK#4ymPxoNUG|qhd`hjVVP2r|em_QdV#v+)_9{i>AsUa~3k$4wrQ#&m*SIv+V z0I>=m0H|=Q)D4S*6_bnvY~IfM2iZB{dpVKb^t2Y&TPHwx@;7E%aXonw1{68<&PFP% zBooAen&&z9K+SV^jds_yJ#e(MN>4JbkYKt{M`--ORlW1P9EyoZaZ0ACL^UD-%i*AJ zC8cx7yU*hN8l2*G!VC!w`=)z+2|(DphcmuqB>lH_!Dc%8q+T5IP?+yAfjx`xU{Ygm z6EMKXf`FZpI2KjgIelegoM$KGbz$JB;?QI zT_{E>rvY4boa?l#>>6QM*lMS>0$CjKg4Sbky~x)Xe>(JqIrzmC zn~z(|AAolAJet_%9C%&n!SAs2C0ssg{jI%J?>H(>0oA zuTHy^OXsfimM%D6^ z(~#>zpAO4T1ZCcUwAq{&wt{-8Je13&>aQa@_FrUiqNk>czw<^|}t zy$|P~rXvn`6(NCcjQ9UQW2Rg45Z**TwTG3ao$W2os2ZM{QK@U3kmkL72`N}~b*E~? z_vmF$>aWQg!t9s1Q?p9=n4lu#>HT?jF3i6id` z5u@4bdC|{F{oDfw`-y}YA{S@}VwKAW1SsdB==!!%&OLd1jf0r--D(#o*qpKkjypFs z$=md}<=c6%8e*$Vn>j1U{ed}R6hRnkTck;e)mi}*@Eo-|w~=ZTYuCzXMV9;ptZHGH z+u}#^iB$?n%Dg=HWD>xqks828(r8+!Nqf}_k~jB`@7hO?6AAtG1IKwXVqiLp@`2}+ zCJab4Z{_aL4p7EjaCYb@9zFtX(4W*n{mv(>B06!F2Y$qymLX`ZrN{>t0ZV3D;j+4& z7*^kO-5!96K+*pK6{mO6!Es-m{o-Y>?MHME!ycF0{%WT06vm)|X!aw7JYnF3Jh0h~ zLzj)vzBz4%7~f=2c7hngX0c2rh&TyH#yG&c&ecR0t^{fM_WD)?+n4CtH)WAZ>o;D@ z7Lf#?Yf6^MisB24iePSA*W7uNsEgK*q|Jp?@umhn4Qg&VEA^08s}B=FgCz z37p}uq+$f7(|m`j3iWeKb0(n*U_ofXXjQQJuh7ZT0MC%zUi_zLr#NrRw+GQF@hYK( zE9r0Z;|4;Kc4KvEkLkyJL$p=Y;z1lAI(-PU^tC&Qzz*4Ucf0nTk|0tI?brRe<#CH! zbq;k9VE!YZ(X6g2diI**WFbmc;qfEY7IW%xS*!Ch0dx6%lIABVk~YvuV;4z_reEp>SiPnIR`tzkaXHvOq+pN*Ed;iBJfZCQ znn=81cqmocx2ezo2p4+NRj^x6j}#F`ATPwiYfA6Td+ShTr}G=DawOL|3xCSn@P>-p zu+{G##MZGxz>jt!wq8ih?H*Cf_uEWd`C$gi$ zV%8FnpFdQ_nR9CA2+ASk{Grc>^d}hDlwsek=~`x|J?@pgK5ZH2a^yhpC65F5^W4Sk zg47-H1yvV@pdI9N;bW~|J>J)Vabv;$bYEHqw$Gz_UZI{8E&6DB(2EvJ$f=ouoB9q&U<`WCcG6Kv!rPwCLIV01}$La6ayyZ=s<=`29YF^>sOfEk1BT7`DZ+`;>L zj85H`2o}hy-Kw0|=KjEJM#9K&q*~5Du*CVK^@+%1J8gn$iW;`|H+Z&QjdE_0jQB8|JT-=N%sJ7bR=@ zx=Nz``|!`NV~>8GuQ{);!C2>c8ZmPfyB)T`VQ`{{E$N`ej}k4}oq zE=_JJwLk;c6vUiap{j4qgf9?IR=#_fbjJ-6Zh#y7+Uj9c)H(|bi=Q9+a}*NuGsbN@ z#+<8u&W~V4WFwU68=E$@3_0HXf`VeOr8(iWo=ebKB*h7X2?)(1LIOKXctuQsUB+R! z%okO*j$uDjR0~Qf&%@oY_mstbeRFX&`9|V*+EtEOKt9BdY9&y0QbtCsS|K|bFuHDJ z-Y0BbMNNw(hS>U>%`v28Z4b&f#I+i>&Evc(`5;J&|8pC1WNBCXz8Cn^2(Zix!BV02 zHUzMRN3Qt6NH#j44I0d1)l5oFzhu#kmHE>x<5^J|apN&XU(KdcTmV}~QWMF- z6nS2Eq)nM`44FvpgtxS-qHB22pZNm(UUzZeOsxS4d2sZ5pXp`3hL5PsCPP^8bG-=RCp6ap{?R(x-p$m!cL(~O#2XJDc7~%GSRX`7F^btDF7=#! z#`?fOu(PzbkMwm7tpP3e1K+4U6Z7Hc?J(yM4Q)=~~jeb>|*g=L~ z9vmk*@MIDiFy0CZ3Mx`UK-Q74p%+P=@0|%$x>UGDpR1%Ng?mD!-zt0I_C$%9lcUsg zt05<>>Z`;9pZp08493BCFiAEV+bcIEo3VvV$7E3g*CQ?^8o2uI zLsqYIL*%i9nEgv`XL2%kyLd?M80qt{hu!E)q>sX!;UWInpvsL2&(y1r$S4N@_biyO z)#LW`&IIE-)RFm;ZNq7nbq|<$g?1?YDl28SXrnhJEfdA&8|OdVZ)z%OEO)G<_1Xgc zb+;9RUO(1~I(QB+RY{ZB(yN5jEYkK=GT=U!?yNYEd+A;I$Z-$SFEen6 zK8%cXGeK9e)g-Pf`$@OS2v`Z~nLnYM|kDwDtxgBcsJjmbky0!gT)kcafeNj zc}SBeQ=W^%Z0Bh{oCUOPhSoW&c!>(NC59BYX0M{?o$W?alPt2sgck0h$=V*&wN`b$ z;E_TB|AASw$D{O$TFMcP%`TuKpU5+5*=*biO*yRopR;Tw!e#dmX6|7=VK zUMIHR!~ck_=Qz3110YU9rzloH%osOy-3g>`K-%4qpI<0b#`&PJ@x`*TvijHol(MHy z5^Lzp4mj1N23KV`I*%jo)CWuP^q0ua+u5x^G1~j=>}+Fu`-C5$aeV;?NgttNe;+`B zwOQp{ccB>y)|2oxu%7n_&tPti?12QR?d(^96zbk(3X?P!?W!&JD*{)eB;oq;E>h#7 z@EAgfi|#Xx+&She*N0hm7}HMbHVNoO8me`Ckn2>@MrSpl=;b&bi&3QH4E8|PM>n$N z(R4y-oeq2P!2CiYQv}K}=P(#87QpbV77XL!(NUg{$jOP}+VOu{a>jyz^}>z$GFd*| z9p&#cGDg4B8XDC}>KsW9#3bIU(qdJZs`Gj(mc`tJY2LeglOvZGCDtFEcSrqM{00h; zf}Te>BK(yGkC;53wC^^~K?2C}vXac~8%V-S$Rd3wcLC6FYmF^cPoq|aA%3uw@IXJj z8bWb>^kWJl?BaiDmLl!A?$7gygj5FTU+0=yYU#-yiRVddO_rQI)BI=tP(`W)T7K@E!d8Xc@t`Ni3g&AV?kRb8R%Dd#6>@8B=q_ zEBFRYrh7P8yzf4yUuDs9Obpn^xmr2>N8^VrDFiz&5zlXHH8q@bXBT~}^5l`~M)Xx? z@t|B(T|>XhB7#v;+rqgTIhHh`DTYp}dXPrSHITlocwoZhh#^B5D_y@-O_sLO3|(?n zM#G+dmBk$+rIf~*)y!32a8NsFFnAjduyP8eAqYAtU`(bN_ z`{T?St8tv!<12xQCBEfDCLPIELgKTvW3nl1Z_oKwLQJS}`k9mo<$IF{O;OJ`H#ct^ zj@EL{gv&GZ26~nOo>VM%wI!_V#fB|J6Oa%pVeTWeEq3-D$k2?_PjRYgTDgQX9Re?j z!y|krkRFqrMUr$|XM}P{D|<{8Qqp@TC4d63Gv@FTM;<8z5LSr^jbNq!z#k$MIg5^2 z?tO3aU^E1WZ8Q1-`A-j+H{q|x zn3@-$Zr`?Gq@7gkpef}GRbkXxhh4fI;E=CCx@wNIISR4_jKJ>u! zuKO!d&*~Wgt3*=xUN&{rYV@qiNeQfm0w~D86&o9?l#%s`F#TToJ2UE-rk2*Cg4fx? z9URs3>$?U5@S>KN%vCj}KwQw^7dOYVwF+a#K*x!VBv4Tdz{fdQy?e|}z+5uf^VXkD z?E=&;N9r7WOobYLefZ0(#5{;}(M&vMkI9iCRWpZCz|4xqpcheW(e}-oH-{rkN^*e7 zRR?qtylAlO+wg&!G4|305(ZN%S=nA!jOz=cE{aUF+?g{q5z5Al^l2gdY!dF5@8`f& zoAm6=$XfGxX+y!-d&4Q8nrFHokVj>qb|Uy*KmaAPzkYMuY~{awbGEk-K^lLHNYOfnIaV$vB3ivzb4L%QfpGnaZ-*FqIm%9DpMC%P>SEdE|%Kimf5tWOl=1 z(kZJ8HMJ*5%lCU_L&j-Y*pbxg1TX6o$Ot7M+;=#20hFQV6Xo?v01@_(Zsb?m;{%fY=d-T${!YTo3 z43PV@_lsQHd*>qbxT*Zzfy^-(ZDBGB6(02rbyDSGUFrez2{$OejA$pCszsusV;e}1rcV>NOWI&B@#qUwNlu5^VJ35mj3_Ia@_ zi?fwnwm~wkj9AC+7D&m3F98k?6x{f$z;dqq6d=)Y;smHH0%i4i6jAzp(A3-#HQY9` zC$6gbTCKM(W(ZNvO9jQK)$NfSrU~L3am?^dzsw-#eL=Dh)kx3*@rxG#QgIkv1>)}p znEmTfRMBZXRXJhKkw!UJLr|S=ncTfI^61IFjuWX(XPxlppCebf8z|;~+dw6hTpd3f zUT3`}z%r-F8gv8IfLUU!OU~k%K1Uxkoa&YZNpS;#CRhJXM*kh7@f2fpy$780kLc1R zRP_@*afCxdLrGlI3`RjXi0jNXrmeeh_RgIbL?^4K(^M20S*gA%NmHBiCxQtlM#B%()hfk7z4%fB4sT#}u0!+;9O;eWi+VU?6*mwKP029jU zWxd_3TevzZ;GO~P>Pv$4En9bICv}xMg6I6$d5)-Qu-tLr<5zUG80W5SzM*)?{a%nT zQwFwt)`XATDagT)d+UDhdc#+&{lFlw#!4|`5zZqt`TWExtkCuB1p&?E;u%Ko?-=x| zSp|O#YEH^6ig}T}ft5ki7xLG6dFnnri{>JMsj^I4BgL7(&M^ub)2qN3es+3wDJNpZ z@Ku0`BJUBUOFNWBt8i7r$Cq9j>nnh(r^ZLjpuQ_5L`u$4F5&+$g;cWUb+`SN-<0Y}0E z?_k-aq-3cu340Wogk*W$ePQanndL(4nt{8sxy&lgeaHu94W=XCL`ILxI%`5WE7&K} zn|d7sWDbKPfPypimCv>e&+^d&9p{q8iX=$&B&;|uBz>yp&D<=DM9Vwxdb~ll| z2#${w7@9O?>r2>vKmVO^L?>O`+hQx58opN!I7gB1Sd1i6xoOj%hmeZ>x=q?;<<)5w zS?gy-(?`-*Q}$HhPBrZHLaC#cdR)3EM8b6Sw?`5cZf|1)DICeK<5DIHPczV#(}$A# zP7Kk#82Ic-3L>vZHC9f$(zmH3O>+A4EF%@7PXr==eF}A@0!fdMB{wZ~$2BUJYpecQ zcmhvh{PR?0DM%W1gWe~)EX2UC4EJ*`#Yr+M)tIxMpN3FISKLZ=S~$8Tj{63SNtIZ5 zmo+Ua(sz#eGD*~4;*DD2+Vqf0KVb?^G>a(iuj}sZqOIf0`?z}gOn@crYVT+vV0r^d zqm|Y)(_<=eohV*%LWK44u`K)H$o3nGJ`SZc-tV>4?=>3drO5{udCbUMrOj#Vt4vEA zvi?TG!CVnqS#tLJ6;oS;5}Cd;#2iErbpnOi)UCX*NlW**bUntrfy=fTW?UYU-&#_dOo(aCrAX-$p^c%X5?r?6 zosYg%Q>O-Mi-BjU6wtY289P&Od+RsvqNz#l`8q=fmn)wh?33&W8bf;05+kiNPHE}a zzlkCIuc!1Yh?Hqf>=S9}``W&PF?iua+FYU4*-*8Ld5kN)Cn#%)=3v>>wuN)xZ++LR zKhM8uQ1x9`#z<^<*)K6%$J;iNZ+*a=MbEeLsiHz5*l#!2!cnhji+8Eq{L+fPEnI~0 z;*D=si}jThz!bD_+2nwB-@){%BPJ>Fu}w0gV!G`2_uB-GzjwgazCA{8+dxvDydjw* zT*3##huB=smq8p<*(1`iuAaAvZ(LRQNU27PK6zA~8oEVD+g z`Z{G5%#BCAEtMMm#Zc<=Bx>>(lCq(%UuV(tRAO24%gx_k)S%m7}Y+tAfr6xjoH8nMF!Ix-kr0 zqcqYx1%S%Z0QiaAcplpMKxT-+oh4^Y7c%Ou4#^G#kuu_Uh?M`24`daOyUD|-pHtu1^be?zTnx!QvEkl-{ODEbAw;=DqQClitSR*+ zUiA%2Z7T3YR~S7nn$Ey?1{D=hmAXBoro*;P>boeS%6$0N>I)i|X;e(8a$*)#5NdM>r`0t7dD5tM12-P*#&+(Smx?YfH^`|DRRC_G$iktAcciC8XsY&(UIf>gnN=w*POmD!`FC z0X;+K*(?PE(D=psm4Bl{6w9p;mU}C&_D5~X)Wn&awK{82A0h2?=#aT%yC7{lP{b*6 z^Wz_U4m{x{69PNpbl^am827y!!xZ*x0>6JloUgh}`7JJr?LJzbF%R~=;mykHe;)W1 z;?N!xwV%5_J;Yhr3E`C*ye6|3j7mFsaaR3L#BD6aDfIidwjRu*|9{r1*a?>= z9MyN*3Q0K~s#AsDGM>?=PgES!uW|i~%%2x~Ij9;U2OJ*W5$vlzpw5``b6%dq(kcxK z8ctigiy%)SY%NFK;Ijf z_omD!PR^>wG07$5@lzTajOo*-JdOa{SScO5wBM$5Rt_YapVR=v#{SOnVau~NS3K(v z@WMDGiYu?a_`|NPKw(_S;K}avg<(mrBr|#AGXlRIVtO8??!spdG*05&wJat#lD4 zX$u0$Qg2ITL%IK_4G2OD1pZ^U8_SUq1$rc?O_&qzr1p?1X*VY0bb|}0YDg&5RCGsD zV)xJN#M9V`3skG&uoVEa$IEOO=Y7Du|L(sHQxOLl#QtJ6zrzTrM8@IK0TW( zFbIOO`G_ploaUc#FeZ>J(2pNi=z^MzjM1F8cdaOdX`ma0z803-_4MzUFDHf?M!J=6 z-@22X?W`ZO9z2|YhtRb{A1P+(nctU?EhA&7I}}MBw}F?BFR{(yk4r7!q(%(?oz%#G ziKqVm^`Pcn9^U^Q36?ma9WjS?AqjZX(w=bfgKf;(E{t0{GW4;ZCRUl+(2o@Al*I;S zhq<+zHuEL~D^)Vt>XrGFpURX}+OZMMPbci>BFNZ-G>EwPAo|zvnTx=#?%w~ zzeL8L6B+1G8xHnlb171s3l=uF|7;@Tzl=aYx0GvJQ7N~;0O&~B$QtVZY&~y{Xu45K z$5AfQH+C*en(K`15u|^stcYe~{CABeX#cqR1zbi1TQp-yfc}$$6g`WsC0&f`>+Ire-^KENNB@42`?(WMCMv`KRS{8upSE^ka78M)ga$63ly8 zBG)^d%Od+KEOArrC2SQ-%b4L{cPeH0o{&b=>!ijA;nhK~3dD z1U1lPl6)!qXhk_CTz4T;>q3lx>Nkr1edZZ70y}Y)XQW;gwTi6t8GnX8b!XbQ*~*d| zF~NRXFq!Mko8M7k7eN(Yl&|II=PRadxi){Q1k8ZltM}Zk+UzNv3~u7rYR=;q{N>@{ zky<|DA3Uv|Ro<^rNf%#plCvrPYDneNlxOxgZ44e`=@zZQuBqzCL?z9B6B7wqk!z34 zP|^X%{GYAkLIRXmr_>lH!>~%K=!7PnJ%S|&nG-fv`dKA|~3q|qShc>nNsE8A5&ls=DL zV)}V9%2(!eZWxBC-TftrN%)t`L)AKaUmDM1c*BnTm)@vELzLefoe3<$(Or@xfm0nm zA1=PAThN7Kr16Ajj{tKAnOD}!J-)vgJhd3IoU|omQW=`N48g+j=KI%6Zesv{e&n+zq5w2O=n$6 zLJz1S0<=OOj-^-UYHZ8;d}jEhvFoQ;sj}J^zFCmM>UWz}u3oi^l*P=>ZF;ll-Ib?0 z=Qc#H1l2BMcB9qvQR|-*DxUV5XuU|`nqC46)~*_W{c~^x3Zy+PRH*KL+(IVG*lCC9 zj#{5)!x?Uy*bkI;q~QHUoaMR>qf)V{D!#y?0SpL2fFqiF*;WQE06%Qak+o`xmmm2P zeCc`2?)=XnN(){YA8uMo9+{N+$P zZTGVaLMHhc=*AsDgMTwHCz00rrq3fhJUxpsqZdHa-Dqu;6-LrCbzq*J1=gORm9oqE z^`)ew+`32sc4_hSn5&Hmku84R=noLF<=rNeeAIgHMh$vhd}4e2r4>2!oVp|Qjaypm zP}{v0Cc%^@_uIgU_H1BBd;Eh|MAC#GVl@TpA+vKSm3b1#HDg|!wB!ZY(>6NT4%uUm zz6DMw2u~*bR0KBH@?t~!Z2pxAD1oj6!R%t!=+jM&KXX%7jwR*j;AHME`)RwI(T1cd zFsnMlOKfU=H2tG>=GOJe=nh|e|JjK114gd6pVA(kaBYR$Pk-&veuVx}(ZM!<5x=V# zGpw*Zq3}GU!J!~!8RaGRs1z@#l)V;7;4AR*45k+30L}%)4~sLo{7hT8MwNLfYLt!3 zJjVCYp}UjC`-QzQ=U5fV>6Z-G+K^;~iqU(A-!)6jxEV$W;_3)(7HPX#qfZLamV#1B z(e!m?^PQSU1QJgADZ~%=7F!iUT1)J&P5nf10jvkIdnsxHjkc{B{m>5Wxu)@ZT8%qa zu-9Ej(X>Tzax-pRQ=LtZUx^yAO0!GS@?IH0M=gN$em7G2rETYPrcC|Wp#|S3*W)+>KMf zb>`FlvAoFl7Z7EGyHD#@95wHE%u9Deqz{DI}mW0Z!Q(7I0xkt3m#qzR_58Uh?X(?Y6g7lO=5;VnAw{jKKCr(=&>n1z6_^C`fbDK1J#%wRtWv)g>8~ z59>CJpLb+YCMnHRsyeW2!wD;g?iL3TLX8mMYW>M7C#K@o`65sUR6YFs2^HTPi%w56 zIg4^KrReDmguZjKRpwUj><@?Q0h)UO!gkFeldSL7>Wj4R2s#5!!_eXfK>ltMy$^Oz z!iyj4N;=8Eyfp1K&HQ>v9G;cs!5rkREZpYGAY(eqh3I{K7_ZCA$_v;s!Bk>Z&=JmU z!uJK+q{=dM>`X*>7N807wpNzjK^MW8u?3d4Wk?(q8R@la&Q8j%dkW$|#i?OVMh9e- z_FZ{&?jXW!-ij-_I~u0(%QQ8{f;3^H>GS;N8>CyKtDHvLbko6Y=OGk1?I%La?ax70JbBU*lM5wsU$6sg+B)x& zlJS`<>&^ifH)b~|!)sew?Alk4u8$*yJ_XU2VF8CeG&L>U@}bG^40O^;0?QA&d-dwo zYYX=(21J(Z0SByq^ffVy@2wL+P3NTs$ApEk;(UWJ==|NK4F$$iYh3(4wKQ%imUZr`H(hH{!1h~QT1Xldsy~|aj*^F zE&b+qnbma`j3+IYMm}L;k%UB^kQdh5xE{|3-}$!6kVn_TPZCAvzvRBv-1fT>Zymu$ ze!FIe3ab26Yf+VU{{O;7a!6rchH>2c>Uq&(Dp~2bY+4_Jd=tk1l<$ z3<+~0!m9s1!#sg3bu;_l=RR~E8D^DH!s28(8s>{Tl0?)EBZjQSDy*aeJ4cc3KVRcM7X~amn6z_h2OMZ&`ZGFDGA5Km7nx0rfnwtjn`-%q$ z>-|hs|LbRa&6?wkVl3H215od%f}t1k|}{( zZ>Jg^kx^K!0(*(k^7Nx@F7R2WlVBNvcG+yb`#w-S>=ST+Vx zs*Zb)eYfyAzv1!>wMWg_TBgaMWG7}!Y1ONr#J;URB0Q4VBCMW|=X6ruO)&5pW#DbK zJ;nc{2&mhZk$Q2ZU17kg?>9D1v#*^v-DMqC;vdrvqQ+#wNW8$lzam}>Z_uABMHkt^=EXBFHJi&kVPx-fc-A#I=HY~Mers)I4u9Np4Ik-mK>ezns z7mIj4^`{qUZU>XDG@;3{Q;Zo6Q|hkDj~c>P|E+tTYR3uGIT;cPkwuy#snb>v1WY0G zZ)eY#PRgzvxhAFddt!H&8LqUHSAFMRykK%li77aZ*#zDqJUVY4mOx~c7*iY#5%Sg9ckE&LnJ>=HO7a~j*;^zYEYqTy8dY1$*?rP z&Rx98_Yw?*iw>|9B|Xh!y&gZ@#RvqFnVpIA(tCtrr(iP*s+<| zv)-B$fMOwCS;_I+8X?ynVLpB?zl{zQD{-BaTFUFDFF-d`rPxT>=G1;lgY)DNl-69v=`F2U-&ot6Y*tX4)9&muBx(MBD$XgZ87SbqQVXRb!mZ!{iZ(bil zLrBovDWxAL7s7lpLlYtEN(0h_Mc--J3}O4m_w{XP**D>p&m#xoOpvA?fAK#x?woFF z=DF0YDEq_H?oF!?O&M^4^uzA>=F@fHt4=?GKmu}1@Wa{AWAMBT!~dlX;!zEkA-6Un zvaKf&_}R9yH#x(ML8H>$O_f=D&l#W8Djg~`A}XO}jYymF(Nb}}phE#-lm|^6jHGc^ zlXBGHp8o>{N@y!l>?>y(Y=)+B{(5K85BL5DTX-k%Ut`)TL#FXBQ)?rQFwC9}U$VA- zUkLVB;(%9$r$T)QFSSHMV@7Dev92&iearMCRRzQ~ZC3bvSQ_+GET6=XH2zzzwS}{} zNQTMt^4P}IvX81%P%tncDXKfEIO=O?S7tp(c2|;SE?q~H6S%Abh7ip7;IrShvR9^= zmvAMLb{ETqR+@Z_$M~lX>WJ(SQrH1{l{%2+2X8z-Z{EBGdsDASR2|Vmm~TlD zcx2{AyYY6o_R~fY-jTE?-TZUWGm|L5g(>J+Jb~DQET1CKJE^UyijYH=4ks}_6w@30 z6G$xPV*csfI|>0Gc`ep7vm&x4fp8&C?;h{h_iP~Vkv<+O_C-vAThF>HFq1%#I)PBG z2Pm!t*)ZXvEI1I#H;l+T6j zg@MZqh4P}PtHR$Z#{Zw zd_3gn&1X-dT1a@#-$5Ya*}UFhT95MJs#0ah@SoW3pR0G&=z|Xw3|49Bu)Y(5wAOI*!mRgVU|J)JWO{TYmMoN9UE6}m^Gh& zGx@EKX%szB6h!E>(q^@(o%K~ujQL2xg=gJOf{#Wz%~q**7Ys@ueUyoD+txI!PyKF> zUh6}m@WmlFe#}zO%5LXkk$u(rdJ%qC{5Si%`YYU8D{r^*Nq{4lpmoSCADrUzn{3C`W2JVam|Z}X;szfH%~4$`IxJ1@I-SbI_f^y`Gm(pzrDG}ALc^)!@df|im=~Ot&)3FKnyKCIpG+`Oe?=(4Z(>qO{IdgDm=bJa1zU74Z z9L#_C>RS{Aap<1D4aK)&g*NCkMB zH`mg{ujyIT83x>Bc=M+<6(;5;iF=bBg>K!vCuNlV z5(JsgjxB@#d_y%%R|idUhQtCkVs6<5OZ7E&WSaDx?RRIPg_Ly87^__`U8dsjH*csv zGh)gdq>LBk6kB~o=bi;!A=tEdOnQsJH0sSjV-B{g`}w0*#LS zWCGT*{vw(`aB<^n;1j-W>n+#aLA>iF_{J`bB}tPbuo6yW+BPn^z3TK4=ynIXL66Pz@1Q6 z1pXFBq#Ht}c1OoT)B-IcElOUFdas;@eQ-f8u3A&>(3j&2^sfzCe{`N&u<8OHy@?VV zHf%6kaj?6qvmPZJ8x0MEV|@l~Ys_6MX%6QR;)XOuS z5y;a->xy!{Y1J1w(#Bdtfw9a3AMRY2Y+QyOL;OQe$;+0{oLb>#u||`#`$lirS@ir| zJ$A}+-3N9rH`*L5hzeb7l>Me|y!2ljL16D`#$jGQ+fymmBYOTX5XmER{sRPVCLG=C z9SGh#be_#Q+O^V>SC=2`dNxQS+Qj>=PW<-lFg?@1pLo5! zR1EjjhB!k1HQ`QOY0Yor3$P13B!3LF}Z za(|~y*XBR@SbFn}uq&9Wo@g8Nr`Ct_dge1A+Xlxbb0%iGM-sZemWHPYFOn=+p6Ng7 zQPt<84V#NgZO{kuR`=>THENzLs!5YD7O&K+1vobbYJE}qa2BivO=sM zFB@0!DLxuYop$i0Y;Nh}yUPhd=!?XmWp95o32O&p60Z1MW>3NdCfxZ~ND%D7nQ(QA z!iEi_ItPZlO;x$4z3jF9Uuv=XDI+uGP`TmJ4?hkp!;Pd+rWvZd&{NCm9H;hD9{&Hp{ zErB+DyexbiPo0FXTvZps`+nQ}6}Mqy`$*Ve3NpY+^V)vN!CXg$y>gFY4dJh(B;aEu-_kQoP<3m{6*@rYL5d)X-!j_g1HrFS;ooo61qsBz( zgB7~T3x8RLwaQoq-nwN{J%9N+7Eu(rlH1NJU6Etdx}j-Mn{i0Jv#n`f^m?sQ#0f4W zvW-5U7f0u>#z){{}n++N>!iL)G7lL?`e<;r} zegq;^k!Bv=a7|d&cKo9L9zs|shB7Za9dR(YEGfvyn>|gc;7%U@SlZEALXK9YH~a!; zZJ9V_&UU@o>>&N)vt8{592`e0E$*Mx)pO{Kw5-WjYcy!cpV{esLh_WayQE_EN$3(@ z^)JZY2`7%R8fhCVS*4M3VZyXiLxhHy4BkKdm%YX_Rio#p`ZJ8ZFD0x#HDrOc zkmb1kygZ}QSGA$nMQ~k;)zRrEbjkL_>7QEeXIp@eLQF5&b6J<5OeocsIV|8=Chv!xVAh{oI75y$r@g+PQf zqLO#ZV+SLctJOHqzisFt_4j(3``<^OZ~pdSJ!}+h&(^j5n4e9r&Mr@7Ptp`|02OWd)VmIeNF(4cE60g942WCIZxNAe_}{&a%zdDzrpD z8NB~cW;r8)kgOkW4;!R3n@?2W8myJ1k4uMeZCU1yF?`$8jlf}0pDG@N8Aar8W$M zTC;BxWxi#FxZ|10w{HmX3jEl@c%~zxazo{PqO(cN*Fo>U9Kevht@lpi;oAX|^8w@L zFK@O~5N{?Pu=vRk+tP<3lby+&yDVZtcb@IKmzl6Xw;2gS^RV3$xTXXTYrXxgd;10Ug7#9O*ZrYvHdbYa$`6<&NMvkhb#8b)GBV=_i_M&Q z{_)gp=TFZuc=NpIK~ucOvCC$x z#Ttdrqjr`1kyp|lN;LH033qf$z)&EH&V=Wxs!EF}?Lpn(>|XpP{Af$Xa&oRQRu8St zi%4nh-vK(4O@oJT$2`baW``POhwqd;RGB%OsOde(83`}f0Wo#P?e!xux6bQj#rZ_P z%$*ORye@6*5&Hqgx9sZ6Z7o?K54|Aw-N7by1WBd%p=aQuRT%@H&g4NEX~HAM4*7n} zm~rXHW8P}tJ*oZr%Is-kidSCNNr|jHn!7&&G?v?Lu0D14(d0E@Mum^<$+4}!UB%wc zT2WEaDQw=%;ge-;RlC{?yA?5{QsgpIH>e1w+*oS%bicYd(z)WOTx!l04if6r{ec3O z7xwAn{I1MVQ?y*CsoOd#X!+A4&t~J8@XGfZ%6vXL-9gVP$_0U)f>xVI?O5|sN3zX> z4EF39|4ysT^sz)$gB}KRq5Kj8;IY<2s{-WSY4axVBucgIeuPeg=es(#`b0EoY^im4 zc;(^4#d(oN?sfJ~kNa(G4@fMyNe+kqN7r}9Q{Db?XRn9~rHG2kEGv%8s1TJxb_&Pd zn`5;kD`ka@$lfwDS}Ob4n>*C8NA`R^m-_vF_kBOl^G|>DO3wLyuj_MtKJRrUGQmYK z5m#Rde-l%I<(>k1twbr#>EFfbZ-bAF4?G;=2P0>H7yRRA5=e**T>l~)i44#2A|!l) zpfPQ3&77%>?`@ZTQOCXTnb{!9!@&L<=K|1x42yyaw`HJPrFApwu2+X@ac^&eU3P_( z)7g*EgjhE$ghxk@XNXFP+*t#ikRmk87bdi1_Yg7o>3cH5BFLilgdE(Y``QD_Z_PWB zRpT%->a}}xSsjxLA8bRfnIup$$M4<`_c)lo8G8$2*5K%fq4h_y}Z%{%L?6Qpr}sa=8llKH)5 zqfC!oWh9J*0zeZB6MN14ShT zCg$3!&FFTSPb98hdM71L48PtZD{w5ao7R_N=i1E9@UDx9a~n2LKUJxb;Ub7!kIR8ecuS7K9w`_Z~RUZ@I}zXW3W zaa{1YGdxOg`gA{pBFRS~Z<^v)|N1W>97mWszPkR1i1}!~;ztWeT~9s9$?4W^Gf=MH z{<*@G9yLLO9kstQq^is7prROF5Gd~)kcPY#c9{NLa>JCuxsAtTX67J1<{QXuB$qqF zfk0{~J*FkPdOMNmqLWVE9bdayW`12$k7TJ0z%cqtXW-DrT$bMH`iWiHNIBIJNlp+d zT2Vf_I5^>b^B8XN8NpjfL!hRWS*PE=p3ffO>TmGWZPtX;`)KZwK*<<>?X!Jk2gt4? z9(`NewO`(J|KDA@94~CLPc_23N1~5R9+{($5p3t6^HSx#ouOUz0@b|h%a@hz5a+!A zwf=Mr;55N{flnkB;jihh@Dj0#Wfr=Y)d4PI{O(!j+nKF0ho7o$KpD$Ze0F7KUG;!h zcP`WK{N70_zTx$fx`zoQ8I-OS!l+L%WfFsU$xo73Mkgl6+fwzOCXoa#8Z$W)9d)=| zLPqO3Zik6#uORHDL}1Q4Ag=~vEW4a+Q;Jg!Gw-~B+d4nH9)3)RSt>CdHif3}QjH=u zcA?}Q=wmJ&YT6LPNlgw`R-B6oU*7^@>0pIxmIXoi;wp?@g|x56ak0|gReM>4zfy8~ z2!}S-0Rv{y?-FNd2#R8RP<@g-)oY@uP&l`BoD;$P|PWH$3@IuUWy@eb1Ftj|h zr6;x~p6)AQrrl?nZctqXBU6VJ&r>{ig7;-}+aC4BjWHGx^HPm9dG<77Onqv{ukc5& zy>_--Ms(LLQ{QL9(q8kl#^tnlZya+QeU`9!R3~0m_^YJr+@0!p*>@7zvK~05Y|Y$P zD_mce?npV95$d?_{Ol&m4O$lzqNOQ|glBey$YTZ5wE)Y$U{hCiM-LJtk}V>Aq@0#} zfk<^fy76bIy)fa(*@CSE2MLQnY-@ZFL79A$Z+!{?CmA!qccvY3ZpN%Hl%o4s>iI!s zwQ~x=gEw*orrMb6#-H+To>?EN144#>-Nso}e|xxDya$G<{g-8SyzsMj6^`R!V$6-< z|Hg;t??b7)JNE{^3z+$SVg$01b-CaY7%50bq$C=M!U&Xw50hcynagM5jQ7 zyF)z4agf#Y^?i{79?ZQjAcvBtkS=9rl}RRfN^|h^Fgk0gc!sxz4$0;|*isw*z~C=V zl$sFPRN7JIfCakn8gV&4YeKp#_exKt;~O7HH;_yqitnXjy^w3Tcw8ZXwEa}&8#Fc- z>X5ek^vPlzYJut4Vr!TcCai)IjXqBOL`7(1<`GQGEs9B~6kY1ZUK4#avt>Wo?!DnH zY;AtXt)B4}&{fO5t8^MI%GjU|-wMnfiByH^T1VTQrbPu=b~e%|JdfyX(jyDu>yz9t z;+LK~@XR9a8mZrzJIsr%@i7y7(=i@h7xzgTiSs-Y*lNMj(4HKhiuPrSOUVEb%!{ni zQ#_LVyNKniMJZQQv+nvYfjB#nQwsO*t{|Qo$uXjTn(iI7jRYo!f?Wu)n2h)g2K<3FtXc=kmWl=!u~p;gCsD{4?otguZ%Z-O!s-?dcQ_ zv9d>OG_a!kA}QmcHB}^HZQtxN@Q&7N@jVV>j8In`lVfOJQ)wc}qyDGlpB>4mwmzTW zd`E0Ho*w4KttS{H9gUR|Iyur14}HkPk`u9w2fu#(`r`RcLsdffTJv7w`Z$TmvbcwH z)KGa9?Po@XR@k)`MpZGtqAxFSjzx|pU-)a95_^%IoD?1cXnaddL%*393wZM@Q_xKA zO@uNW2_t-6ho-@F^Is7g%M%^-l+;?*j%n?(Uob~L0^W3VdBNtn{*a~tjuXb)pk@jT zEoS>yv-B8fHs{Ou;+#I3d*hbd4xiYJ{`PSu**)m>rJyi?>9IYZhqTxnMbFa~kg~8f zwEB9Zt1|h4uqf=iDoK3q1XNfP>+s8$sodRc^Y-7}U91}IpvX*Eywh7`hAQ{kKqv6; zfoI#Gx7(EJj3ilH7d{Kb&GL{9#jj}Pc#fY9(Id*tRx64R5+bD{r9d`bOzvyl63IiD z(ywAlP3QtW6jMuAaYvcRo%F;&nA5ea-q87`F^V~1Znm~wi(PYD_$K7Ny;5krwPyb# zdc7@mJ8nS2T)-!ogZ`55yKiosHuhfM;!_hwC4MPF!|{)S7Vf%rpLhDb5@1Rv^nG2* zu7t++n9`DK@lr`7K4%gSh~d^> z995BU{L##yDCwr_7zfL4o91L3m1oMR0Pn+J%p=Udwtv~l$5MBbRpwN0yGMgC^Lmc< z$1Cewc#uVLKuQjkQVawqGiM%tGsR+1W}kBIrbb@VEbbL_F0DVcQl`jREzT{rV>emdz zVnE{tQ)Xo`jr+sz?t|VtIJ2&txMmdrkhk#)g|4!!Q~eO?0$@69r02?rBw-1tjK{*R zxhpc&sdFjtxvk&LUy7FWC>BszeRmD+b`6UD;VI?A%HSU=pIKFL4S|fSyrH;zQxuh4&yWmk0<=~<&Zj{ zj3Y*5o)xDTK2`@lN_cOwkgO5&*Ve;aj&G<8>Oj^w25P)!XXgSKe-8--kLzqNO>cht z_N@T7;e3KHqfAi2x@t5}1q>vNOa=cM)K5`j#5=e*o7%xh=iwPLfs3_3^O%Uq@ivQ& z!@7~YTE`(b*VYQavG{~8FxZ^9`R3}k;(^JWil$Nn$T)s-P(WW7O|L`VJC&ZGv+rTc z6nZ!SQ&7*$7dlG$Oqg^{D}|c2fQOv=*f~s1x__~?S_d}V3H3}3Mx1O9oI|6a@DD{q zEbN0EuOaUP(Z|GQx^=E!Ayp^8$LMV`8?#fWFUp<_< z?UK}-P}QKC+Vx&J@12NN2$Sz?v&Un@`CDNTb{DkE<)!KUkNA!zzf%)T2}~a4*?+$4 z)5E_Ne6-i>(ImRBJG8CZ7ZbOdpv*8p57H0D!*2 z(Amg?e=m;PSPoB8c>2`s zhQmCHqxU~WWN=>}bNW{gHeSTU38;j5_Llc$m^*D|d6{2`cCrXrhU-7PXWxHhIo)-5 zAHv7K_=aOL7ZjNFj5>-g$I)3f@t+ii77AZ7Z&6yGVLCkd?Nq!ld7`#e7l!pcsL*-)E_iv&DFXH{Bi6cWfzsG#ncsRXA=*iOK7&K!~K{ zp8S~+<#`c;h)Kht@%HrR_nu=Y^_umPvB|aKHP5CJaR^IJ* z7ROONsrqc5yG7-0CT4K4U&HWNjcz4lSQ{(TGQ z31ePa{F5~M9l10!3rreW`an_St?H-G;-8^gR^alSZIWf~h?Hk3UUN3`V5u`=V1o|7 zpYpTF*+OU7117ymjreWS%0)Mg>Qne!CtA63W?cT~%g1m=p2vm{u0k_=tGKy*X1X9y zMvPQ=`0WlsU$OZr&6ja83Suh0P=PX^7yP=xhpYClXZk$7{I>(dSN;Hip{6delM3i> z{m=Nq2tLB>xxRc|(+r>Vu*iC#_iY^@MC|k>#2SXVFe`mS_0{;geE;zW!yi=RnqF*u zS*}YQp=c^<+N;mO{zj41l1z9`y}&kF$Tv5vcq5O<*Y)5WCqW(DR`-7RgmF&qY_#kt z^%6)X;xniNVML0<&!#X57wf)(CAW*%^@1a}#_Nx4y_HrKx-hX?;gU3um`v^NfQJVtM}u73oz<(CVZK* z#{Csv<@w&a^IRO&RRcLg=DUT2vJ;m8ltD(zoc(XRr}f&6`04TXz0FdH|0iOI4= z`#Q}|r$f1l=o2YQ7P1ki%~&>W$$wN@?l^wejjlo|ZW`!o{fHQm(YK(Xv&2f1=VOy2X8zGM_!=o(GbxpUx5y!>WiIQ zCkV#Va#p6K@Uwbw!LuOTtYe8i>@=G=5U5%n()1!-mbt=b8ycFs4}J~wIq*&9gYKzt zIxsG0D@zK_8v&M2`P?iPdkBNE}r!{|cR%o6)NV>>LND*=xfb!s+`c-WQ5Q!&VHb$DR+eOb4rDBOaLN+-c z-&`WQ3oxjEvdp|(O0s;S;Ik>>pZ`*LQ0Bw;p(km(@Owt^r6wd4_>xe(e0++N^Ds$Q zU8I0bXFfVXPWt)K1$WARo*4t5QzYDk(^DBRW=Tqyvdy;5Lwln3Ywnub9vn}K7ehryyw|stBbT>B;3Gz3deVnsTuYks#Rm$_W8V;O@l!4IS5|FoErN@aqEs|aD_qH zOcgNGd_Xv!6I{kPfVo8IYX7};-Y&puS-q( zC?>mqc4MF{_PN=jJY<_|Q@)Za*2tYQ&3ijf1~Kl}p4T?0tQXME+9LuO=F4e4tDGT_ zoFv`9kR3oC<#^pbL4iV|D@U(y!8T6eBW@Wx8&`0|G3$PW(_DeH0=B2+eUXO3eR~O) z8IZ&s%xnyH0`(k{Yj}9>OnS*c>q`s1qW+#OdIznzs;y<6d67jd+iJ7lqdiN;JkA^i zP*v8155&1tWSxL1m)7`AYU&ZamYRnH!#h`&pE+O|=jhy*rh3|DtG^T$E71~n62m)K z#I{56&XW?0(cbC_9@u-CWpk!2HN==^{c3XaB@JUx_^7>ewaZkI>gbWooNx$f5}v&! zH{URZIVLM`Zo86Os~~y0up8Q`WEg7r1*@fzsdR zXFiI1W%VWGm;g)n8+LkY-^GUK1tT>Y-!3dwW53$Q$N%E+@u`%tSJO|wo>g%2)L8z| z6TZLC`LJovA?%&Y=1AP`ruTR8MV%cyrN;Vp_WM-|y7y_JJjk%$AafKY$K5N+Jjkbg z3xs)GaGAA>`usSjt+r~t!Q)~@^WJm0d|AhRWQ3_+Tsd>%GfhOm8-6@9sUP>9^R61# zk~jn^-d8!twe)257F$`A^+XbtX;2`ovmh99cw=IFC7dQ+AO(Q)u-RET8AP8VGn8_5 z(Qy%xj%D$$W-WHkg%5nw)En_BeOINYdG(~XF^e5MzCmr+f24YPlDNio` z9C!>7DvGm%m6e%6PPBbyQGp2gc%WS)K}hjF*7wl&;sA4Rh_^31OkdJkJ#!X)0P|zV ze%%CQ*a)s7%h8jRa{6&*tI+M2Nl&TW>-Q;@v@Sgt~DV+NtCyvg>GEPE`m+2M4o$wDu8s%kv-Pcw05=*>AoU+^e$zYx=0 zh^4~8>G-#ca)pE%!ddw-0hq~HE2XbYJm`eL~OWfy=!?rXu*wmm~gaxyTL1g6tS z;3PPdJ+?L80jFO3E2P5?<9dtYnKomp?3Zkv5+oGSVf}Ht3uJ*>z?_dsVd%B0U1*EM zXxE*o-Q%Qi+QOSQ=SLzsDFo}IHk>@cRcH^U&@#>FdKO}jAJNzbP!0h%eH}Tup^ zZ7MOcN_)YW$VuIGhgg{S$~7dVT&B7W8{UlHIQ*TlVp~k|erfr^04%>|A8DmLIP=uq zV64(R$rca3d8)c*(m4I&^MEb}RE5@Lh&2+HK$7RN8T@6!wE`2zSpp)Hh?pUeAsl9`VKqui5S&Xc4CnFg@fBiNqBLL8+gNfR-@_OQw-8OQpGxS zY#R~`0&y0X*iFdln6giwSqu1NxicaX)RaOStmL#GFu}pE%KVUO7wFb7aTg3Ie29vy2Fh1Z;!$QLJQ88N?r5Aa0-^p2c z+nqs!;Nv3=-*^*$cPN$`y%@~X)MXVC3_2F4-ZUA(kbbk?>U~ask=b*&5W0(Ulu+O@ zKc*$p-K;x--;a(q3%>6iZJ8E|_L5uv! zqo6tx)i)Z}cG+oXCN=jD^n>X|@S1Tfm7CT6tz;NnEh$gNN8^=0>X$ia?QC+1=J@IV zt9LC0)nD1k-HihKh?*49Obvr$6Lo0p+^ALRlJZ19@r3KF!;h;Wl)~;F|H%T-It{}N zC#FSQJo-Fic@>jdTS$EqiH>e2{D4d#og-SVN5Fz0qL2iUAp_JLGnd*?5M9 zzku7fYE1wBDCT*A3w)g(*4I{qEG@*TyEu5|f{;%O;RN*~!m-GrrDg5F;&{P(+taK) z?b(=}bGWKu>VMmuwB{wah5Z@u^JnX~Rb_J#%?H$8C+` zG?UdVon8+9>s8&R29beL;(k=|_V|9KbZ!2z1Y#C{I$n6yu;(^;5`>f%r~5QTt@}+i z2BC8(F4T*X;QUH#qoWdfq|vxGTc&E1SS|6BIQF4NPctlA z0~q232v%=f|4Ln0`xzezW12`khpMaq>)sq8N%@R}{JOUtJIQt-RC8YtBtIdi&`q%( zCJZFh!~1q@z3@0&C|{hhk2u(5a6vezjr$)3%txjh zX@P@Y9{4R6ALh+-99P4aG*{GtWYKKnhx30n33tN%en**Ny5bhK1R0{kz*vpeM+)3f zY?vH$FnWAcQPL=<5Z+Y&sq5vo;%C$n+NU+2x%qwrj>#BOs5Fj(% z(hP=zt@olj#8mDuo2dcFX5Ds|kQpOl!DWD33LkqCOU^oZ=O<`9yB5SMlvuC8yw#;` z&35kY14P)_$4G}mb$DCn-+Ua49;1!blxAdrFI|Z zy$vKn6MkIn=FJq4Ktxy}SGZT>*K{F=@x7@?(PZo)p59>A!V`++U)knzZ>cvk$-L8e zdUBz&$9l=)TgFk0QXY4W^~2l*94=slK$0?6YaorU;Fua@b_`@;HnG-6G&XzObxuBtpJ{? z2;|E}XDJrc$KBPF!P+93gWF+5VfRBn0GeC=*elZf5_Fb|pOk;TWbo(lO%|$tAl^W0f&d<+fhlR*6oVC{q|(L9br`BFE~DV_76fvxR4Vbej=4i0!S76j!^Q z!Kn*Q6SweKzDz+7jn(@^bX2K2A#IYWYhwjedvJs)j1N^;uW%UZbsesBG!ghCRGL2y zRe)NfricG4vS&n#*(IGYzI5pk!xW2{Z35wHN+P#7tvlI5!_P`7zp4pcu@)dD?LxC+9rG*mgqy)TMKJK<-l_n9m%RMBKTmXq-=&XT_O z233EWgM7E;lu}+ioqn3I^7k{@N5+o^TtRaJd-#v-Hc%Mk5#wEFa@52V-3Njr%_98T z?r=ndbg5TD*d)Q1(%79*=!Q%q#WoL%O^?rH6weNw1hD_pfRmjp<;f(Od^ye(qr`$I zuH(86M-k6~P=J zg+#Id`)ONmL#)PVgpPH7|Bq{9U^f>4Jz0Z26et9Gz5siWi8xm*Bgm3iCaF&gLw`xt^% zIWy1yiYHjY*Ok`2BU?uk!mjq`ZvgJkhg7{|Yxh|f3u3I4uWf^(?Fs-Xyc>UA8v_+9 zQk-bf!J1iuGeRE>ro3Ms6n+PNl3Ld7eA`!H4!_GAy&4g4scqHYwjBw;9ivKCe-TZ= zeKn1M4DkXxP`r2hwy4R?^ui=j(-s-H%M(+!y*`-&k6I8SV0k?bEUq8>1}qft+gw_n zMd+4mwMuuVdaJ?kYZvAp?1fiZp*>5xmKG(%10Z9u* z&5kFbI>C8|EOooJ2mmlUk4oJ`Q10lj65%TBpgH7(Ao$iV426K98Q%TTdyplV=3l_wxEdhh2>lC1SibL`vX}g zeT%=@{obAOYBl_X6ln`_&sE#tiZ_LZ5&U|+1#xh=`Zslt>(Q}YIE~mL(;-hj7WUTb zU7)|>W!861WVq5QN)5NL<3AzUIZN+6zcbOvl`X%mhOt5Cq-H`2lS073{JqN}HiNo| zx@-!6ZHzN}4~ZL5DYU7%Zjgswn&xlkdtf{rrs*Gts;SFCG&4h7b}};dX8Fs}rq545 z_5a%kbyopqu&n+w>$|HT@a^^bNobH_y3X5=UKu2}HqS#{iIKn@^aS-FuDLto3X_a7 z=;IofbBpVpmKyS!^xr!yQXpaR)1q z){IsH<9dIN!00dJ3_5_LE9NGx%(|8K)Yw<@vh&wx?HBVssTNs@NO?(>82Nb=XjYf6 z8rg}d|D#6lA!Y!jHKVMj>TgR=4e;lq;JSSV_q7~3v`#Yvvl8{8-6}1Ol)Lg2hbl~# zT754-h}p{PhuuR?JNcFN!wtva1o{*&^YHsmAfsN{#D-L7js|?!tFiVmiR9OlH*oyR zMhnPhki@ku-Au=M;3uMB{89OQ# z>7f1rS#+ZV1!(`-{~`m?2PhybrX8YkZ84KS!2fWktb#lTiJ(KQfEs6rlWj{MJvof~ z%&6)5&cD%#4N&3d{zfgC&_@;+ecRu~H^16GlnrO5Ydcv%*sE9kn43B3s4gM!B)`|F zfjG7By%yt^O&W}|8Zn+3g)w~2;7w|+aIuN+y>p6mlL2w~mV*@b;In>LW(VCs z`mH1-PrZYhe{*2KvOeh#69j1)UY85CP%@6_hQ+pf!#&VoHO0Jm{YwMS0~oWtFHm1k z=q}iW1GM!G!iy%fgEpF{H_m_i7I?tUJ5HCK_{r((r7}0p+!UK5y8FbmR^)QL%scYR z#kJQLAJat)`#$RCJwlTp*exqC?bj3xH>?xa5B_Ph?v^!1D$&O63QKEpP%|LjDE#AK z>n|yaNy?L9@cHD+`*eATeGxn`WzxU!^Qx-h8sHQ4upwvhu~iMJU}s=+n&Cn1Br?KP z*5w-0J~NAIs6Z_{-7%>tz3F^5fDO8+f{#h_gK*kD+ugGY^8bpP6WJtQY1vWN+H~i= zj}ochIJB~Kg}k!dWtJZa)@l_u1nNid z-l|Qj$uBxlFxj$vXj*dM(jMO_Qtrqs?r(~21${x&|Jn*Skh`tUnsQ)Qz5(*Qt(0W| z2?L+IVL_^N^qC|sQ-XpT6n28uM<9BvWD(x%J>0mm& zL;TK3TA0Yb$vA|#*>6vg`fMHP?umj$bHMMc!Jv}hZ6y9Uz({2}JfAez-)Md$_ z>#cHp?|dvyq!>g%wc?9xY1P~5x6;F6amzhPg}CeV{+vh=?N~zKA%4aH{rFpux;tLV zkx_<`@r8iFpx|xwY;6%deTfSmyO%afD;!QBMF3Z`Mb~bX04a(rw7e^vgeB<+coD)z zDTPJa;YhMezK_lF?YCSE5Zi!T@r4{l{bqe~7PYJZ597pL^*ghGS<;`_V*{5oR3=Raw;bDTY#uPc?eT>uzw5^m3I<((5S^&W-#((%Q!T02jv*Y; zE|e|<+VPT<+xd65MfTtKIHn@OTd{BJJh_KzB!f(irmustW@LYv>tUZ<{V8S~@jS~S z2B6XTEJG2R@WF?78eHu_v-cHjc)T_lO_Y+B-Dm5q5-HvJvrNZhRo{`psMpOVlO z4>-1|2*U7GUmWMB?~dYN_~Jbmo{Z&VMqW0#K70P%DBZUTdZn>sw_pYKDfl`cRGX;& z$n5kRfaD?Gc0;sE3hdknu5|}i+ER~5i=47cj>}w`Njr3XWV6}JD>-dFNsgVTE0PlJ zQ@=09CN%q>Idew!v4pk5OS}bS2e-oM@nuBsNF4E!cpP5Y!1FpdS-R4XN>7{Ex3z4} z3{>!i3;+N$9e9uBQ`^WQnY(X9VBlULzWt2nuJ^UG-vR1|FXj}%_?w(@AhqR&$LnP5 zo6n}RLIakgXN3Xk5q@qovc576)SrO3wegp- zL2B!Wl}i-yPWBL713yL=o1@;R=>$nZko5`@8i>D>5t|%)DnN0Te)E(=%XdgoI}QnSmh{c^aVpc8B9-72)}W>)bj(l& z$fSQTPBHZjlQnyKiQfN#&%kg4zh3!Q5C?3{v;wt}0yt*Gm+`FYpuT+D|KnEr{cxv* zG?#P$^TqF_t%&M%5q(4CdI#5Qh?=k#g!Z#<+WCm1vo9!2^{O@(l;?QFcrZzk4 zQ4~l6qENgsJyb0{%pPy`+ENDc5E=-KcZcakb~a|K>me$*p%appKt>(|b%g#ZA}ic; z{&Qw^nJsj^%*_dk?52V>z%?Qu^uRd1D*%=__i*GU8Fn1T0?@V11=M_17mSM^@22h#Ql?|Q)SLn+7 z!{g3@R!l+;mqiuZaOF81+O^sx;}A~)%!IPNVz?mlsvT*?zc4!a`|xT_OKBE&{n_k) zkvh-xkI#4BWUX-l2Mbp$2{Wpwq&hM~rtd(o^ID75U&y*&41_xddDjl8Pg@ODE2^5r z`e#Vt55tSik-__YPPtX5_1@}5Cq3nBnF|z^FYJ zF5)i-x5U~HtIlZ4~I|`<09CF3;Ino z)0?%|brasdo~i?3S;}>ejh^s}OATWak$PXJz#GmK$Jh}W1}};H*R;vAj-TSG^bsCA z#I)|l%oAPALIhsB^>K@y{Ckh*3w$Agw}&mMFGC~Sydo|kl8^St-K)d5dl`D1Og#dZ zUd|l%@u?>M`67c$u`HWpbi4TT_P=_IzYy)qhHd&yT=+cp+5T>~xjJv67cYey^`mip z9zpd1Q55tEK6FQ0Wkqw4^{wyy=^WhMEr8ND%)bqGV^Z@mGpyU{(B%)n0{dg5~?jQKGg%6+*S*)Q#y zcydZVj5{XXuNE2Y0D$SsmPW3A{vn{iDIjbfYChYrw^v`$!fRvE8V8~bWVZml6@i$|w8 zi{KEfj>o10_1W zJ3Q^U+w)%(wkIF&*3SR^A+H4EXY!~x>_}+9DgIv_nLYEyoL(wd~q(QU(P8{axOqF&H`X$tTZ8K>(!%@UZ6V9=JPJ}Rq zseE*o&2^Xvf-k(Qv3T1Oa`VNWpCDUynKoZbHyC1JDdbt-YX^K^pOxsu5>|TM#uYkw z-jzP+-C)*Lgs?lSzTp=gzCBBP?ajI@*InHC@jh*!Cg+>TZVrV2k|Z=7jHubOfVlr{ znY56b@OE5>}=Sn}IZ@(S|&yWOJc~8Cg>?$Q!<;4HF;uks~0J=gIcqXZ3b0 zynqE50S@CFI&^Rlw*0o%(rM<4&h+J$dZmZuw-1Whe-E7!zuLmXG|P&+RCCOC;WLl2 z?}g>U7wc_I8~>`a4-o(B8Y@nNUj829H$GLYo6tls;ciw$fGFuU#4CwANc_(45kJs! zeiL5GfT=vitq>LwZY=GkHGItS=9$#y%@El1^E}tt4LHsi6xna;vuiS~5I><`Y$eJL zVi6HEc|EuEXu?4Wyx+(=B(+b>3vAWNz*P3UI^KRi7-F$@ja=;AmJ-iDvZ7%bQ_~lj zG^Dfd(Ak?$1_J$9x#rB{bZ0F?-uwodf=0nW`hhUR0It}Nr$nwVI;AxCu&}j98#z7d0npZ8K;THr2$=Bp zg|s9YW1DX6Un~GgRVQh1te7nqVM#N~DP|iAZ9MTs>8S6RNU4GEJmh*Iw}oVVi`H;I$4q$tb;J^3$Jp$CmMv zZZm!b8UL-*UQN}EyhzWDUX@OgZ#L>v(0wmq679JL(g=aSX7$*;CEr;NYgsz}CT5em zq{m4gyH(fFP_#FAd@lp$c7J42p(KI*JLtPN3Mz&+Or@QM?W;W3Hl-p+=>3^c&9pzY zISB^+op;?5SYluD%1z14CemeyMY{ry;R4A;z(;Ev)LFWzoDYKw1;kPx1u2sY??Ljd zW^wA?B{14GBJGEV*DS5Z7VV{$Oi7%B8WhkA!ozS5L9oAj45bjpM zK8ph&n!sybbr3>7PUm0aMb6)c74pXzzLv@KBr$Sw8a#_VFaLXK|L-6(1RW})sJZHI z9evFXDRKdV;Y&z-F(_qSW7od`aSNmXC~~71g`Oz7lV#wAZ(MTKfuFQ-c`+%=j%4Lo zTr>s}=ds;-mN-WhJG6>c<7r#5YhJQ-<~9<~WFB z5@M7fmqk2aVcKMMVO2x$XnXjSNnHWP8Q^34_hihst*Fyfv7RpljROdKjac5aN@<0I zvs)RS9`>!G+*U&h$1=Fsv_af%?dCPBl-ZN-#{k4Do?CMn7h$$la*4!ohVu|?f@|1i zVDQ7?2INnbYe)GkJm9sU^rDS^I$44E^zL(YNba^qo8&=BstAiVl--Q{A7d6JI5l(6 z-gV&qJ$v0628=fnpfiBwiq*Er3IuHU6!xK?s5!rz$^f==;~7sd3?3$tX$TNQPaJ#x zMZIi|oKHd|8MXR3&d@8e0+r{Gy#yfS_Nc_WAlnl!*Mwk|NNAsw0RP-FKLrJ+QsNhk zB`nPb&~L2_8dZ1!a8U=X+zDD1;j_z#G6YS*`TA`Dp`Ss!eoh(Z%!m}0=VFT0V|rL8 ztMoR~1x)coH4wiv6Lb3#I~RP_&!bicP~O+j-M3wcE-fJhUS4N z7=x9bVv0|CbM363es6b{R-yd;wW5pLH5rn>&)VhRh(p@kCDE1R&2o9%ChbE_M{9@Cn@rMF;_ z0-+_U5tvi5s-t&jUf*}ySeiVgspk}~^|M({jGR?x=`!|q)VX_)7`l6*bG3x&*)gw_ z>FotTrGtD2Ah)vu#eTA>Af2nd$s_0y&`>s(pDWFS5@gQZ>xO=9OZ=+v@Dj|`iU6cj zxvp`3X~Od-ynloYkONu2QB}U)imN+KP&mHVgSHA3V!~6;qodCNvGxjuGNtYb+xFN- zsaxRk^0PmKw91Ph>-!mfZ9{HtQnU ziAiqP&uBOpN>y#v%D{R(B3}FU+-)I#g2fz( z^ydX74er~zPl1j5G_+!2Onll`fI!h}Tw6*P5HsyB%^y1l+Ji7~GlHTZrFG`B4~x01g+$6c8JJKxvGxo%KDr6NHPU$(u)>;23L3NPin?r9sq*~y$ zst~Ela>ZfbKDYvQ!9H-aTsow(1=?xL2UIY<#z5WbD9VKT!pGEh#Jy=;xX zij*k$mip!CzM#;OrsCwC9L-FP#?_UjFlGK{V6ZEvqv}nqB4DAdK&ZxvsN8zYNEZ}~ zWQY?<63}XscHf2pgT+Z!!j1GuYa|=NGId>o92bl~T(^idV7VnRnN@faW^=>wIPyBA zpA&aqOgr~F`<2c{0uC?ra6??l;OyLV$q!6I^6* zK|ktl*@1AhPGeC8`*q}~apdLyq1gqcC-}B*sj4x|2}SQ&qsM1 zetm&+A`=nm2Od=5(_x?-Q;g)(;Q|3wLH`eBW8g~_utsmWLCz8hqkO;l0VnjiKfoJN ziqc_aWL9e7;kG7;Ln_S=$KiAHHpG(p`lGM4%9lX$`xE||-F$}Nk#BO1?D>6o5;Mca zYcH=x3x2j72CK#kAdwLEgWeD6h&O^m2M4cHy_9_|19A|gkP{|Y9sO^ z=?46d!h?+K4`k+Gd5{ggVERX(obPpbC)97x{z?BB<~j@cQV2q8cnPvpWoT>v_sQp= zss09mQjN9liv$knD9&|_nE57XUKmixAtSPYiq**vUKp8GGVLN$4(Qj5F61}y!yP-D zmv_~Z0~LU}-NZN&*b$Gr!sW1q9Em#CWuO0-UnnO4k9@HBU+iRJR1VzQ9XeLo=b%H`1)XNrAYHId zq)Yc+&X+b~*<9E17+o4ru0v~~mj)2S91S6>9VXRcNjuO%kHX{?3c6)=$SmVsX^Vc3 z0uV{a_lCkzcYEe_IO@lM{tt_M-A$dpi}4}p7(CG+o@`Wa@MzDn%u?gz*jM3K_}M5u z{m0suNrRX2P&;b*_e-%oa0p-p$W-o6JENBDkGu@lIwLkWhL*in=L|z=pt1WUPeEbc z!gqTNY(5b-Zn<#6L;7^fGK}O|NZaLR+2u8kONVfe%ed1Tpo-+wD|ehX#dvRTB88GT zq%KNMT+H7;{_FAgm!=1ZB7L^?GQEG6*TFpw9LRlQA(A5nIjUfPWEaF<=24B8veL>8N8yu* zg3u}sIV^MuDXCJVx*4Un7N0}^bV|AKnCF}?+!Vy6jLYLWj@!Otuy2}Um>=!a!=Rnh zLbmB|pscX`e=nG?l%DXbTcME+edP%uH?@KK9FkN)e}-M36v%M-6X;fo==;c$@TYR@ zPobFbp1k#^0{+>zH_+Ojuwh``4HN@|C@l+kus)Y21J`&NMA{v$uQ61E*piDUFkd}b2DJQgEc8_l?=w$2C*}tGS zl8*ukoM#v%75>}X|8qYYA-b$Ku5hUEefX=udV>)O4FC1e$&g@5ufLZItCq2R_hN>S|2oY1c;oCOuXa%y zBxP^l`u#a(zqX%gE0D}tf2ck-Ky{*EAcS=rD$Y_anesp1e-;$)fN%)~>D_|-UG5It zP*(tYg%RxlW&wUk25{EZjT}&rM>zhm3CVeWq%{Xeq^l+cPO?=t^GdulAwk&RC#L3J zh$I)_f+FV7n{&Yjb?IXuY<@S20qs!9zCp7-Y_zU9m#BjDD~-G4 z!kzQc=O_DoY}%fZ7=(&W&(gq$WhG1sWLIw7Je?A{S!I{*A~I>m51XpV=4|BS-GV7| zh^!Y3H)du#9z5PaRUuf(@Pb4t^{qPqTQ0zTT~Uv_Ko7}%EA7so#usfLQfNo(z2B*-l1CpUks$4_$weVWDvL&~cv=3gmlniZkX8kB7djuM0Uq z!d)9M6FZG!zas_zTmFh*M_w5sYT8Yf3cXH@(oz7$Yz=F3hJSa9gSWvX{ z_~{enf8LaZ>1vzFp6K_&l@Hn@d0PW0GB?V;F0+G!SPd@Tb&wonLNx{dznVf{S_3Es52()19r0fwvL&r- z(6Qp*oZIj`qsi~SgNu`N$Z^CG>7yENnQq~z+9NB+6P zuRw)=_#ph!?~W^v2`|=WYS^Q^!u7NW6aAPr6pn(`fPoa>S7=!It+Rv#J zH921l&u@}-8t#cJ+ez%yIxnxNvhvuYE=%ACY5S4?&x?hDqm0qZ31V;zbf^9!y^(_B zY=NQ9ssz%I7P0T@JNY9N`^TdgZ!nf8y%||oWNLJj6RkWz&*~hGT{btW3g^hS`>=e|_QPIK zFZuYh=?}cL`|MXRFWxftT&c*S;tYSmZ8OlNgG9Z+w5kXmxUy)RPO|X4}Il z8OV=%onrWVsQptzVJ?Lwc}W>k1fUc>;ZFD~2L&}Dh8*lf3wK&VmC(dh4pweuz`W2( z^j}Ws0^JKYEX=wOQYe!HkoE}o%m9?>eST)Ar()mS>5Sr_k$uP0NTsgJUdM2&6Wx%b zRljwRe!SI#D#20SomR#DW-{~m6$*yq?2@P1kEi+j5UVLr`oCDI*BUUce-UT&e9C&; zL9a@}p-D&x>vf9C?RlkxNgZ5VDe0%8y_BKyP!-mLB@G87++O$kD%p!$>` z;P!l@(RuTI0%UHqCxqq5FkDaU1r?OUqYVISQA{!Gm-nR?Ta-OCs-=oVxCoIqDuftx zH2Dlp=zWy-UxEvh1zEeJ_fww{@MI0WGQaip9Of$V|uYs*|@1!(1V!;=yRQKYHO7o5Q zdow27yK36;-nW^EL0)9hCNnfNME40$QBmP1-thU5wrAKwn5A=IDaZ(mJ4TMu%I-ck zH46E~h)0c+7A{O75FR}r04qSPyBpG5j-0PTef@s|4{bc;aK-?_;@MFij+j5l4OfyuLjxK zSvQrDm27S+l$MZnlUeqP>~*75vdPL;_TGE_uJ?`S^Yr-~CP zexCTvn@v{o8TPs%aEG$(zqKi$$&*dC2-35XwkZn8*ZP zdc$6usP3Gd6^4&HI!$Ek2MUdAu53%6qOUxV6mXo48B5WFh*FfMSB643cs&!QAme-e zpvek$bTHqE5ofC(Jn2q;N1}6f_OTKxQ;NJRDXi1U5F=qnRf;#tMmtn~W``vmVT;5R z1;%5>tZ^EN1+H=wi!%Mv3xQq77v}?C>Yv_Cx`#a=+muJ#glb*D(r3>P*_{g7Kv!=U z>P6GWoR1`Dn;PI!jrl;aES-FPXz1s+8mpokE(F;5;RXgX292*KKGIVLo+xRH?@p+H zogw*AWUb6+uGhi{ga;BPBCySnh5nAC{aotyI_o_a}qgGrQ`Mh#Kv&QaJ3w$b2U zD`Eg0u!}#S&)7&M`V4#zC1gWPrHstVOdkeR9JdV&;s8hZzIC#4?!|EigjePJV1?Gw z3k6mI`S>m^I2wZV+Y8Skgd7x*Spd`$?YbTrK<&=^RNRLdC?|(v%8UK!LwXCy2j+Ty zQPPWV@fVS6B6CKvLVZGdhh^4?gK5 z)G2r@A?bUT8Jv2SPKRq@Q^Bu&9^SY{m0nzFYFgV zagt^57UIJCqO|+k5LMkohsJW#xX-Is0Wfw%+sa3a>LP~VO&@7g?~cNzxTZqjq2%s_5n^4O&TwjIX6-^K4L$!Oji61!~ZAn-YIc)(`Hr2NM-2OdM=A zZf=YvMw|a;OJBeeqG83WIJ&GPXbBfSaIwUPDL(l(-i2J6MgbXDc{fH37;wzqe;wX5_#TCwY z;hZIq_~C*$j2Q>%?8lP`C`JYDITlXTKs*`RizldF5x1HnOEfGt#@|K6DZ64Y8_hk1 zHd6D3%#*!r@zo*z0&|nRs2}yfo z?kLNk9l(!Ou*6g<+PQ*C&*h%>Tn)Cz&HUt|%;bhUup5e4FD6sg_^3~0iAAHqr1KrJ z2nYP*es`O9&47tmQf5e-(W%AGEW_7HYD=|w01n^IGw--s0|2EVv8Rds&(VB5CdH^A zpjZigPgD<@0D9H{f2hMzG6o-iJ2Dn{{{%}_mmvKMQ&oOJj+QC5^BwF`u3 z{Na)-ru&hhTm!D73uBetUbHQTK@cX}+hGHKrB+L2sUvSDXL{-eKFK^wK~-tXe6YM( zxyDgSgx=cj5@6n#Po*13*32mKIZeM3o$IW_|Kbdw=5vGd5C472yep7y3ryVlrx zN0a;;fiHk$H#!V4DQPeu0=Ke{GvCP;9Q#XhZLxk zT3l#p8vhZdgav#PbRY?A&buC=?t$8^4N}yZv?J6;uj#b0`})#*IdTtG9_BLfWoU}=>p!j z{7~W)1a{?_5Ug3>h-;vW*~uh}S9*O!v>0X1sXCFqZp}5`=hyTLbt4jnPaJ3uXATJA zkwTKYnGy2UD#B9WJXv^TFANmG2*H{vTF$Fw-wWW1*s@)t>{bci_%8({-mSrT-yg~A z`5C7m8{Pd`Ln33#Id-teR9^Ud%X?Yvv6k#uDIwBzIMhXp(Z3BjH`U<7%IZl5aHf2> zZ9NdKmHHz=SRhq%S2{7_$cpoV=evdAlo8MvSzI)oI7-I}dudRsSaOTdp%-sQbz`)0a+seC=R#Nem|9EbyQ%-XXfXuRe-gNt~3j`&grRYIN zy*5$H#esa;<0j=Dj}X99qJUy?}LDtAq&=wi~xUG z6~^gBoiUflfKj6~D1(L5o`X9M=w|01@|hv?``U26hFCdHB!l*W{zLVnPRD~1N)}ar z00td%oSPBDCnvDR6p>j{ANfg)9;%Pj27S@A6>aPhY<~oiQ*tJicK^{N{h0NJ(&W~e zmJ}3=T(Y_oESs;K)RkC!r|N~$b#@4-o5=t1+p>%&vVE%0{H-y!F?agw3{PqmW7)4y zl)Nz-Wuei0sJ2Je{VvFy}BEijF_6i?(bYN4d0q@43m$?s9_A1$DRh;*g`_6UQaxrZ95rIgId zUt@dHBrI<+V`DV#_ql`E4Y5D*s~m?lf=0zs?_gBa-9mSsW zyvzjH>&o!~64>Wq^=3Qw=w)J0k;=i@#mKDDSWhNdowl2=wfb4=#MFGAe36@xs2aT8 zBm~lhPySV4?AFLh$kSNRCPcfsk3TVaBJ)Okt9C!WEFE0y_xZt@C!H$!+goFN{`bcG zkas#EB%Go#R+F9rY-Q+AV@G9s+HD!6;`I}yS%eEe*46m8h;J1LZ=CLt)Zd@&Rpl&~ zW3Hosgk|5Ymrm3SENVhxVw;>#1CdWqjT7C4HUZ3mGwSMvE^AqFiqR;WJEbIOEQlYF zWg3p%QqO{MW?}Z(-oPi529#X58lKuqS9eN#om|u4O2-MJoeHr`vB|4RIN8G%F9C<@`*XKtj`QO1gFF^$EvK@R)-P;=e5RIyfu_k zov^y8DOXdwsxuYrDE{Xsu;n-{L_PaiUz6YX<=C5Hwl!6mVD@`jRP^{go$lwNITto> zVCq-atyj=v6umqWh?LV_p_gqVZGOx#bN%40mYJiext>Y=z zSzIXkN{LKFVZTO;AvrRqLuTOdZw}2I_GBLohJL2-@o&${iaxBf2HYLXm zT+m1j=nhql1*>V82qt5FXKjR@j2;`){*>@U96XMYM_b55804ALRzUJ2qbRd@tXC(y zgO>zqqCyqRUXpqxA>iFG7H!!3}fX*Xr=qqpC3 z#kxKbB9oH&?xWa2=ClNCS`6k+(BGs4H)J5-cgYL4_v7frs3%$ix6}IWiLRkT^u?GK_yA44hZ()@Xz!Y>`zzdp(!}`9inI4sN z9WAxnh;n=7~rrl$>uCU zvw|`inmrVIdwu1sxnTU3x7+16+Sjj`+Ql0z%6+iKc4fvlqZOn)p;fag?buX)>_F$(pM zaVC+R#DY1a7tgdmx>hpi>!m$7mdNuBy6)eT-+pF{oa?c&9q*M8v^7iYQhW{`o172m z(SLu6JIu_7{Ib7<;;MNHRejjfG!NsA;_hr9J0N@n+S{ezXEQlfHjz-S%{IE>Zd519N6!%+qx*apJhI{^}K zP``lSBLjD*3Nls^&vRM$Z@Q+Qzt1HTU|Y^!OKXx_p6oUjjMTej0bo@0mGlL*&jrU% z82$`3@y2M+u9-8df4B@0fKiYTJ$GfdFWbLNyCeU899!nrgAS-!iTl7+(4ldM^GUi z#lC7-2ovxj6uIjfTozzNcd9{~1zSwY71<{0CFY-yo}RvKj8=-o%l%+*_pQdzX2*gc z<%KaGVP+w+TQZF=w@Z$fXnyf%BXs7WF<4j9_E{M0c?>`S_nSPp*#~ zkK$c(05w-U;)ylI9?c>vhz7n(a%BG!X1Zr;54f~YA;{#mi4&*}33q%^ts0*+km zODlBN71;8DK-UrQ^}6>w21B%@8$iJs#2Z=vCF5j(;l!zTZEGL;8}pyX51*N#L8aVs zR2ge6D^546)F_VA>{0h#N>SO$+G!+QXswasyeT0LxF}K5*!|4iqInfzHqhe7X6Opd*NL)s3j-j~Tz=TfrE`JLlBcAlH@0gC+AumQ+g%L|8r%oGQYha;5>5DT?lR}C?_SkHJ!$9BSl z$PzgyDM(YqIVCdO6Z4(0ZCN5ed3?r0H+SqyUv+}6Ky;vYa$(@OmuSFPpfL@~VeNZE zT%#<$1@B=jNpgS-NTn|BoIkL4gG52QF~lRov8&yHFCH#I&Q#@eS4_LdR-a4{3#q*2 z-Fzi>YC&|2iLMSYhh8b$w|UH!61g=QL&Nn=X5V+K$I!`VfDNYQwl)?QP ztpM9xhYx0h?OC!~YKT;^W(A-h$hJs2LtN#x11h;V_nmFJ0kvpZ7XA}m^3NQmd$DH3 znxN$Ap5=z$+R3v>#kB_tt$d4H;u~Iov+A}~#uJUxSRf=C*A2R~CoS=^RAAYl;iEv9 zBntK(3hf{Piyv`zl^glk&`%a;G9j5@XT^NS-6==){k-+qAH$v+gSx0&dDA_kM`=M; z7yp3)gkzxFX=<_t<|iyc2nRnrjME;%;Bf#UoyDA}%-qj89*#&pqvZt@6uhTpt2~ED z+>O#CPXq|wT;9o4QBJtJjeWlKwm9mO?Pu$;*7#_%*sdli74MoZ-c&)Y6R>@ufLcft zzUz>R0tc>I*pEwl1|j!`FX3LC?MpG90>JvVC4oP5xWUE0GCZA@cJizAxnG_r@IMNK zwL6ZXh}G=!bdTBEKq6qM2%*Yv0_rXpiXig>)fCmL2PLM@;^1FyUB7-EL7RN(L|E{N zPCgDwZuo|*D?!Q%fTnG8x&TV+U<_!RfSIdU*Q?hM8LfpKA=epzDP{nNuAo$SZ#@_T z!~tsSI}N;3nKHNUW#zUf3CfSbt!5g7*F*WLrAae2>*~+fS@#DnG#3?|ob1Zh8!zi= znWdPqBgU*0)0m{SVzE{O-Vvg%opX(#;ig>LdEK#pSv?AniPT3bz(b3gf1H0X82FM$ z&Qj29&X$l><7W;B2R^CU)j$y;`b`hvbTeAvPZ|IZO^w1si8)$8`(LmMZk|gHi{{92 z`SGHF2T<`VmoJ^lW%1~2A-WEGK=RTr1r2ijS@Qr03DOw5bTCVBLMJe`*wc?D_Oj_h z-B)L}`!r1POAx5ow-9j@1v`p5(b(_M1HVM0R^+iAlc3PCw}c@&-9?E$Xedn?v9v<= z{84JIJ#<|9PuR&vA71T1+yT!DJ2om&9zTBG^j4<+ucK&@&~NIb6^zykc?6__&Tlu$ z{L?zPClaMa3#Z+;^6qE4I)JeX=yQTWV~)vvEFS9|aLlgUG6B{2CnEW*eAN7CogAQ>t|Pw``i+83$kq;{b(@ zm8dl5=n&b8Ml@@Ev{?PXHW;fg0+qF1-g~3ZZ6p0MU%CarbIRbR;m3hH;GP)Dd1bQh z5+@rcz@^zminZ)MLdorG(=u*5&e=gw0i2P~Y#+7qB#ZKP`X#*TQPC12@h0J`Yxz4+ zywA(@DpR-)^obX?kw-Nwq?!LJ?TmjNgCQm$c!WoId=b`ae3){kUc$~pXWlK@kG>81 zOfQDr78VZ?hn=fy^ZG=wh}Qp1w|5e&3ecL_)f61Yw#*2B*_}A^l~~;#C-2y)P>>Ge zes)jSTxg_n^KI6_EADBv55on`KB%Ibc}d`*-mWd%fK0ct$tCz#pDZp5-!z2C0{7S? zSRc>yiU2bpb`#8N@{CpAM<8u1G;DM^4H2Par+g$bAXmNAvXC(YmQ*CHDCQWZ;$+S( zvao>R6;*D4zfJ#F?61oRMyFxP#-dY@x#4e736#qqn3E!;gsU5oXVgX$ZoZ_j{{01i z#Hecbgn>Ky6u@KvGqm8)fIAiehrU$s=7R^HrRQM;b7WwXoMY|Yakvm zmx8W55E!TrgYk9FNMsPGzzCPuNx-2Bx#ReHXeD)Xow(*R`9_v_+9KX{FOxg%wy!xqLWW->!elLSt-3RsQb&&d&cwZbH;-!4HSC>R0bxAly#*v$i zHLRCj?N<%3u7vZo{*@ssq&#^;E2?zKFVYA-5EkvqX2=qBE&a8;M0);s&F7>W-idAw z+%j~sOQEu)5}Z`NPXr-X+-d%%zn?31$GPF~4MoXu-##Fy94NI@^WV4kD$jr*iR+@~ zT+nMq6WPj#Cv3@vW@d6BQ%Hl_&?c{SXfm?WZ^o&oa(4o>JYE8b1qh(b8+Kmm!TE| z53XmfKJ-UW>g9?ju8g1vxBFe!8_3yaX=TyE|EhcEnL%BMW-QYqd2G>Qa6Ga4X_Uc4 z@0$R2DRm5uCyL|j*#g2JU@LaXBc;)+)~S>vD&Pg&I4`ER{kX(_lN z%S_=PDrl6hl}*FbYj3XT-w*b?r;E@$-?3fA5jhU0K3Op37^jv(O5E(C(k8*X(IQu( zSnqnFkh&Q%mgz^y_=wh5s^GkRlJcUe9WqaXlDise6WUZl>bmEe2$&&p zN#)ugd^fjaxoxwx4qGXqSJXsuY)>U8s-(bWZ3qKHvJaZTLO~9^wVE4xm=n1av7LXi z0x-0r@UVjsji1|P@DBjxvN4?T5Q)?5x>PT3)w{FN z)~7uAY!@-sLZ`k2H7a8)lqggZ832*IB-)GF$rCBpbxm!K*%wls0HFH!?kvEg9FcYK z+YAXSd!p)TCie;E8A!Os_<$Kt&nK7B3G3k&OK(0hY1#`Le$)Oze!}bG`9FA+6$S3ipEoa@qC<<}Q{=EI` z<+=;-4wREvd_Yzf?>NUxj9#nQf*`{A{{4GV*KNS8Uu#WGykv@M#BxY*>b{Ga`~~4a z4%`u#&T|SOlLTChs#qLrt^qJ^QupA7Ok<7R_@nNE^UtE$hdYz3;;0IKUS29PnMlZp zvm03UUOIFEEn6|XBw~6L?d0Fse~}z2Vdx?)SRH^$Dgk}$pzW;&-4(TBVOF{v@_)r% z55y*{(P*-74g~eJP7hl4!UOCl;hf+Zd>zvs*Q8?Dog|&W7roX|shn`#q*;&^U#+^` z|ACa>MAQd7Bbb!Sz>HZK_t}T#4gpOhjOEDR2oXO#1A)uGaOKTWdym#sjT)F*pY=<= zj$)bocFG;^^VvmZQ^cJV1e~pTT59Ge!vSR&f71?IP#+8QgvVgSx_1Cip9gnq>W`B{ z^ZG{!ppW~Mj$W?iX*jL=65#Pc!fJ)wEDgYsU6##(A(idg+rh~XUJy$3?j9vD?2iDW zREM`qsn?wB3vzExcU{!X-gvZK{B$L#wWhtO=iAcbk>9tgLg$#Npobk z9}pLZX=#?IYw1zQKHzd?wt?@nL*tH$a_7|qYT4&gdseaZsh1YHGV&lXdBQ&v!V13m^uhJ zG&8VcNNb<3S++olkvL-{H<}V7Jt`BaHXBc6`0YDEsw)~*wGl#Uc1@--M3+?Yw5Z#Q zmXsA=BS_o0r_@M{hO{c9nrQ~u;fGTCpIwuEWQ?*Qav5Khg+QF6b8m*oSz(A+T!jo)0>MQ#%)~B2FyRid^fsBcH%k=k(Nrv7p4$-#cL|g&TTLo ze22`SqFk3XtKKp^9}vSnj%u@M@_xdJ|J|2%f+8~lLF)d*mc1COu|hku1AkR{cUq$z z+cVZeQ|J(DQ1`YqnbE#@uUfbwEHJ*RAdTUO{O2ei1FGPzr*>EOSp-bK4Iakfp-PDy zsC$3JWJ9~73^os5uLyMOS!YEb^7aAz9S!YvdkI&IQ!GdO5WZwpbYS+L;IodX6oKoEN1^?`Fu; z^ICIRS>@t2&*WzMv_)fX?+4h+%XP0Srn2GJ(vzY!GS|wUI5*a~el=Q{)@c|icEobK z>Xc?=BY)jR@kPdG_37hXGr4OkJ!?EIE`8Mz7r%}OEpU7pDPF5g$9gt66hC#yAe>w4 zP0_Wf3K^lEIdU^tz8p9UdTyf38Nn&{Q?ydIg}7bJC*A4?UFIN9NC(gJMurP7Yde^q z7g-0jT|OVY@~B3?Bf`?59L=bMjKfLAY?mmlW9B9Moc#hPuBW zoY!E*U`(=^#!wM47!Fd}D|pOwXjM0n^>B-!kFbg4bK_UK>8vuEQ%vH<*Grh0Lwhev z_E&C@$qe^6v#L?ES&EV-I4Qmm_+;(()RDo_eCBgWW#33~E~`ppDBiRL=gQ2|s!RSg zp|1v3BQKC&P)wzI=YT5A$26Sj`vnUW)?}}tz ze+vSyLP#~qYJ7oY403ZkQ~V6TPkzy{8rMf~+m1HsJKu~Vz!4f`JwU6Y4wrH$3Gp&S zpc1RI51~5CF`jIz)1JRyuaP4ZsYQIUmNWI5*wJkPJ11;0i_uW#$1AKas4WIW>3}#| zmaGlcHs(?r?Q=5Px1(hr;&0ylbXTuJPeGLx9baKNBS@42j&I;1_bL z7dO9gYh0#xyGypQrgLTCQ{;(l*OkrHYo^jK+?~_Ay_}f(jTh&}dn-qZi(kxr3U2uJ z%S}wNvBvRO4%(G^aou2^gj(RCPn(5Krx8cQB|{eP!W>P3JFM#=PK5yxV$P_QoV1SY z+k>U@MO!IeF_`mydJYte^AltD9`D}CU3lQhlLV-XAZX$eyr*>k`qBhK;;&>-OIBuV zvwUT%^OP2sWu8zLpqEm>g0c6?_A5$O84%r`t@Ox)QrZ1Gy0i}#imvOF^_NBcz@A|= z^G!>zL*uiL=LcJ8+mG6Hb&^GPGo4SqZfgcJ_@8&DSHYahQ0ME#FtLs7P=5`puzYjx zA+yykLhuz_n>*`!55jkKUQs);!xT8gzV_*~E3I@nrfGS}d`3>i!OSX>)gkN?YHN_y zZi$4qvV+XvXMxp&3?@s?e^)t9-I5J)VVQ2cb6j5MQ%6uR_kz*u zDUnbIOI!LDt{F16+U5U>HGCUMc2oJgv1W+wW9-0>X#8$36hLJmplaN)A?c_tSyf+x3Jnm6G1QOs+gZGHY7B(v!-k%c+D?UQ4zOMS}288^J--u5? zznjvUcHu#SPopph2pelSj&EozVN5hQbQ!%wef&qyH02TFn>E&GGPAJAKyAP&WPkcO zsVnvBygmNm5v#2KOvA~Y53a;$fCbCNt5>fsQw_@PBOK=>``7z4hQ8>HWM-c0N4Sww zUBRX`E6@FCe9s}H8nL1q%A)fx?nR=Ezm_X;Kd3gxKGNPupfwxp=q~(!IhhL#42CeH z^_5b#9M-TLH^zmV+eLW5%*DVzT?go1t_{H|nrysN@ZM~y(smq~y z>K3S!c8g!72HV_gnP%=4?-Yj>LWnXjJW#M1{uodj`A{eur?^7!+s#f{84y^j53#Yg zo|U$ks)3(`7glpc4Zo}fhKSrfG}vR{ZONUp*Vlm!8N=g&@1z2*>p25|01FKU0vo^Y zK7CSvId1p1{C%TDg z+8$nVwx71Kv_G3#mRB^hF*X09YK87F_N>9D-Y_wn_avjdEkm1eCl{^VW}&31hPU1P zJ!kHCLlmp6rv?dd9~CgIa#~Z8B!v zBu|^bO1B=#T!Ybv?P>|aukO=8LQB4D#xNQ`1&#{Lnq_WSi`)!AdQ7Oc+Edim=8c|l zdZN62Z2zt!affjbz8bK-JhnSZnSw(8@(0goLyrzxXpRd zEU%WsLJKh3996v7H!&LewI^ctE8N}J4yUR!%f!8&BgF%A35R^ zq{kV*ys9ZNP?!q$e+}nw&!!Bg zn0WPvVH8N&PK2yFO}1RT6Ka^_1JB>Jbr6@Vg)PJ!-->fp zTlj>?9^y%ZeGh0EZv!al>gM;f;G6`f&5(%DigX1{ogA_A^30|5FLe_w&lGmhk}M~k zD`wxxYx9MOdO6dy{Sv~i#TTcj{r+@=X}LpjpcQ~}%n9#jJxK4OzSariL-2{NO64I$ zJAJ!({IBj7H$m&i(R>sYT>wMB-0NEWX|m6z{IU4Bam7YyABUgtdeXf}?1ERlt3k(? z;N=^2?LAVZ0k9MC;c_7#1z>1|5il?J;L-5;8|lhK;)m=k?tTEL8yS^D!O#_z0pd=? zoDzR3dtM)R0egv)+{q^2dxK|;GWkEs|{zwn^%Ami1e%82CTq;i$cZopo-0!;0r?(<(W@^;*|=co!@#=Xk9^XO2P*DB#I#jCnqB)EHJF37kF2Wza5H1{;CA0}?v230gm3$= zT$>uR-VI+`k&oJ#u`s!;t#TtLSZHl)79m6ds=1wQSIuneNFN$*h5Bfb8R5fqTb`WQ z%junb!R>wMx(Q?E?d#MCJ-22_G)zd7r36(UUpLyHTGthCSYgk>sFfNPOy;UvAF*B+ zUq7io`(u0lQJ|oap<8S89SP2n=f1CAWu-6L^*~+&W;d%}+%qV&U zN64X5aE?)r%~CMt-BL=;i@6)}6M5!7LD%Gcc9JTRq$>s^J1xUeH7NM|9LF8OwhYg* z{`8fu8U>n9kpWAGXT_zH#CCkqEWNS(_7f7(89;>f|5@g-BoC; z@n=$xd`Z1%O*f06)2S{1fz=d-!E&fwyyqO0R|l_ZlKiu%=!G$SrOji2XEb*)oAPI> ziXNPUpnFPCs;XrEwB#Zwb3|(Bqq1?wsk&F(hHZx{XjT-E3eCd?d~qKSzk2)6EyKOC z_VC^I;aN|0{hx8tS+KOKYk?7DOCSWK%4w+f+DqZc8UVzo;mnIB>Lhj(V{ONYy86GU)rW{PHAw#<{J(+}U_vnaU?d7IpQ3Pl?iyouADc zD4wXtRvn-o*omT-{D*_I+XBJ0_^JGUx6?^Fe2VSvxrsIr$Lu8^YM#MQm2|`@#)J28 z4b~7HI}v((#Lia&6+Yn$AY_^qA-=pBA1X6RqS)W zs_OXEp@q<$oNJUvH}1D^$<(;%qdhHoMAq*FW3|NWh9^lH~8C zliBtk{1Mz^b@nop*m#Bsr_{mFyJ_YB!3aCzJO|U-QDH0(vM^N zu|pqx_-YPFC9;Zw_%CF)6W`v#KC34;uaj~UIp-(3oMa^3l|VaKL;Cie>+<&srWNTF z)$_|YGEhiL9>m-NXw+Kk^xwpMeMD2}|6BoinjZS(ccFHzjTY;hyA14!5aTGP?^dh*}+* zUJ9wWId^<H$96<-MIy(UtFEjseTWadNJrWXwh2_Sidv;s95|2M0FpeXj zUKMtp-UaIZxrMZgT7|a3wllNu$I25L!Y>Tzu}&yC?2fCY{p?4)K~hsSDowV{A7T^T zBr3tlK)Zv?P5KrpM)!i#-YkF}-b^i)9Nf!yIJg|)m;d?t(${z#_)jM<-zopkL7m>? zfh^DN&B+g6JPZax&DTsv@2)w1Nlaj}*nKUBl~>AkwiJO+V0!u_>m9DLw9b;2Az{m7 zP$kAwx+p(TlE}MdHY@`YCHdz@Ii6vtm^U#P=F}n3ScFvZ?;AMTLQ)aVrclO^ z#4HEMt|rNcC1Ivt9?B&7clpi_@3nQC{_b+Y%lJS{yaco|Yhi5|2L{1Tvw!{mK#c@K z0a@Pjh9KrZAL>}F;M=kZ_-U@)kUB>Hc073eT~tf9JD&+6{joFh#puookqDBVPV3=d*ym9>8r#G-8wiY3qM|hafeH#UbQasbp}JzG~X?<3C4p zhgCRA*y-o;EEvH1->dM(-@O?L>O(1aCP@WANmg+pSACe~ildKQ^wo;ll6@IeTGRNV z9ayDoZOs2%_=c3KHB+xnxoe7Q4^ekoGg}KGEaD)?_i5Uxx7#K)FIx9q^0P+*6cwHh3{?wMn)4gxU+~XJIV~o+^yQ zA%Q^mort0eClB7b4}#Dw+ezbOon?!SA=(UL(P;K>GAHfMIv4ZUuigw3wtpwJR$uWp z;L3ms#}qQQaj}MY$K<1fTQUxB_Wj*ka`69jOQ4Ec2i{lQj|E0K0*yWXlv2LSoXpr7 zioWtwTRpJ4;~bOkq;oXGP&DDDn>KuN%S3MvT#(oUWDnpKL9JwKGP+wpT=Lkb_4-|f z!EaT-P#@nTA(~xK=(aK*h!Uuls5n~HaV=cMCa1wV1Jr&K+18t*R!*rdq|GNbNVdh_ zOM4eGSK#wBb#I(kq!U&*3^2CJY)fDNhDyF_olv{{q(TwyxUU-dk?qS2yn zPe>)TllHBdF-=Dr0~dtjtf&CAmM0`R_IG?=L2A8)3ZOyh*Pn^ql|7gDLBp*X{r)LT zJeBMoHc4Om*kjIL{;yX=Tle+4g<)3+K^d8Yn1CxJgb6S0At3eG&&-d9#$Yt$+%$p5 z!^ZGIrx)MKMGM3F;hCnXdDw$a(>7|iE;!n{ zu+^L8(lo-6uTr)uW_v$cHawpI{kz~(v!qO`L$7rq3=XM|MV|$31t9~~KH2d!+-Q%r zdoV%g8hN9%H*1_hO5lb&*YUB}(mF2FYTNVD_D3*A%Oqzpg2axM;vHvgX6YUo``#NR z0D1qn+!9$1cK{%Z5NjZWb!CgdE-ujcmB%c`Bh6UMvkNKv{=F<|T?=E%?pBXk} zw>_P3-1y$nSmuhMsds9mWbR-NHP&P#czymSfZz@Y z1%Z3t3<4^(U!cQ>tz;E*mqGs#Vm(wXum>#P7Q}EimY|ev8a&yFxTb zlAWd?2S`l6XRClT5zS?Nd_UTs(}$>lKiP8S@$OeP5v#|gF{Bh8Ueb=&&Gt?3+BBFp zK-8P}bSpnj4Q|oVl9<&(l@Ug`W+?_Gf8-g-zd6>UNXZ}g9=8JR-00ztdS!?`pOtVs z7}T*LT>Z1zxV++Tdl-?fh4We33(rOf*jxnS=`MUrxzX|N+MEith=KVe=WJhCa4PoF ziNwk&LC6))_?#6>@0NWV>L4mz{E+sm;75F2mlq&dGB2sBor3%}9D`Jow-aQG)f>$^ zCtuE+4iRzKNfJz`TllYNJJ>KQxkoNHzI5fVs3+L1OuYs9q;ZvLhgVFQtnmR_NUkwe z_mlRoH27#jgd$Mo*srCo1KLY~RJV_MN*FnZUB21+kB=lESnEXFt&je~YYV5_LJ=M3 ziTOT|&20n&chdOWfSbur`$^OwcX`+%US`395hKUDpRn}mH0tj#HKgL6Lmst>o$vqYeix0 zWA~xvf^6pu#zD0oJz_BAx^Z*R5_eXw#*TI2sF7Lm^JYV5C{HfSQ7^DdiuQTW-`yGej#gPWC zK5gaXIdXbDPz+$eo&7bM9oU1Dy;vd9M<;UHg`JnvY-f6rWh4)szmJYS1RA+oqi9Aj zfp(U`OQ67O!|Dk}Kau6oV3j}@#P;Q;D_e3pm|enFsfH)K=SL~?lbwr|E!MVYoA`AH zivyMB@1_*E9)o=GmUa_ z{bBWEr(z>Dj9{#nACoIf?2+P+0l>j}ew^BM|F#{#lVF1!^@0f_wL5b^cO1#vm5QS= z?Hb|vvNhgQX=E*9KGUv3vf;Llg-N4m^$c+#Ha|5NGh zI&`by*VXr*MEl%6z@!ZQInAlYNakhc{pBw7mY{k0?Jxm5U#YE+GZ%)E6Qit6IJcB< z3Aitp$aeL=-5Tn-WT(^Vb4)OAob8Od`)_5i)qVb&fZ#tTCN0$ud`D!5i4a`@P@{3i zLVE-b-Wf2f@i0teR-ncv;jJ&N1k0a$r zg%^1KSNos8fiM=gWD$oxWHH)nMetjRgE=DNXNa*srh{L)`iD47#5pha6KqQkfZ<1_ zRm>CwoGzcp?Ij=8%gd5!oZP$W3huU%E=8n;*prSmyJd?*Tmv%-xJZr z;}Z>I)WR|79F4uIO4fO!-Xea1fL6D2@fftdOt-8tvyKbUSe->1{AbqrpCdhbP9HIB zG31sWh3t1PgTvKu%#*T>;EavrXACNc@(k6tah;XGqyO7{=QcNGW zX~0b|IOPQYkzA#uU|JwxpJ-GE%lvJnsVUB+o|!mLsXeY}cQG6N1`Oqn+$ogfT^m~^ z^!wzpi|T!k)y=qv6K@P(_nY3sVdQ?u>eR`%C;sw*V|4OwWw^zn7!F8Qlm9)!-K!vA7dK`pOvWb$lP@OJhJJVO zI5z&?!n&4jNLR_00QzgFgmSTjd`hm{kpD8uPyvp zWJG0_d6Udq|CVUIeBRc`f-ozvAp=2ZgI+A;J1gLFGesk1qSst#8;l03-m(@;>0xH0 zve0WvzBlUZ6i!Dx7pdWk$N@E8o$P)lFjtXiC>x9$fY{1HWe~0>dKlS zynJO5r;gyPBmRkkrvXRd9(;cSIoor0@w+valp8wA+}**aS}{|UyCL`}^Z|do0XlXr z0}-;ya6>dFQ2DF{#t=dGHH*#{TGGB#dLHoCWoCNGJO#kb?o+zksUM>7zlk~{z6#Z) zn<+!Yu6bzRlkCf{h8A)ZUiihd_vLSg+~*8xaycZPV9Z^WBQP-Z3^wef#z&|7qaxQ$ zNi|UEI;e!YsJM>&y{;Wr?f{E1i}dwH**0V0b=8XBH>Be1ydEP{5~DKBMziLkI>f8P z&&GI7(W_Bn1F_#8Evred8qr{gYV!3#SW7Wq@s{4O^=Iq6Q!SCE zRycZjsY?vQ6ylC=0ovHIg6+27@g9SuajV|gnjPfQ6v#Rvkw*v%8O$l^nL3~tWYyHoS`d3}we+apo z$4{`1=9+ELndlThi$5skbA~OqFa6%w+1UG@C`rNX9@wW>HuAjb@5WgRu3e7m!(Fxu z1s<&PII(nO2k~i$$`|**<6y<|3;vQgcMAU6eUD*VWE~?3F)W4a4MqS6h~v8T%6ff4 zCdy>ac4I$pfi}xB8!Rc`DvGY%1og}CeaeSGHE`j1Ky2iCX-n+%pQUeFd*sNX-ZUea2XImjZFrm(*@wQLh^HZCOrmUB z3}(Hjp~gu@X7SC#3OI+IvS@-tyz5=IB>Q^St1o^o}KCkc%12mG#B1DfI!~eE7)jh zdJFUtm1*<2axfm99_1!RCttVhHNUekd7zvFn6u>bt%Rf|ql*a%7A zr4QyG%(Q~GM}0haWd3s!(FpXZ@D3I_@&1&+dRKRc5o2j=zqQv?RoXB)t4&zEQC+=8Y>IAbW7ZC9wrt!!wqRMaC22~dzkH3YoQ5JSkTm)xnI$k4EU>{~Z ztBjavu><|~aQ9&AHW)M;KxNL6%M#RM$Osuov7wG zX$H75HYOV&oMYtX{u!SRb=>?H;vH|IpxF6v*pxllQ>>x-fYZmKTH;_Ylkt;{T`-DV7+!NvV-BZ(oOsTCgs-F#UovVm z>x?llKatlYZ`WCH6&d?K@23@}+pW9K0}L1D;C{eeE>d?{=`%qcZ;;Zw(z7$b|^a0wO`*C6J&BbUO2 zowA-jfBu>%C|N&rW1;Y*5Una?1l5=!SYv8u%G)bXB1V~@!pio|NC*ga2EhEdmUv`K zpwf>X1JoB;7-tV?2AckoO@}tGmO9)^YfZsbo+vKn`2)Njd+;7=Wh_ZdLzT8oF^iCG zkZs(d$$i-0ar}A6w}lXynMa4@G8&Ve3C0@JBaug#wAFVT%koA}@uk@hwsoVN+UT_~ zws*y<^$}=tind}a=u+bCE8fk|vvR&9s?hEMsXVLfFee5$B=34N5JgRIayp3iI_$%L z0RiSE!VsN{y;Kt-|F*~@L!{3#5X4m4t`Rp=nMQ*jE_Uw;#$u4l&Q_)R&_A*N6RZ<3 zjGX@JSN3D#o*GtgE^$HX{t}NG_H}A>ZU~1CNU2KFH2ypkTdhG!)Bfoa^Zx~oeI11s z`2%!lp}FJS&Batx|BD~P%@YMSBfjjnZMQd_Do?{(!y$rRJo>5?b(Spnwh*$tjy*Ik z7mW!%pAIXqU}48}D|0Cci^;Q(e}qPt5C>pL_qtJb;DkK?gTe1Myx}^XSdlvtS65oU zF4n`-1FmHD@fG`}T#Fz6=ivfSR(@+4e++CK%t>Zt{!ULPsb>d~4XBk2;7k7harWME zJ@5bf_>rP;lCp);lBSYMBq^GrP-)T<+G&>(l0w?0C6Y*~XfI@>v-_H1h$W)k zPFl{1`1zB00qnb`cK(`G_p62H{83=%6$kic8E8Mf{-2roBTmb2EwYF9tnOgBut@5j zz7$%*5_Nn}dAT18QlN=*dGhy0$@-OVgmZLHPzSkW#fCuqhL=}~6u5tMUYa6f|9sa= zytiF(yJ&!|(@5;BRUbhq_&V*OpY!U8f7Kh^Et|V?W229?^ zAE7nAh0vQ|KSG1i+obvI`SZWK8QJb(780*%V{X>x%*Q)yU>m;vEkP83O0t?E5mf~crqH~O3;PQyz7)vG6Y92FKT+^$Q5&R zUkUnuNF*{s{p8sJ2*t76_tldOR;vL$5yp#o&h~yfg29rj8f=b)Vw&TL>?!YMhgChT z=M8$^2W%*vBSf_9j=8PUtrKv z{2|p?KF+Hm_TrnB-YXTKdCF``H4bs*R@|&07u$2I>im_<@uLA8cMO{r8kB~PO*D3d z)lIsMOdGqU%>PBUCIG61ji5OWwT|Q%fuF3h?vuC~;;=8{UF)=P*Zn(Ysnlc8? z0?)Vb$m%1!6>=InDcmJ5(Yda^LHW(&C7m=L=-oX?Fc6!&;UDhOndhs@4Sf%iDBi(f z$tM*Xa(dmbmS!U=b|;V()SLo3+qmt|JI;pS&06?)?ERFX2yG@KOJzyW603jWCHz!q zLxW#f@BF-!atKlnN(BUND`vM!4LPvie%)8lXW^hsKKJB)se%wqAjW@`*rk$AVo-+^ zNndiE2v~@+|?yBAoFvBA33)=g6y>){R zQ2Q^hkSRuAeyA&u+~c`xaxFhUGcHx99~zOh|4T^@g8!n);K{s;XPhsOg=#wb-oyR$ zJjOijx{9CGKW-+LN**KKsH&SX0xDTdrE@gHCbB#keOwRr6bAj}t(#x?OAzTXR~B>z10JMjtR4f%JLDJ^&Q@LaewyPi*h$rr;xTaU&baB<1?WfoV3 zP~sV9R%Jsc^iG-8?PNW~eVZ|FTX8pLllX(X{wZmfl_lalRg7LNVjUbM6DJ_wadWt6 zVN~<}SoIbO|LIpmNChKuYne0Ae)qCTRIhH08JHKbGe_A2(QKj!ngZi}$!`HN0@0aQ zd8s_uv8xLHUi;?uTFwJ+#9jD3CW+n2vEr@%mRVai>W7adZG|-#QOVQ#bp}z>EB|5S z+`Sh0b++9Sy<|%9%?_8OpdC);NfMF0W1Nsk6(^$qqb>Ae!7uI#3X-`Gm|@+kNnoiI ztw6Ke%7WXszyBa_2A$C0qsr-ZpwMz8%~O#({N|oaGISf{h=6#l*(&l&+Yp1MAsK#$ zNSn{-(wICVUQjZ6I?VHakw{Zx${8h;rZ8%unXBh9$Vt1)B(qzdlrHq5hEce6vsm@_w5ZU6e0o#stuzH~tJ?C&^0TyN zp3{@9O470F&qAL*eM(YQAI@c@Hq6=+vitga3FDg%ajS_?Jw$@Gz2N`q3cpU8CU`se zb#pZ^f$_MvvvjQLK6a)dfjAjjcN^jDvN&Y)8}4$L(fc(CIg6{x&8sa@pUWdt^6hw;+0vV z?|3N;Z5BQ%1QQ`h%NdSVyB}jt|Ax2tJ=??^?)4-ufMF%%3^e$m1Yj^f%@{*27iSRexj-wfKd${ z{F3U?!1HZ;K2)YA0i!9wBb#~WtgC3K*``^Vm8rloctIv^=)h1_(`{JP9ZZ==aDGPO z5R(Mw6MZ_A{Ipt#X$)ygH`I_~qDAg?R{ZSOK*zt86H$9e*BrU)iBY^d#PA--LXGMr zzftx?tiVdV+P9=sPHJXYN9&=OpRbm%YBjopI%k&NYr7`V$KcAn;Kd+K1XG^Vj~>%p z!#9H_WlQGJ>%suyEn9P**E7iN#!6L{```Heiu~nK3NONwb|;I%D0Ej})VWG;e579h zL6~rd9w;4^4L7K$sC+*_mV=KBrot40r#)Z3vRE|cs>aID6{8#Zc{J=x3H&RuF{E3$ zIC^)%1J<3Al8`}OBXTOxD2m7tPBAK+DuEV|EK__$iR(xM7hXBf<^z}g5Sc7UbU}ij zWHVI^Zh+jsqPlY58d38IjsvkF7mGUP6CS2$qr-D;6SUD^ZQU%~%&{;fB_&q=JSNTx znk07L$E;n=ktg%9V!hW=H&7?y_Gl_!{srqzvo}X{zJpi}8A7{z*`uR#$LjT8SQVka zf>@UgNa3fQ19t^E=Kg8c;Igd2$NuwTDIj)X0S%ciffj7hWe`eaL*?yBPidSt&CFQ& zIYW?NUY6t1gm&B+l42?uy85Q^80JeJAveZoJbX!QqNcE7eS=TJL0s%x z;zq9Nq7l;{B_wg2oF0dZx(_6ic#G5CPanfH#hjq`7AP(w+M6 z^F9}jHJLR8DcN%kwn@i^2sG)uzf^b*o8BV&lZwayu$?h)a$}Lk#Yfx*~SU@k^DKV)Dth@8^fpTiNjRH0%SJ1HvAUDK9r-m{5_R>el zNW#PnnTZj0;&RIRReY0w_fcJVzg$8}{E^@zTPw2ywRYqXU_lVqcDB z)LZ+cJW$uy*T1AEv5VglV=K^^6fHi(EN;EpJ03~^U0ZWsUE32I@~Uwy=Ow9ZI=CJ4 zz$L#F10CXEO>`aWLTz`@c6OAGJKD4m4^o025Js<3gQ+vRyVY#tJ53(Ks949J4xd=E zT2N}p(er`-E@wN;`d+lc`kgFdN?wFNsb&J2sdr#>U!`BwucGQGk6tzGC6hF2hzKOM+LNN6+q8h(jfo$Y`t{c8V#(uYgX{D(2 z261Z{62L=>Ek0tG^HI_;5h9?MzWzQ6PKe@m)or(+Jr(;2V_W&ura-SFvbQO)`km_( z_wL=x*MOL3%ec|14*w``1(luOF1SFqhCakHNIy7J4SUe_d|I^0Mfe*>y(F2WSdSs+ z!PA6ZcwKV*HRA}00eCYw)Tqj4evsi8s4E9>*B&U59crSTlCW+mqO%D_fuE?szRgbF zfD`t>KKLx_z;&$kjE3c}#fSE%x9+qz#C$TqYQJ(22cS~#WjA*r*Y4;^J0r&EiD z;tbJXms(UU@yMin!6m!(BuikWRgI%4O5D1M1qWgq9Z!bhV(RtNKla`&B{ACNgW&;a! zaBN)Esi;hbsM@rRYaYRzC(Vw(=IJq2L}qOZ>qZtxujq`YlMS%e4hm#uW_6@&OMTzGAE7(TrWv&A`OOMqE4qXbJ7xW_9U8_I=u3t5;CFC?gQj%ekGn zaGo<3x<}{opG_7~yg12RP3zldvlSz%@X=J)3A@2@0#Fc1yF{DjV^U1g#K(0BLJW|I zy>j%K>q9;2d&SM~|+2Q_Fx%O zJoP^IZYdU1ESmq1Tfcf-DBfv62A}TymL$kUzTSg(MMc9I7+mu*R=s?em*(3M;TrqS zucd-yFoSZwHGlTakiM@$pIT1U2lmJj@qv$+(R3x=58MTBm>w|!9>zQp`edhRp1RU^ zW*gv}k_61|3BiriYO(Xy4eOnaVvb1_EFXNsll&S(uzXL$C)R)1%2UnhcYd*{s(Fh} z7eWTZLic&9fJv77wlwG`FG3b(KCy&jQhHu zcNZYO$;UoOgjIh9R_&DZPdf2jplC~$YKngAZpu`VL1qwy%N6ly@A8B3f}6NFSWZHF~4dD2p#_Z74#4&Q+%w-(a@! z9SHu=9l1X02sgAw%rd5^s3>%>Qbujc|Jv>qaM)euX-ChISJCU=+5KzO5fG39>JtVA z!~g~ny-Qy0-9EcqMjs)5uV(GjmN$)s^l6qMtXtXhJ6f(}knGe*Q;x_6Z(Z}PMkl$H zOJUB#+t!9BlIdrP5oG-t9D8FUvWGGRcrS{8U&BZx`^-l%iTSwvPz>Gg*#y44nz;DK z63<*H%bI$aMN(2X+pc%ZxRi#XhDIn+Ua|FBeZ`q7d%j=a`S#rT!BH4obj-2Eo2!1^ z=0il0t(Y&P-hcD=^3S-faj5ysP*2<2mhT$T$Y+RA-cf+z+?wM#i!iR8l7lMR9OZC& z8?{RdZIt?9Dys8bznp|NC7+J%-zY0&3=hou7hbjb;}^r}m^edgL6}H5`79n&=fFJ{YV|MKi?ypjsp;5_7yKCQ;~4i-tJh~ zN4ImVH6ZoWyDf?)uIt)KuRoIyZzL0qc5i#o;aLB^XXyaoWo}Q748W{g*O_H#wj@yr z8V74->%slWO}jf zb8%8;K&Md%Z=9Y|h1ZMX73$u)PL06X?+{M@-^U-%K)98^Np1pExe_*r5=va-8YN{4 zxq7RpW{etYMO1K>UakDEc~d&oO#v5#P8898Su9F7oH*(T6tuG770I6jqCrz*lNUv3 zZCn~nKP@@qrR|gTRYlSGA?;|SltbC1s-XR{0xi+Sb7sh6XjMgI36ot2_o3JePzPt5 z5_onQDUs{XiSIMMIY= zjU#%Sh0PY$-+V5`{gJrFr#XuvmFwO?M?jVjd7Ph1_>tcS!8ABD%H-+q%VzT9+}72n zS;m3tO)$UC7Gzpy%~NZJCjo=LMh_LQ+1iSXVA$8ZbhGfkj7K<+=H>ToIeJT)rf!h& ze0Gi&oVP5~+ny^D-n!Xr5t5T<`C%j&BloLvB*+NtvPU4R3bxUmOMYzouU+m-x_jD+ zANOCxZl%pIr?$#Kb-_7klx7dsT#YQNADc>vkYjm4MTSe9v{si+&@0FX7mM>&c-QXo z=-mQ95mvZ9w~?-FIv#m@H&F05AJrWk>t6z3GRloKyn(amh$+aBk6m$Oo_{HCO< zQ)=|8)4T0`+$58tZ+G!)_XQ)^6X&|oGif7 z_qDUEV{5AyK1VUbwlnx}th)3oWSv!}4wAF(4JvpQGp zl$;t;%j!R6 zAfmp4nVC!-N||c8__25?eVN`awhsR2@18-aGSrjmESvPW>FDexn-)z!c=~3};sd2a zWkmM64Q$VC>hT}FJkV}WTe912(g7;#nmeWi{0_&4xFy9&yr^5z5Let6cihgk!5ff@ zJ&E@1v%_8Mm4~i)dvDb}7%b+k6nz?+qA!?zJRdL49CvX;PQ8y&?^ALiEv>_~!i^?B zO0H&Vl!q}0@A=EacE#&u%eTDv-V~fG$**-xWKL+Yen~y-gr6YicAYYxV zh=AJ>DktCF59)M1dt~V0{qe%1$qyxHdcg`oAk{0B#rsXHkKHM2{JeCF)sTjLXI1yY z1q;HJlQyv5Tt96;fb#X~OCqOg{D7v9fn-7kUq4b7R1{NZvcIbo=6~?)nDB!sNw5nu z0{8+HwQaKG*7~oooX)VqXwtBJ=U~nar>yTyue)(@XhbJK7P{C04BzXbYfmJ-Y3i66 zs2fhRq@LWTP-gT2hDQ{cci+LRhI;Rx8ddW?9Hs>z?DSi)3(>Yq7d?oTd(G4#}kiZ)O0||3-)yw^Wc)HDF;RZj%jd$w4Z{t>A z3hT@Ka?S8njT^W&=TfKqyJNhK2?E%I>2~-qKqfi^!PlQ|EI_=wt zc~ggn7W0M(%nZih&eVT@xz2)&iVzzAuDjJOx2WW^nzO~UeXe?Gi2NpZsW;TA3opuj zn~YPlW6b*%4XwiE$Z}_eBi7`Mi7r{XVdn?yDNg0^*WcSp_`cntD&}zNGI@iMGf&o^ zfCQB)G27{Z)X91U<~^A!HujH-E29K*qjVUqvyMJQJ60X66eX0o-zV$!TjuANCZbM3 zzjcf0s>_%yq-CD<;zyh%NZ|@09rUKr9BeFpl2og1m%Ryl9ndAb`Zm(%S9JA^x=Pj^ zufJFgo{;`Mf&|e(o?M#8OCM4`cH0!QRA<DSiD~j`YrD z=Po-E0MgpAz}SWhwj=jR*yMC$JjMl$Ah8n09Q?Quk>WP{kfWJR2#S5va7wh$aYbld z`sjg-Km&~##%CKB1TvGUGarugWkpsnUJna$+p$!6^LSrQKchVzk#&lQ00DtR@1#(GL@hgK1yO;V1n&uzqh|V?J^dRJOx?gRy)}IJam3EJ5zlY zUo7{uK4Y};nZ{358AES(PR5g|?!gm%IVwm9!Z`+;W4W0s)?X*U!r&cXFj-U9%!BOo z3-dNCIQDx2AHy+d zwQad;d(2E9&T@6tU5{4AyvI}UqSj*Sdj4?&iHjF4vcBwW-3HIF-1TXQPG$LdQ0`1? z16w|zwnH+2;K896mgwMnQq)UD5&`IVyr{ZUEsB3y zmQ?!x5%!squ5Qx0i8<4d2h_FAcoZ18$8*6N`j4q`J)V<^KdOfjHEH$a_uO!pI+eJz zpL-4403u6ow*)1hzMUZsUdg_hvH4ro3x-l5T?-Fb47 zsuxWayw6Yihvl#4ed4V9ca>SZ zDxHsG&sjGJ>zz?a0P1h{Jy@p$-I*l?)58KeoJ82I(#+HZ>qeSFj*q_y z*H}nJd-(!&N8n|bxR`$|=PZmO4fB~haCfT_#e5>TUGLfGI-+Aky0ysM;?f-pyVzou zYH%rsvlBQ3LJXV|ClN&XiYgWiJX?(Ekd=lLT<<<(BlpEtV~S_MSL4UKriQ9xz4^v3 za`BBDh1YO8b&XN#1ZWX_nWU5bJWQvu8{@7k&sYah=Fxlm_y{v2r;RqG3Up(YZrbq) zlUkC0RC~3{w?nEX`Ew|J3_!5nmMbQ=e(?R!(33dQvtZl_3ZWLDwkvg zG7>sFji~-!yG>A1_-s^qU$zKbdWU7|S(GLb74JfdsdDC3IBJ$4|2<8A8ySaEF;mexJSZG{VA5J|)ZKTbF?$0E zF=tk6j_5c_x43vAw)R({2S2&5B2dsp(W~4PLd2Z{_k6lRIHRo#qafCM(tioZ7iwVC zq91b|(1Kz|E`(QH1<^tHux3I?})!i1aZ-U6OiTtn2k+FrJ8injhciu9Mn*2*<=VznuAt5!!7x}kB0M+QG&2?5 znPhwFo&iGV10dfSjK+=hTOA$mTE%YH-^fQ{9~?NVG*ImvT-ES~k$EgEw-5|463H6x zcTF?zs_MGw5F;_;iaI$F{4EJN4yFqax-jLxX+^xHn<#!-thm4J`x8aC)MLTk?&oNi zFX7qr`PSH!K}sUy+#U$u zG1703bl0j3r4>G2?s}sR2bOOyH4TloTUnTY`Hi|Kr%i6!-V<0AQhl-lUo*I%A{tx$bJT`|q2H(9A+NooPD@9M46+rI7hk;w%_97(8ZQl*oGRT=w8rD&eEjC>6Om!s1PObkN-HL7_?UtcAipNA?$;7C^d2q<)g}MYl!lmHgs}%+U~S?`ulgObBtBS5S@FtO|A! zU+nZkJM5>Nx&x`?G}B!J?53*ZWSa6PKx^Q!~F_zIFMS-@n>?9f7 zdhc#$e144uvh(zglT=}nGxGA49s9`in_aR2D^Qp{Q@4*{p2hOZdCS=Lmu(n$nILRJ z&0xv+dL_f=xFL=c^FBr#yZ2{R0;^4)-8w#$A`K*UY0+czhGQ))ExxA-Jq%cT)U8?r zOKN{Y(Gq92YiU%*uF6O#dam=3=+9Ln zC8e$e0fq<+K_|rtiAz!z*G$CE?S9;n;Pqs-M?3eLMhh`dHR7yF{@o1Jm-o!3;e(*kXv9uNe?M%*jl zVRu0?U34TzAEtg8W`VWj%Ydrn>&_${lf}3_;fJESEq-Wl0iI|NR8c zZf)QhrRY^%{FY^L&^Js4&#*)AEOqWNDLodu*3x-d%S@C*hrV9i>PS5)$Fl$6z2zn6 zZf<-S8u#ZB%;!dl$)ZvH=Tew%Z0wuvHP zWmAVb(G;)A!4M@}u)I`<{Q-aZY^NgPhF8S?m90~X4Y6+&HSRqoWnWOWzM=PxK*8ju zY){ow3JQ#Qq|Xjo=(l=}w#}2be{7!-vHP4Z;Y@Y)MQqy_Tf19koGKrxrwT5)Q%OT0 z7zZL$nz1>*Za-m|BTxSDR`t$3Oq|^g^}J;A?yxP*L8fs91jVGD$p>} z%H~mNvP#9OAH8Su_6{|tQYedrU5{WkiYRJ;A0iLiR|;fyYT1fB?FKZP?g~pt;_={D z8_#JYh&QkE(*H_*!lLMPw8^=GuYedoI|0N8P^!B3pm zg7~!FxAotY_yXbE1DHl`QFj)#%yfCT^hml*&9YP`{Lz6v>g7PS!?sY--urodm;8L{ zIXAMA0oQk-`6flmD$Ct01XW6=X(;p{sdRnu(3g&krZv0>JE?iAOeBD@LZInnRdBGo z`(I-S=y)3pPMl~(QB!EVw=u8ohtCQp|NOTaO&{uF_RL3bTZ->q7FTNi{d)cZtFGTI z*Y^eMbq&8CAIrfjZ3W8b=gYgiwLJ4nDT&r04Bw+`_dL`*m5R)|N1a(~9Q^9W`a_LX zo%>igymfbILyfx)hArKGI@`i6dq|n%zq??_y3Ou{q zy(}sxb-9-xt&!j?;r#_s%&wn5YIi)En|4s5nbB*&v*@;#+&Y2txa^92m0PnKOwWUyZTszsCHsPiMUU3<`V^ zHxo?U3t!x`vF5CD+y7)qCi7mz0$XUPs8Z!+@>TuWF1OYmX0%@WGc<0lM#?%i)H_`A z!jw@3KHaEyoN>D{sII6O)*Vxwyh-O;oqWc)$uB;lo#2?1F7Bz6?F5iaA>Z45q=9UqvFjRkI_+wN~<1)W9LkI zU$S3vkEbnq@rndk9PRra$18b*>g~SRBSPG|K0sHg%_RCl~g6cd;yp{5*h-NrXZA1(&I1~1Mvqy z@Ma0RdD$s9)mye-6V`2h_D!pods8B#QoS^!2^J?;-GBI*7;;R11dZ&6 z-$pOpHRR_q(VJl$IUl467XIl=ldjX3wXf&3pH?*3MLX&%b@H0_K7Vz0>e1404cedk zJKNLZPV22ZxpFpswp5D&|AG-2ipEJHX>^ZU?qN}+>0!)6WWsrFYFItE%*SUU*`3qT zG&MwmgM~%PykN?XEvuVUmI0XL*sOIGI@oMI#36q6n}v75*v3VO#ji=}r?VJvB{vuL zL&%fhs!u- zL*q8)^ct%q%V;q`Jn&Z{U6)b?09{pHjkx}D?NJ>JPSSe3o*#Iq1G=bal+QxF=c)|?yC zFj!vBn7Lsc9~b|IpwM-dQ6gK7J>8=hwU*w{GB0Lg{YXVMb=ckLANGDG4`BvrqSfc4 z4wK>#4bO7w+(Y^ZgI@Oa+u6+gtVN%;<_Qj$NxtKsyv^et=6&)fyMjtkH@0%~K`F_x zR;Ch1d7TFUo~!O3;VBWHX6T&DLF^Jhu48CW=1Cyq9K0d9YfHCuX-K)`L@A?+atxSE zM3SH+0erYviRXtH1{Qh@s+YgcDv6&wo{+d~UNXbxstf*Dq9m^|Q=1>zq99(Bx0_Nf3nDOu|(4K$`K} zGvu&wR(1pTgH$V1-DeE4Iz(vxW;Yk0uW1X<(ZolOGv;D6ZmUmhk!_0782qN?@4C zFDQ7B;E{cYzuZ!#Cy}fbiproQ;hLLN zMh-?O%E?_$)Xn=Oe&)Rh_giyOu7*1HXTb$db3iR3D{ne*ry7Ke(&IIy)kH@O1X3Y{LzK(74>l^A&LWY4_Vat{1b-VhNo?muuEJH zz7gCfpn=hxLJW~GM9Rmxx8daZryhK^9f*h=w#X2^XGg=g02Xj2cI&*XmW?D z6oVa*IFRm-;LHVJ2^wuZ59B-u(MDtXUJOf%gMegk`W?n1m?^e#M`NcIuEsUVdb{gB z6QBngMU*2Wu(;^ui#J(3@$5oh&kG0yJYn@~qY7_c5iFd1qxr?+qSQ${!GBVC@)HZ; z-Mra$il0J_cHo_6$tyQF*hdjI{n$81kaxcWBx{BuCzABPw!%0oe3GLi^% z`$5<_AJVg!l+0plk+G8;f^p+Ui>%Vy*Cn?g1j{JmLIjP9P+Aef=*_SsiRhXkjSRj7 zLY0j!`LIG=UQ5sA+eN~;>swBYX&Xv{8D@&gmXK_%%CWboQ$xceKe$(ElR zI=;%SwXIG)BHljN2lFnXyFf^aMNsQeD{UF;M^92}v9uToA_Fq1cn4bklfBoIsjRsQvijGIuYPA?*@6^9NT3LZt#iUXCfpNGkl zF#oi39V!#KBsDP66c>og>h*_1%JxTX_(7Qtzz|M3G}_px-S0-4?tg=A|EH(|P+bwW z#hR_#Lp`a?@^ySg*2;B8E=exE?HwhAQnPMew=By$s_~Xbh`HpGm_LiowC0a}wAR1> zs^w=d&Ckv6UxnqH=x61jd?3d)&sD1xXqj*2_l|M_xLn2f8*3JjAkEDK^_kb8v0S|f zqSXOpb?s9=Cuq#+>NCS8&rY?^A6sX&RdlAPadK}n&CpFw)jMY`!0CfP+PoaCzH+4F3=12~5O(#^+a zXX?@)KJS0~Glp~RM?zpb)02vNVXXWWJ51$ZBfvrq=LuRCOEs49_L8C#7H;Xx_Je6{ zA2Qda=WtP7bT6Qzrlzi+)D|CTkM0rOZ9tzr##J%eWkH{~>d#wS_ymSN$?PQMC=LjF zfOYU1jTB}{PUr(d(clotL-CH~RO3ch;8HxBlXChPl_z6oBeVS7N*HK_bNsv0k}8xSRKsf%MZC)fj69$Np$KaCnGi9TQOh-$ z#A|j_;^`g6i?&|lV8OAA`*NP@MdNW57dyKQs%K^|A-K31O+LmtOKnZAAp!LMarOwh zXgCH{Cg<+s43nJgV@!g>9lO4aKq?p%E_kT>+|nmH6LI!g;8syc1?#tO*|Md0!rYm0 z&mYhC&*QSt;kEq>Z&6YR&@+dgoyqP9j zV0lZ=Jmx=JjU&@|wjh*`Dt|NK_+r;{rf_|WWis8dvi#`>0{ ziJJ?i#~(aje1maEH@#|hz+u&k2n=nf>waK1^czx$uRF9mpeFK(t%-tidRm=pFVZ@Z zsKF2PewXw~S1eNn`h$u3j_K+!wSOhDzc%;*9_=X8CW@VN<0ajbl?&e>iDDKTGNEnB#d2Cas?;mii=g!+R( z?S@W(RtjE_BS>oKgTZzAXh~PYqRY#O{6vWh>DvMa=t&XNLeQoYkDp^+wTg?>sq|il zWU4<)IzvP)d^hacVHMywSg2S(^T6>NH*N%fMhRSU3}t^lD-w8NH~#a7 z)G?en1)mKa6bx=Oq7XoeA?bn6Caq%QmSN=9ji$M~S)D*QzF-f^nZwk%BDNEF&~(Zh+>Ftnt=Q+3 zqWM%*t)x?#)RjR1l8qyCeQ!pNor&RW;;^J9FAR(|tTHjXynCe3BM~P+&PQjCHXlh?q z>|XYFU*G+9mney0&~@~NWHs!JfbqZNx3DuS`gKe3{J*`KM-9pbLC3_)C~kH@^H(@D zO!_#$l-h(LTcCH4x7MTHe&z!UQYfL22l9G-E>-jFQ;%hsroZvd>7*OkJx9&g*RT+fSO)*Ya$7Viwj6B=wA+FrJ@`;H)*!~bGlZWE z7W-|L|JXy8hrlJ`j`$5yOpJ@Fl>P8ml*Q{at)6Q{bBbaH_#sTvZVBk7w2r$+N})cb&@-Q`4f+?G&wUlJ$2J<#V&g09P9I|!zB|pEiZDe66MQxhauuLP zn|S9i$tBAOxbnvcob&PVQT&lgx>*q&NPKkrgBaM6`M9~I`gU`vT2Az^A`yI+o>e4%900|O~K;O%h{8;#sm4?8=(GghwIPQ>clSX z1J}siy@1vQ3;dE^m#+KWo#g!+9NBNZQTbzeemRs6Gj|nds}qrrr65jVi1a@_@<8Mc z4tTg!GsP{Tx}Z;Ag|4|ayh`pgW((G3CgUE6c>7@{Fr z`|~B2$*fpMqJ-e3n`V6V2E)lZIfnnSQ*xegkfw!S4pW}5X=O#^8KQy--_tHV0pCSP zz}*xpl27rT(3yx&k!koi&|u#b1ZL{kcYjSvciAu3L_uvGJ9;@POCDZHd`jn7p;RuDjQ+Qhm_43&sT#nfLt!2{7tv@FH(e(HjwgJ{_MQJ))}MbZBwdgErYC33iWF(t4uU@^zw;WD3FM#$ zpekYZk}`bRjqOE0wP672-rZuQBx-CO12vvqP_UvZ_%M09p&i+`L;8A*)Q5~L(Nvsq+IkIC130dhEKsJdnW<~L zk7B*>k(`38wE|X?;1;Q(77ob`2jpSaXfeMz8%_V*r{C_E$Ix=P?_DBm7L%h|$VEnj z7_&R;6u32LH~krL4~TdmMfaJWh?gxmJctlU;pEAC$$SCBzik+0>zcojU7O@PaXw-B zhaCIl!-*P|KYBo=Qves`G{_1Ly`Mx(>oZfb_PVyX5o#`ASq^#*XEF&br~dOa4qqVh zGcJjJE%cg^ox=ddao6^OmH62KKsx
{Dq8F-8tY+0H9ew>CcxAgF8@f*R+W%T~%=*0oPZPzdqGSky6*PqUvI6&Q&`K zGU-X{$oL=_cvIrai;6Vv9z`Sb8cdwwh}tkOFN)KaOxhKyM-lzGN!-d6f82$CZiQbI zG4t`929%I&^l>4C6pTQoYG3Pl{=Bc(We$ZvL4#TA!Q$s2ibkAQs7=O!otgh=+xvfh z&+Jd2vB18c&G4m*mug1&jZyhS3QxRBbTCnH5m}~q+daGc+7UVc@ip5&^WWO_z|9=O z5Bq#qA9r?i71tv{Q(e{Cn%2F_(zg7t%tRh;Z;j7Cc>kYYXc!0^v4*20ld|-ZTnwvL znXub|&0VXlg|;vviB3qXvsNNkv(Rzj-iEA$_JVi*TyXr~QtwGoOr%d+WRM!7k6Y^+ z?jX2$*`d0|c3q7fJWP8h37Plz#k@!JA{n+NUVg!Iu7t;QziT=T983yHDiiXq!H6+Q zv(>L#|Hn`Nx+WczYy@om8$|o+^xyRK!BM-M-{{yk7|lKS5%-4N+~KmV7XSK^Z{yf-8wa$2znwE%j*(t{M@L*UxRGQNQ5+vz(A)o{&iJxm|}`{uXfbA zX~!;?DrKZG@#z#GO*|4W77a5$=y>!0ch@d z&dBBGQ|M9@& zvpX3$DP4ST94IHnjg?G2C3gkF9975j$ZgCWZSi4TyY?$XJUML+*-}bMzQ76;Q4umY zaHOE-APng&aPr^pSi}0rZOSPDf)YlZ@chLxVTQaNH4|Ccd; zEjO4j?ad??CWTPV2v#jPU`V(SU``cd`TatuYs5iK_C!c)oK00^ zzR=e(pFTVAf+b3-y#8Lti4MN>BD(+FxVep5L--K;?g?=c`{dq43t5Je(QjXjNK7aq zda^<)%DGRR6g|-ILta5%s>$zx-ch;D=rSgz6sjNRq%TKPw)NkC7~XS#pp!@V0)#v& zH1ug&Ff+eXeukZ&R97%!p4Doz+BJL=4?#7hFD{aYSQO^#?Tn2+Gz|>?JMx{Jh`oT` z#royiQ*9QlBNLEFW=m4%&q@t%p`kwAoTj>P#hg^szu(gD2CS>m9ex8!nQioHA@p<& zfv5%#gKmE~+qR1+pr*{d+8!28&V3}0G%1VN8ea{QZ;B~5|3Rp*3b;HTr3|B(8ODd+ z@f7V-h|jE{YjdoHS{3aH6Q-L2@?#D}>tM<>yV3kRYW2wh7uu>idd?_-D6?4vDx7N#jT}Y~mD}&I5^zD)?}` zW|XH(TplW6pG4$U+Lx2cFUoK?{)uXSvturM^`U3L2*b_GxpSwst;upq+9XefX9@&P ztz~z_so7(pA)1k*wjD{1@R59_kz@CePyhhE1{s|evue2jdfX1#+S(=$zz=sine8RI z6(8A}q=2*JrS?I&%)^uX&i`IpT0?(Wg)r*3Migr_xtC@8_M_Ms!=0P20>$LbKJntp zzGv#B5@WVzSti;Gko~YJZv0?el)V~Pu4=zNmQM(>0A5cCOO@zv%)0vDtlPlf9uQtQ zg!v`rymygZRz1`bZPHl%0H-;cQ3FI>6bQ;w@1nq=h}0e+Q6pDOxk^u8%7Y+e3a?3H z{Z~_G;2oN%9VSFA*rg z$VTR)_at0F%+Wf3{ycNb;aoMTfFB?!21L4mC~F9QC211os1p@g4_j!sGu}#a1ek=d zZhZ%H{46hV&U*Z|tA>y9MQiv>p6u!PTETcx^eqwpg*7sp%Oss_|-PRvf2|tlPiQwUy!a{|tnJ$wB6L zdv^99A0udb0Ll$l9<0x2{L)#+4Lr)IXZ!NCYq!vyJMZav)f{d3erT0gyEJo{)&9(x zGcDC-*N1!KVetxPD>_meKjF-M(oW};J|%spAP{9(;xHkgT&r!3u0xW70lh6<{cjj5 zpP(824QfX(xC<=q?Dz_4#z)T?MaEFtjdds}L?J;|FF0NxB?k=KtQZ)DV({*_5HOxh zr8nNmMvv98|H9$o{fhr+V(28!jB1RON73QU(VUaM}n-?C#p-%h{274_-2juVW@% z_GrBldbOCA)(?!RWBC>OcFnV9_n)BOgxYph1V+FZ>>^N%ZSSXdChyxz7FSYVMj5nP z`vj^+j+#CRPd$^cn?#IQ^&R!^U~Pz%-MLc;Ra_ ze#s8m82k@&#o~8_xRaorwUpBSz-bSZ_0|^aHn~kaWZq%Bxq(6f0csTjc&C_jrWlAm zvRxP9!7#9lP$Je^c)x24+WWs3oobBr6SPT4@;KMS)Q zw81fdtb9&4%bKJ3-3oqf4YZ;2f|RC`y74M43kwM*2m5PwjnuQD3iEyO$_9B$2>H;T zHiIB6gkekn`?&TL!l?5cB=bLmSV~NYye+xrp4HrB)eS6y&=vjoTa z^;cB17Q+sO{!Ksu;Ig4-^>(4b&`T*)@ab_!lm>!;TS*k%`{&&$KaCO*xJJiwP$4Os zc$xS=4)3fpuOZpTT9L9@)Sz%a^??X6R?@Cb;I>d@0RBqh+vFE5-`qSe&0rC&!`(vY zEpdWSBgjdC})AW+L9 z?_J_V)gDC9q&N9HD5B!L4-nZ@A$Ou%vScnC{oP5r;)`+zH-Ll{IdWC=A5zl4)|C7; z0a!4P2sDQ&KEU0Jgp7&^;@ZIKZTBToMPs!E(?+D;?s1M9vI~4@ix;9oDiMsxXC!*x z#vR#^K!h9G2XE8z=>EbD4pOAo-0r(IWPASbETfbr0>|Y4hduP9Ce(FyTXrSxbJLST zh{TGK)=@%Llz4%U4J|gyt_}(YX;2<}b=5#`JQ*aYVx&9lkgh+jsRFTHLf9vs=+9-C zCPJ0=jTU6d^0%J;*Z%(TR~8YH-tW6=GrQOWf*#0SOCX$OP<9^?8`%&_Qga#hLyo9d z1V{%&Z5u0Ae}k+}W`^$rGTrC?@ECER24j%ZdbAGw@dba)n<0fYiGK0_hw~IFg)Z0> z;YvKy+pt}OrNUK^WF*fs%`XVq=Z}w$dlnBsW=S4F2?uzn#-&j6i>`Md=@69>1MCl2S0pUByLBIP`^i_QiG{)JKObiR!$OzM0gb zIe@g`6I+0Q0D8}@Q11|?9qN|{2-EuZrn;?3^1_WS6+EV}4ea5qaXF6Ss~SEM%^5~U z#+G^+$_TrE+()4b&5ynY9xE%WmtUVhf6lRK(=!#0+|wV+(21xT0w~(Kb`$Eb|Jw6E z{_0NTinCtGPz0C=Ay^**g@m7IlbWqo2?FoOx389eF4UtzceDV)jI&vjj2sH9bDWO= zJu}*cGFeD&*@MZcryK&t;|AwYLAK}r8cOf1t3EgTygse#gXi`~6Fyej(fZqrq1nSZ&vtV9&wQa9qF_|nr+$go1NBtk`|0lft z>)|K|^4A|O2k=7CIdQy{ai4$9yMf3?k+?|U>YV!_u|aO7ainM9VwrvW6f`t6-l(+t z@~#lGZoNFz-_#WG9CD2;I@z{2!2EfQAfpXYG?TKQ?GXx3Y=leLeymD-2GoFk=W9W) z*H%`xH78skQVQXpsQ7Qj)JW8{?VENRQ}T-c{`>5H-;sf14Ln&n$4e1l97QZ{2v}VS}Fda zLHzBsW48^SlqD&}A4q$zye38N$wP&kH2+_FPaY2C`nIMNT8Pk2r$m-0>meDP&_)VJ zgsDW?$(psP!->!$)F6_jD2nWB9ZQkOF3X6r?^}d?_iM=*-svpgb*|s<`>KEDGVeU^ zeLv5=Z_guzCjma5C01*L2bYDe+LWyxdTWi0@~@|^zQi0{>~g~K`L>W0KQ>}_l1O+` zm1L4%K~=5!z0qVr>#m=h;eTNBU{%K2x{EXn?21nK4z`ELP_*fhhWrEm`I04R#cknbli%Z6@EM~i8P^J(zV`D~baa%xG%NN37b6G6eJ=JEl|RCE+?~YQ-z}3hV>o-b&o`r{?-f+X zHl}qum_R*h_@`vAUW`+}s1p-}VovVuiIp0SEeMi#d?6i5{k30eLNMOAmQXP~l#n6m z_<_}Nixj_jC0ykqI1mS^SY58jmh`)cUkMed+ri3Xg_^nRQBhHMZ9=V4Hhe)?Y2IAE z6jZ?*WXCc2^&LGt7(K$Uo%ehoJeDG|bm`ItcR}b7>#R^Wk1Ng~LfiSw4?*TG@ z24(PLFCiz`sLeKdViIn)uRXtSt;V-ka#q7QL8k$^>j6GPqj1So4md_>_5#}*8xbZ~2D%_-en7!Jl zeORC_5?FC~9QwKMz^vRjI+c}k-*`tFUsqSRcGD)K^j+0Y_Eg8!{7z|;@oh9TwkqMQ z+*>B)i{N%sh#cmC55E3)dbYPjWJHQhu0-8xNlD3<0!ewF>F?S1y-xh#vY~G|)GK;` zi^{QEG_Z~KY+9bJRYyO}tds&qvhFD5p9@yY0XRHQ!xR}N^ld9WfGvuYEjZ3S`BN^X z`i!sb84Y26#%@T|_(%#E$HQ^B(YO=PmYnNEgj*mIdvi8i5BS0{SAK5dVQ(pW_9Ve{o=E)$H=Y1nQx0i9NNCNKD_1V|E?&I&-FYFv ziCiC&fZPhF7g}PMy$QvLeiN|;;Dw1h2j6N>Z4L8lc(IJ_jX`X+y21xCAcd%5Pey(RwN;q@ zi12N%$Ug`$11CIsc+q9@i)}kiD>SXMabwxItjfyDL~jTU?{#bocZ_&obubTO3f_g? z;j3KBZFHogwF7x+Du_ghq~!G_Gx~Tk7@<9XFr>>;=^_TEH@R*wPQ1@?$EGu`A_sze`F5^% z=c<1!22{|XVK4wU8m5oRg_+7gX<@-l>n5KllExsGNn?Dq@iiz@y!cWDXH})sTLNJ8 zeyPZCeAz3m1@mc6*|#3Xxk1!3g=3F}A)l|tg@YRA1;V{x+lu~Sni~Z}?WDjgSkJ0q z`r_GP$8a_FYRfKB$41dJ;Cqkw)%E!#WFf@d+`MWo&=CG+Vi<^M#$EnvV4udQy=q6K zlHZGPUoBg`mt{Dc56(cs4KTpg9ht~sbJUw3^2E=R?5d#o6DGzuIRUr0^N21tP3sTn zc89>yCvCQg;oL@O*n3sz%XgB6ZM;RUkvZB%#H^UO{2#y z=m;MpEGFC8`U!+{M|}Spm0CQ^gy7~_qtsVz`%gSLR}+W!lg@OY0F0KV#}rk-R8kEa z9P`q9Fplq)6<;+3>QZ4O4U>;Wx!O5lYrfHEY@^S_8r=oeyIxyOG+{2_P9;!fw)-2Iy zwl1f!4M@MgZt~=?a&-0}0))FaSxYY9c|TepA{)^=x#oqEzPP}aQ;>o#sP&1_Oy8N@65J#z@z&}MA*v9X zAhf~MLP)%sc<$cC?J@044v?RzUU6P0<$@+~2~+Lv4{?URu~0fC$>s-kodz^unze@d zHdSoAz4gW*G6~JS@fR(G)itbZEkC7Z@{+&?a)JFANG@BEUyCB|IVcIwJsuKuc8VJ~ zTj^N;im4|A?q`Bc>6Du>mY*SX1y>cN<~4pc#Zl))G}yUWaO-A&H^dPe{LSCGfiOaAy6fx_8H@uJm roHAxH1ZY_v?CW5!F^@7v7gNBbC_# zV20>elgZbxr>Z+Id|AfioT@eYz9MI(Xbl5wNmFY^<_a>5P9tL_1R_jp}*Hux@mkDSvk zxb0v&D-o-+CE2h!UON!X>1;%nCDitIR*pMPx)4;?wj9%#nFT+ad3aP{Sf2}a-q80emEU{wy8X7d zKR5foKKgN<_`t3N<*i>S#6dwpfCClc)Y^xjlHzp56A%61~NT`GuWg8yhpK zjMX6Ob4{>tk6=t~qUZQd_2*oW4o?+>$gLl-LoCHaM5=1N<7aCpQ7Gc`@;!Yn+W$=P zGN?Gg+P5hMx^+8#P8ksdU4dANu@xgS^9&sPegKrw=&=sE`>`FI4FXj*$FzvVuEVGu z==_82t0`pbL<4s&4;maA%JHik2pc0@>BYbp6Awg-dw{(dsUuMRsMFf9+4`r^!D%5h z%LEpevGI0ey%M?SQYz0)^xUK;Ip?`-bqQ+=BA|7!)NSh5csg37pl0#}kRdNV5m4mK zfFSFaaw~3ym|j=Qkn2(VpEg6pkHv$x#=G;QC0Nr58rJ(z>4rq$e0v6(f~2F@B%fb7 z1r9cpI>yOzg#K$ieFF?@H+s~6dHK$(;A+8q0U0pa_g;if`39({^n_T(edp8A_h*LK zD?lS97iwwI(IYqMfFSUQopQW2#B!Rly;;N{e>7Hx8&|0@%iY9h`PNUA-?7!#aYx-) zlW2lQca|4);b$TCDAG-o-$(p9l#p9EF;hH zJ)!;VxqGP1?%lf?abJpk8+dRdd5M~-Zy`;Z2<;J#a6$*nTFdu;JJUb&;gmnf5HUv9 z8CosjH-+hj5V1&Sz#$4C4f4u5W(C3(S)smK;>Nt*%Mm$^C_&s#wmw{dWtqdU|>iEze32CXK$xsAT<` zJIl$K3oM)g;9xMw%!9kSx>~xsrN$2y%sUrSD#7XjKZYJ7T%D#|0bx9pDDQ+4bQZ`C z5))~`QrLhTTjM#$rgMnrx?WVvcR5UUy3X&|d##s#qK6(CTI&;22xgp3b#)y}H-rVw zz>pt4ZH1X_pN13P2dt16D>O|I3j&63MTFEU74sQCIE=j~i|J@q^WM)%wALwk7{#iL z=i@#JsT_z4zvIo`{_AJ)Q4k3eg9hokVI*?0cHO#ls#!X|nkjRa<3 zxd7raZ%}9h1!l&_AkVM}vrRu!FGRiBrXJt43I_@$z%-1v#hW@kre2lVIy0{=9ro=E z`2W4&0v%0s;jJ9JmN9(onQ?70ZLa3Dmi@LLd8`Zr+{99MvICqWKJq|Aax0Z+%m93A z38=e`h{eQaU1KJPbbK(q6^EqZ@O`!~_lIw?jogHiCnJBmt|X>9Q6^UUjbaQy&F8!7 zgrQ<;IPY;&Ah-hVfy1tWSD^-nx$Yb}2XWzoqNDkVW_D22^9&y8!=1?~rjD5?hpC+L z_H~eU%!a&E|6g#fmoToe0prFmMY|RzK=)=$jAn3Zet?LZ-~DLbVZ>|nsonwV`l@G^QS&dT3!Vto2&olx^N7FFPyEsr=|`+cABvMnNKzn!_RS3tMJuFz$JR?n z-;$P>Bl~(*Ts(XV65HADfzY(D$=r`oqiToxM%QS^!N>xJ^8KIvJJeY+#d{^dkwaoV zh#(&+@C^R~liVy*TjfpMc%oBx!@Mi0GQB?yM74n=d!oL)0{elk1C4hsP z5VT|l4rS`}kph#PU+5**K(US~Qiih&#C6%DumKI0wfR=C zGYDP=!|mU%FxSCHb=bX=VU@S=Ga%zBEHyAOcjfJax3{$j({4+c-Dpxdh;Sa`YKX;# z-1M1>#nLcvG1@?UxAE?|dwPyI2+lP{;T}AAz?T6TW>>_JgAwDEW4gIjZ#5v$t_3lc zkde_m)%4vkBtvyzeZ`D~pW$COa6eU2rKvyB+A?P*qT9U(6hUpi2vcW7l}z>f!J1)` zX*1i|qH=xjqg!*t#y1%Q9~52S2-mqakz+dhjIstqRvT7wa~&;;q{`7jKaI|lpdqlt`=bd0R|6QZ9<){Ix56e_$m;f1bK&Ubhny;-INh=gX-6!gSBlt zxKby6;iC{NXWs;iSlH@ukdfXjet0X`zRus=(VYkyhC3XvirVGR=qP3KLgcalldL+U z8y7l^JMNj#GNTOwYUByDnO}vh)mRN#$b^(0u;PeCPSL4FMQie)9zFshliv_-h<+3l z4ivsBM))q=_t(E>($C5*tUPeI{(CVUO1n#n^6a@x0d~U^Fc43B&iHpR&_|*;B{@6% z$XEVG?55lpDB(W`CcbS*Yie2F2`^aLY}t zPtnc6RrK_^VCnY7nY@3#1LiUW1WHiYy906{mjkra(cQlw#Y|5vXCd#+Yp3jy)Esar z&kL;|Ng@xn&Jr(-kOMte+#2dldh{gU1BucB9nk4T{K$X%)^B7- z{y|v@#67`5Pt`j}Rbopc`o{kN>OXU*PQe@_D55CmHHMo~-j$Rn73@?R_shE8XCh_|;?S&5QHNI$mvn7G+Ba8?+dhNc?H5531p1@ff%#H@@yA znc+sNydriA?AfyslBoIlHBf^HQ^~}otjU>-1kQg!vru?DeDx=!%K`^Wf`OtEoFPF( zp$hd3N%v=gwL-Q6$tmzZK^M>8!GHh}ld|V#dFW;w)Br(qStE=`Dv*Qg9ye-gYKHeE z+EF-1u(osoFTJHAlBnPZ(P7fk(sn^P4sU2miUdT#A{BksFM#3fZVlhN6yz&K79}D< z54#KuI7_S--%WJ>3n~Y#7-t;VfDXPFq#6nSUMZAddE*0qmf>qVvT6$wc0IKeI6{}9 zI+D^dN2EcI)1BXp9&{{~M}h{jva$pnlbXa`&`hC`z+rI~?warNfSWIx4TvyF+``fl z3NR?V==(Mx*P5pfRJX{7E^~2lIb&)n5N&qq6g#A4=9L7@9tR|*zHgcD13K78I7AQ5 zS+b-Xy6U`**m&IGuhEexLsA;Lxc4d?EoBO?&Q9( zn_n5;biDJiCAvM3173lLs2QcwJuC+pyWH=V%8)YYTw_-kHpW2qPTJqq7mG02AV9N4 zWM|Psf7yc12yBG~1YdTbyYr~$F^vir6iRW&@A|tAEfszpxWt6jy8P($^BtV9&%l^@ z=CK;}K;dQ%3F|I*7$_oC5es*HwzRb!(luhC1Tt9?Z2=XA7nxC+sgYAEvwF|zI}1Gc zZ3FRl5fAU6Q8&7W_XaXp6}&I(_`dyU{s;U)FcRb$^Ndif6uda%vnJW;qx-B3ElWTb zD73^(6|c>I%OcDsesGR^oiJ1j63&#S=iUj5qIc!^4ux?piQ92D^gg};X($A-|Q z=B`(;s{n(fRbVPuvxyocvw0y!{11&w?0txkBc7wcbbXx1-w2Q|fa)69ZRTQSTaX&X zoyRsVv{RD@5u^+odi-a3^2q%x@je6*rz$*J9K|giZM824-R7BFw71YP{{U*rxcjf& zDX0w>gjT!ajX0=&2aA!p&LBvfG<>XppDspT==d49=2ETC3aWsM0EE_i zT)>?XK^2cw_(Nc-s|2rfBe9t0Dk|U$0nnQUsjfT=!h`$}|9!1V^L^!;83MX@yZk3v zbZ{481fgzOIJT}1o!9Qj12~{~{Q=^{gnS;AeYd4LLkE?JUo-QIq`H}rmY*Bt722~pdEe5SC9`xGh0Q$x$8UO>U zCmiiXS5mJb!X*FpwI=El0D$gcFi+z@rum|P3I^;IB-C-WQ-#tpVF+tc3tdMy6WX>H zRaCHU-MY0P1_lPL^%FITs%8|Ase{WBD&Z!_^r62~IyPUu)2yMZd>45mtave891rbw zcPhjt3qiL2b2m9?GPDe_soKri=n<~JU^_4l=yRU+N8R}E+l=75{{RYT;9uUUnHr|- zCu*_Vp+!AH?O}xa5``7EMpZE!FtM9>;W`rqZ6|~z^IIOMpz{jNHZCMxJC(e+%oDc_ zeYHi>{7#E0aV9sClX#cIbq$ry1Xdzmz$N`d6yAc-0aw@i{`pm?&tN41t+scZDik~p ziy)xaWg@k=&|)}$s+F}28G5n-u0}4*ule7}8!cpU%u3$>hT5U|`jqr9|3Zcrue3T5 z1~fIoO@jq;zdk;(+C#Qcd9yWoY7yDr5~wp3F31c+kGCpI2PGtk%F4<4g<4o!%Q-nY zG3Om$jH)=1MTv;17*2}*+{as}Kyj`*AHzX`P)i{l5?S7H;v0m-z+Oxr@)C zBzzSkSkQbIyU|Oy-T+5btTeBZ3 z-#i~%+t`Rf@sdxp0}R}L22->CX_&?tV>Q4X@@OEls7P_es#Q(Gt5&So|KQ!cHt;#v zd_TLdh>YtvpwtQnb0|S>YDB{<|IG?6{BiKF4t)|@H)nlPISOMXE^C!_98QTK?Xg?H zO!xuRf9Ot4Tj;uG(-yiuOGE?rOlw6T(QED0~r2W`e0u2d$InRvBD~m9+bpC4x z)~{WARrbNn`a^uKv-CwyR$!$Q(Z7Yzts4>sdcG?UJu`LW3gSNm5zq^;v6rC9?y{B+ z7vT!(yAYr97F2H=GFZ?g49U=jD0k!h|QVnvlA`@+;%Ajn*Jd{LGpQ~3!CIGI3rUJdaF;j_H zaqryP!({-&Xb#beV*!`N1VdK}gc=8Nr3lN>HAiRHowEiMULh@cAN^}6waJ}|2T*hm*HI|p z5=5GAuT*MLT!s#KPYK(P9-0m^#yB`QKo?6f?@ZHEGT;5J?T~~uF(jr-;zr>bli7_nvh{grtSsP#?vSqZY<=zN9mN@d{} zjNSaqme|V1V>=!mZ;c*o+`BEPWpIm4G`fSyz}^#fDY(KNRP%mYtPE-FNJAypCF)rDld)LBB@1rIFcV z!LKLk2W`Fk_jD(?H=b}gofznVej*Gla@co~fxNZo9s&d1uw3+w?Su7Vo-y5{X^rtl zI)t#%s%5AkQQ;kztV#Rwx>z3FkiCB?SoMIHZBUE1_~v}{(<0Du*$RKMSc@u}*u!C= zw;K8c{0-x#wMD)SAFSGD!%N<9Jd4PA?x2>k+a;$_ zsR-{63;k^V$SAZ$g=sf>i#p>(lKbZPGi^p2?_<@#k6I%9Oc TuHB!g;NL+Rh5b?cPF?;V7xAnr literal 0 HcmV?d00001 diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/experimental/nnx/docs/quick_start.ipynb new file mode 100644 index 0000000000..540aec36f4 --- /dev/null +++ b/flax/experimental/nnx/docs/quick_start.ipynb @@ -0,0 +1,568 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NNX\n", + "\n", + "Welcome to NNX!\n", + "\n", + "NNX is an open source Python library for **N**eural **N**etwork in JA**X**. Its main feature is, much like Pytorch, allowing Python object semantics and reference sharing, which brings simplicty and familiarity, and easily crossing over into the functional world with through a set of simple APIs.\n", + "\n", + "This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using NNX and train the network for image classification on the MNIST dataset." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install -q nnx" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the MNIST dataset\n", + "We will use the `datasets` library to load MNIST and convert it to NumPy arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cris/nnx/.venv/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Found cached dataset mnist (/home/cris/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", + "100%|██████████| 2/2 [00:00<00:00, 499.95it/s]\n" + ] + } + ], + "source": [ + "import datasets\n", + "import numpy as np\n", + "\n", + "dataset = datasets.load_dataset(\"mnist\")\n", + "X_train = np.array(np.stack(dataset[\"train\"][\"image\"]), dtype=np.uint8)[\n", + " ..., None\n", + "]\n", + "y_train = np.array(dataset[\"train\"][\"label\"], dtype=np.uint8)\n", + "X_test = np.array(np.stack(dataset[\"test\"][\"image\"]), dtype=np.uint8)[..., None]\n", + "y_test = np.array(dataset[\"test\"][\"label\"], dtype=np.uint8)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets visualize a few examples from the dataset using matplotlib:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA58klEQVR4nO3df3zNdf/H8dcxs83P+f0rTWtcfi3UjESGvi3R1apFP7TQD3XxbYmkK2zlilQi+VkpikTzI+RSucxVXNpIiJIRlVVsFiM/Ztvn+0dfuzrn9WFn29n2PmeP++3mj/fT+3zO23q3l4/z2vvjsCzLEgAAUK4qlfcCAAAABRkAACNQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKTXr16+Xnj17Sr169SQ4OFgiIyPl3XffLe9lwYav78fvvvtORowYIV27dpXAwEBxOBxy6NCh8l4WLsLX92NiYqI4HA71KzAwsLyXViKVy3sBsLdq1SqJiYmRa6+9tmDzLV26VOLi4iQzM1NGjBhR3ktEBbJlyxaZPn26tGnTRlq3bi07duwo7yUBMnv2bKlevXrB2M/PrxxXU3IUZEPNmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmfrrX/8qx48flxo1asjLL79MQYYRYmNjpV69euW9DI/x2n+ydkdOTo6MHz9errnmGqlVq5ZUq1ZNunfvLsnJyRd9zdSpUyUkJESCgoKkR48esnv3bjVn7969EhsbK3Xq1JHAwECJiIiQVatWFbqe06dPy969eyUzM7PQudnZ2VK7du2CYiwiUrlyZalXr54EBQUV+nqYx5v3Y506daRGjRqFzoP38Ob9eIFlWZKdnS2+8tBCny7I2dnZ8uabb0pUVJRMnjxZEhMTJSMjQ6Kjo23/hv/OO+/I9OnTZdiwYfL000/L7t27pVevXnLkyJGCOXv27JEuXbrIt99+K2PGjJEpU6ZItWrVJCYmRlasWHHJ9aSmpkrr1q1lxowZha49KipK9uzZI+PGjZP9+/fLgQMHZMKECbJt2zYZPXp0kb8WKH/evB/he3xhP4aGhkqtWrWkRo0aMnDgQKe1eCXLS7399tuWiFhbt2696Jzc3Fzr3LlzTtlvv/1mNWzY0BoyZEhBdvDgQUtErKCgIOvw4cMFeUpKiiUi1ogRIwqy3r17W+Hh4dbZs2cLsvz8fKtr165WixYtCrLk5GRLRKzk5GSVJSQkFPrnO3XqlNW/f3/L4XBYImKJiFW1alVr5cqVhb4WZc/X9+OfvfTSS5aIWAcPHizS61B2fH0/Tps2zRo+fLi1aNEiKykpyYqPj7cqV65stWjRwjpx4kShrzeVT98h+/n5SZUqVUREJD8/X7KysiQ3N1ciIiJk+/btan5MTIw0bdq0YBwZGSmdO3eWtWvXiohIVlaWbNiwQfr37y8nT56UzMxMyczMlGPHjkl0dLSkpaVJenr6RdcTFRUllmVJYmJioWsPCAiQli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P8L3ePN+jI+Pl9dee03uueceueOOO2TatGmyYMECSUtLk1mzZhXxK2EOny7IIiILFiyQq666SgIDA6Vu3bpSv359+eijj+TEiRNqbosWLVTWsmXLgh/v2L9/v1iWJePGjZP69es7/UpISBARkaNHj3pk3cOHD5fVq1fL+++/L3fddZfce++9sn79emncuLHEx8d75D1Q9rx1P8I3+dJ+vOeee6RRo0ayfv36UnuP0ubTXdYLFy6UQYMGSUxMjDz55JPSoEED8fPzk0mTJsmBAweKfL38/HwRERk1apRER0fbzgkLCyvRmkX+aLaYN2+ejB49WipV+u/fmfz9/aVPnz4yY8YMycnJKfjbLbyDt+5H+CZf3I/NmjWTrKysUn2P0uTTBTkpKUlCQ0Nl+fLl4nA4CvILf1tzlZaWprJ9+/ZJ8+bNReSPBgKRPwrjDTfc4PkF/79jx45Jbm6u5OXlqd87f/685Ofn2/4ezOat+xG+ydf2o2VZcujQIenYsWOZv7en+PQ/WV/4IXHrTy3xKSkpsmXLFtv5K1eudPqMIzU1VVJSUqRPnz4iItKgQQOJioqSuXPnyi+//KJen5GRccn1uNvW36BBAwkODpYVK1ZITk5OQX7q1ClZvXq1tGrVih998kLeuh/hm7x5P9pda/bs2ZKRkSE33XRToa83ldffIb/11luybt06lcfHx0u/fv1k+fLlctttt0nfvn3l4MGDMmfOHGnTpo2cOnVKvSYsLEy6desmjz76qJw7d06mTZsmdevWdfoxo5kzZ0q3bt0kPDxcHnroIQkNDZUjR47Ili1b5PDhw7Jz586LrjU1NVV69uwpCQkJl2xc8PPzk1GjRsnYsWOlS5cuEhcXJ3l5eTJv3jw5fPiwLFy4sGhfJJQZX9yPIiInTpyQ1157TURENm/eLCJ/HF4THBwswcHBMnz4cHe+PChjvrofQ0JCZMCAARIeHi6BgYGyadMmef/996VDhw4ydOhQ979Apimv9u6SutDWf7FfP/30k5Wfn29NnDjRCgkJsQICAqyOHTtaa9asse6//34rJCSk4FoX2vpfeukla8qUKVazZs2sgIAAq3v37tbOnTvVex84cMCKi4uzGjVqZPn7+1tNmza1+vXrZyUlJRXM8cSPmSxatMiKjIy0goODraCgIKtz585O7wFz+Pp+vLAmu19/XjvM4Ov78cEHH7TatGlj1ahRw/L397fCwsKsp556ysrOzi7Jl63cOSzLR444AQDAi/n0Z8gAAHgLCjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAAt0/q+vNZp4Crsv5xdvYjLoX9CJO4ux+5QwYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxQubwXAKBoPv/8c5Vdd911Ktu0aZPKbrvtNpUdO3bMMwsDUCLcIQMAYAAKMgAABqAgAwBgAAoyAAAG8NmmrgYNGqjsySefdBqPGjXKrWtNnz5dZc8884zKTp065ebqgOKzLMutrFu3biqbMWOGyu6++27PLAwoR127dlVZYmKiyiIiIlQWGRmpsv3793tkXUXBHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6oqPj1fZyJEjncZ2jTB2/vd//1dloaGhKhswYIDKTp8+7dZ7AGWhSZMm5b0EoEDNmjVV5nrqXHh4uJozePBglTVv3lxlVapUcWsddk3ANHUBAFBBUZABADAABRkAAANQkAEAMIBPNHXZnTT08MMPF/q6pKQklc2dO1dlderUUdmcOXNUtmTJEpXFxsaq7Ny5c4WuDQDKUlhYmNO4bdu2xb5Ww4YNVdanTx+V2Z2Q1ahRo0Kv73A4VGb3ffXDDz9U2eLFi1W2d+/eQt+zLHCHDACAASjIAAAYgIIMAIABKMgAABjAJ5q60tPTVVa3bl2Vbd261Wlsd7KWu6d3tWnTRmUJCQkq+9vf/qayqVOnuvUeAFBSwcHBKvvoo49UdtVVVzmNq1atWuz3tGu6cvd769GjR53GKSkpas7mzZtV9sEHH6js0KFDbr2nKbhDBgDAABRkAAAMQEEGAMAAFGQAAAzgE01dV155ZbFe526TgZ3Jkyer7K677lJZpUr8nQdA+Zk3b57Krr32WpW58/1w0aJFKnP35MHly5er7MSJEyr74YcfnMZ2Tbu+imoBAIABKMgAABiAggwAgAEoyAAAGMAnmrree+89ld10000qu/rqq53GHTp0UHN27Njh1nuePXtWZd98843KittwBgCeYNc4ZXeS1sqVK53Gt99+e2ktCRfBHTIAAAagIAMAYAAKMgAABvCJz5DtfjD9mWeeUVlycrLT+LPPPlNzbr311kJfJyJSs2ZNlf31r39V2caNG1UGlAW7zwnh25o0aaKyqKgoldkdArJhwwansWvPjYjI5ZdfrrJ9+/apzPWJTReTnZ2tspycHLde64u4QwYAwAAUZAAADEBBBgDAABRkAAAM4BNNXXb279+vsri4OKfxqlWr1JwVK1aozK7Rq3nz5irz8/NT2bp16y61TOCSBg0apLLw8HC3XmvXuFORnpzj6xo3bqwyu+83ISEhbl3v1VdfLfGaLrBrKLTbj6mpqSr76KOPnMazZs1Sc7KyskqwOnNxhwwAgAEoyAAAGICCDACAASjIAAAYwGebuuy4nrg1ZMgQNWfJkiUqc30Kit21gJKKiYlR2ezZs1VWpUoVt65ndxLd8OHDi7wulL9WrVqpbOnSpSpr27atW9f75ZdfVHbkyBGn8eLFi91cnTZ48GCV2TV1tWnTRmWRkZFO4379+qk5H3zwgcpee+01lXnbqV/cIQMAYAAKMgAABqAgAwBgAAoyAAAGqFBNXa6WLVumsquuukpldqfJ2DXg2PHVE2XgeZ07d1ZZYGCgyuyaY+zYPXYP5rP7HrR582aVVa1aVWXHjx9X2QMPPKCyLVu2qMy1qaskXnrpJbfm2TWrRUdHO42feuopt64fFhamskcffdStdZiCO2QAAAxAQQYAwAAUZAAADEBBBgDAABW6qSs/P19lu3fvVllSUpLKBg4c6NZ7pKWlFX1hqJDsmrXs9qgdu8eNwnfk5eWp7P3331eZXQPX2bNnS2VNnrB3795Cs3/+859qzptvvqmy+++/X2Vz585V2Y4dO4qwwrLFHTIAAAagIAMAYAAKMgAABqAgAwBggArd1OWuatWqFfu1M2fOVNnjjz/uNOZRjiipqVOnlvcS4CG7du1Smd3pXT/++GNZLKfc7du3T2WTJk1S2Zo1a1Tm+r1WRGTQoEGeWFap4A4ZAAADUJABADAABRkAAANQkAEAMABNXW7o3bu3W/Pmz5+vsjvvvFNlQ4YMcRrT1AXgUipKA5e7srOz3ZpXs2bNUl6JZ3GHDACAASjIAAAYgIIMAIAB+AzZReXK+kvicDhUlpOTo7Lp06erzO6H2hMTE53Gzz//vJpj9xQU+JY6deo4jaOjo916XXp6uso+//xzj6wJ8AadOnUq7yWUCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1OXilltuUVmNGjVUtmfPHpXt2LFDZQcPHlTZTTfd5DSeMmWKmtO3b99LLRM+ICsry2n88ccfqzkdOnRQWdOmTVXWvXt3ldntUXgnu6c9NWjQQGXr168vi+WUu3bt2rk1b+3ataW8Es/iDhkAAANQkAEAMAAFGQAAA1CQAQAwAE1dxeRu88SJEydUtmTJEqfxK6+8ouaEhYWpbP/+/W6uDhXNrbfeqrI5c+aUw0pQGqpXr66yTz75RGW9evVS2caNG0tjSWXmjjvuUNkDDzygMrsnQNl9jUzGHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6XLRs2dKteZ5ssAoICFBZeHh4qb4nzLNr1y6VnT9/XmX+/v4qu/HGG1W2evVqlb333nsqy83NdRp/8MEHl1wnyt7OnTtVtnz5cpW9//77KrNr9Prmm288s7BSMGDAAKfxm2++qebYPf7Wrnntxx9/9Ni6ygJ3yAAAGICCDACAASjIAAAYgIIMAIABaOpysW/fPrfm2T3ububMmSqzO2GnX79+TuPMzEw1x9tP10HR2TXk1K9fX2UTJ05UWdWqVVVm9wjPm2++WWV5eXlO4zp16qg5c+fOVRnKzu+//66y4cOHq8zuEZ5ff/21yuwapV599VWnsacbv3r37q2yNm3aqOzFF190GlepUkXNsft/5d577y3B6szAHTIAAAagIAMAYAAKMgAABqAgAwBgAIdlWZZbEx2O0l6LEYKCglT21Vdfqczu8YgLFy5UWdOmTVXm2txgdzKS62k1pnNzG3lMRdmPdmJiYlT2+OOPq8z1MZ8i7v13stvvKSkpbq3NFBV1PzZp0kRln3/+ucquuOIKlbk2jp05c6bY67D7etSqVUtlfn5+Kjt69KjTePz48WrOu+++q7KzZ88WZYllyt39yB0yAAAGoCADAGAACjIAAAagIAMAYACautxgd9pWQkKCymJjY1VWs2ZNlS1btsxp/Mwzz6g5GRkZRVliuauoTTQwE/vxv9q1a6ey+Ph4lbk+8rVTp07Ffk+7r0dSUpLKdu/erbJ58+Y5jdPT04u9DlPQ1AUAgBehIAMAYAAKMgAABuAzZHgEn9nBJOxHmITPkAEA8CIUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADOD24xcBAEDp4Q4ZAAADUJABADAABRkAAANQkAEAMAAFWUQOHTokDodDXn75ZY9dc+PGjeJwOGTjxo0euyYqBvYjTMJ+LDteW5Dnz58vDodDtm3bVt5LKRWJiYnicDjUr8DAwPJeGmz4+n5csWKFREdHS5MmTSQgIEAuu+wyiY2Nld27d5f30mDD1/ejiMj69eulZ8+eUq9ePQkODpbIyEh59913y3tZJVK5vBeAS5s9e7ZUr169YOzn51eOq0FF9fXXX0vt2rUlPj5e6tWrJ7/++qu89dZbEhkZKVu2bJH27duX9xJRgaxatUpiYmLk2muvLbh5Wbp0qcTFxUlmZqaMGDGivJdYLBRkw8XGxkq9evXKexmo4MaPH6+yBx98UC677DKZPXu2zJkzpxxWhYpqxowZ0rhxY9mwYYMEBASIiMjQoUOlVatWMn/+fK8tyF77T9buyMnJkfHjx8s111wjtWrVkmrVqkn37t0lOTn5oq+ZOnWqhISESFBQkPTo0cP2n+T27t0rsbGxUqdOHQkMDJSIiAhZtWpVoes5ffq07N27VzIzM93+M1iWJdnZ2cL5Ld7PF/bjnzVo0ECqVq0qx48fL9brUb68eT9mZ2dL7dq1C4qxiEjlypWlXr16EhQUVOjrTeXTBTk7O1vefPNNiYqKksmTJ0tiYqJkZGRIdHS07NixQ81/5513ZPr06TJs2DB5+umnZffu3dKrVy85cuRIwZw9e/ZIly5d5Ntvv5UxY8bIlClTpFq1ahITEyMrVqy45HpSU1OldevWMmPGDLf/DKGhoVKrVi2pUaOGDBw40Gkt8C6+sB+PHz8uGRkZ8vXXX8uDDz4o2dnZ0rt3b7dfD3N4836MioqSPXv2yLhx42T//v1y4MABmTBhgmzbtk1Gjx5d5K+FMSwv9fbbb1siYm3duvWic3Jzc61z5845Zb/99pvVsGFDa8iQIQXZwYMHLRGxgoKCrMOHDxfkKSkplohYI0aMKMh69+5thYeHW2fPni3I8vPzra5du1otWrQoyJKTky0RsZKTk1WWkJBQ6J9v2rRp1vDhw61FixZZSUlJVnx8vFW5cmWrRYsW1okTJwp9PcqWr+/HC/7yl79YImKJiFW9enVr7NixVl5entuvR9nw9f146tQpq3///pbD4SjYj1WrVrVWrlxZ6GtN5tN3yH5+flKlShUREcnPz5esrCzJzc2ViIgI2b59u5ofExMjTZs2LRhHRkZK586dZe3atSIikpWVJRs2bJD+/fvLyZMnJTMzUzIzM+XYsWMSHR0taWlpkp6eftH1REVFiWVZkpiYWOja4+Pj5bXXXpN77rlH7rjjDpk2bZosWLBA0tLSZNasWUX8SsAE3rwfL3j77bdl3bp1MmvWLGndurWcOXNG8vLy3H49zOHN+zEgIEBatmwpsbGxsnjxYlm4cKFERETIwIED5YsvvijiV8Ig5fwXgmJz52+AlmVZ8+fPt8LDwy1/f/+Cv0mJiHXFFVcUzLnwN8Dx48er1993331WQECAZVn//RvhpX5t377dsiz7vwF6QqNGjazevXt79JoouYq4H7OysqyGDRtaI0eO9Ng14Rm+vh+HDh1qtW/f3ulfZ3JycqwWLVpYkZGRxbqmCXy6y3rhwoUyaNAgiYmJkSeffFIaNGggfn5+MmnSJDlw4ECRr5efny8iIqNGjZLo6GjbOWFhYSVac2GaNWsmWVlZpfoeKB2+th9r164tvXr1kkWLFnn00AiUDW/djzk5OTJv3jwZPXq0VKr033/k9ff3lz59+siMGTMkJyen4O7fm/h0QU5KSpLQ0FBZvny5OByOgjwhIcF2flpamsr27dsnzZs3F5E/GqxE/vgPf8MNN3h+wYWwLEsOHTokHTt2LPP3Rsn52n4UETlz5oycOHGiXN4bJeOt+/HYsWOSm5tr+1HJ+fPnJT8/32s/RvH5z5BFxOlHhlJSUmTLli2281euXOn0GUdqaqqkpKRInz59ROSPH/OIioqSuXPnyi+//KJen5GRccn1FKWt3+5as2fPloyMDLnpppsKfT3M48378ejRoyo7dOiQ/Otf/5KIiIhCXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVq5bU/+uT1d8hvvfWWrFu3TuXx8fHSr18/Wb58udx2223St29fOXjwoMyZM0fatGkjp06dUq8JCwuTbt26yaOPPirnzp2TadOmSd26dZ3a6GfOnCndunWT8PBweeihhyQ0NFSOHDkiW7ZskcOHD8vOnTsvutbU1FTp2bOnJCQkFNq4EBISIgMGDJDw8HAJDAyUTZs2yfvvvy8dOnSQoUOHuv8FQpny1f0YHh4uvXv3lg4dOkjt2rUlLS1N5s2bJ+fPn5cXXnjB/S8QypQv7kc/Pz8ZNWqUjB07Vrp06SJxcXGSl5cn8+bNk8OHD8vChQuL9kUySfl+hF18F5oWLvbrp59+svLz862JEydaISEhVkBAgNWxY0drzZo11v3332+FhIQUXOtC08JLL71kTZkyxWrWrJkVEBBgde/e3dq5c6d67wMHDlhxcXFWo0aNLH9/f6tp06ZWv379rKSkpII5JW3rf/DBB602bdpYNWrUsPz9/a2wsDDrqaeesrKzs0vyZUMp8fX9mJCQYEVERFi1a9e2KleubDVp0sS66667rF27dpXky4ZS4uv70bIsa9GiRVZkZKQVHBxsBQUFWZ07d3Z6D2/ksCyOgAIAoLz59GfIAAB4CwoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgALdP6vrzWaeAq7L+cXb2Iy6F/QiTuLsfuUMGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMULm8FwBUBA888IDK3njjDY9d3+FwqOzw4cMqmzhxospmz57tsXUAKD7ukAEAMAAFGQAAA1CQAQAwAAUZAAAD0NQFlIE+ffqozLIsj13f7lpNmjRR2fTp01XWoUMHp/HQoUM9ti6gqPz8/JzGHTt2VHOef/55ldnNu+eee1S2fv36EqyudHGHDACAASjIAAAYgIIMAIABKMgAABiApq4KrG7duk7jVq1aufW6zZs3l8ZyfNqXX36psmuuucZpfPnll5f6OipV0n8Hdz1FbM+ePWqOXTMYUFKuDYUiIgsWLHAaX3XVVcW+frNmzYr92vLAHTIAAAagIAMAYAAKMgAABqAgAwBgAJq6fFCNGjVUNmTIEJWNHDnSady0aVO3ru96kg4KN2nSJJV98MEHTuNGjRp59D1feeUVlbk2konoRze6NvsBnjB48GCVPfPMMyq78sorncb/+c9/1JzTp0+r7OzZsypbuHBhUZZY7rhDBgDAABRkAAAMQEEGAMAAFGQAAAxAU5cXadu2rcoee+wxlUVHR6usuCfW/PDDD8V6HQq3f//+S46Lwu6/77lz54p9PcBddk2e48ePV9kTTzyhssqVdQl6+umnncavv/66mrNx40aVNWzYUGVVqlRR2fnz51VmCu6QAQAwAAUZAAADUJABADAABRkAAAPQ1GUAu8ceDho0SGUPPfSQyoKDg1U2b948lS1btkxlgYGBTuPevXurOX//+99VhrJj15QyYcIEldmdguTuiVu///6703j16tVurg4QeeSRR1Rm19S1detWldmdIPjdd985jZcvX67mhIeHq+yNN95QmeveNh13yAAAGICCDACAASjIAAAYgIIMAIABaOoqZS1btlTZmDFjnMZ33nmnmlO1alWVLVq0SGV2DQ8rV64swgr/68MPPyzW61B6VqxYobKbbrrJo+/h+gi8bdu2efT68B133323yl544QWV/etf/1LZAw88oDK7kwATExOdxv369VNzsrOzVbZkyRKVeRvukAEAMAAFGQAAA1CQAQAwQIX+DNnu6SDPP/+8ys6cOaOyhIQEld1zzz0qc/18TkR/Pvzuu++qOXafDa9fv15l8A59+/ZVmevTu5577jk1x9/fv9jvmZGRobI5c+aobNasWcV+D1QsdocY2X2e+/jjj6vM7vPiW265RWWuPTY5OTlqTmxsrMrsPrf2NtwhAwBgAAoyAAAGoCADAGAACjIAAAaoUE1drk+/WbdunZrTvn17lVmWpbJbb71VZfXr11eZ3SELzz77rNOYZi3v1bRpU5Xdd999Khs3bpzKXJ+25a5Dhw6pbMCAASr78ccfVXb06NFivScgIlKnTh2VJSUlqWz37t0qi46OVtmUKVMKfU+7xthPP/200Nd5I+6QAQAwAAUZAAADUJABADAABRkAAANUqKau0aNHO42vuuqqYl/L4XCo7Pbbb1fZP//5z2K/B8zXsWNHldmd9lZcmzZtUtmLL76oMp7QhLJw8uRJlT388MMqs2v+snuqnZ3Bgwc7jRcvXuzm6rwfd8gAABiAggwAgAEoyAAAGICCDACAAXy2qatPnz4qe+KJJ5zGx48fV3Nq166tsm+++UZld911l8r27NlThBXCF9g9eu73339XWbVq1Yp1/W7duqksLCxMZXZ7Lz09XWWvvvpqofPsHtsIiNh/z3Q9AVFEZODAgW5db9CgQSqrSE1crrhDBgDAABRkAAAMQEEGAMAAFGQAAAzgs01dMTExKqtUyfnvH8uXL1dz3n77bZV99dVXKjtz5kzxFwef8dlnn6nM7sS2ESNGqMz1pLgmTZq49Z6NGjVyK7MTFxensu3btzuN9+3bp+aMHDlSZb/++qtb7wmIiPzjH/9QWUVu4LLDHTIAAAagIAMAYAAKMgAABqAgAwBgAJ9t6tq/f7/KXB+Z2L17dzXnoYceKrU1oWJYv369W1mVKlWcxkOHDlVzJk+erLKAgIASrE67+uqrLzkWEWnXrp3KXnvtNZW99dZbKsvPzy/B6lBemjdvrrIHH3yw2Nf7+uuvVZaTk1Ps6/ki7pABADAABRkAAANQkAEAMAAFGQAAA/hsU9fu3btVlpeX5zRu0aKFmhMREaGybdu2eW5hwP9zbWixa5LaunWryuweLWrH7vSu4jbl2DV1zZ07V2X169dX2aRJk4r1nihbrk2GL7/8sppj9z3z448/Vtn111+vsmuvvVZlS5cuLcoSfR53yAAAGICCDACAASjIAAAYwGFZluXWRJdDNbzRCy+84DR+8skn1ZzffvtNZe3bt1dZenq65xbmA9zcRh5jyn60+zy3R48eKnv99ddV9v3335fKmi7w8/NTWdWqVVV26623Oo2HDx+u5nTq1Mmt93Tt0xARueWWW1Rm97mjJ1XU/VgSTz31lNPY9fuliMj8+fNVZnegzXPPPacyu/9X7L63+iJ39yN3yAAAGICCDACAASjIAAAYgIIMAIABfPZgEDsvvvii0/jOO+9Uc6644gqVVa9evdTWBO9m1zhl1yyYkpKistJu6rJrsDp58qTKFi5c6DS2e1Lahg0bVGb31Cm7r0flyhXq24xXuOGGG1Q2YcIEp3FaWpqaM2bMGJXZPbHp9OnTJVhdxcUdMgAABqAgAwBgAAoyAAAGoCADAGCACtVtkZWV5TT++eef1Ry7pi6gpOxOPXJ9cs6SJUvUnB9++MGt619zzTUqs2tGrFWrlsqefvppp3HTpk3VHLsGLjtfffWVyjZv3uzWa1E6mjdvrjK7k+NcT5MaOXKkmnPkyBGV2T0J7JFHHlHZJ598cqllQrhDBgDACBRkAAAMQEEGAMAAFGQAAAxQoZq6WrVq5TS+7rrr1BxfeIways7WrVtVtnbtWpV1795dZZMmTXIaDx48WM2xOzXLTrdu3VRWs2ZNt15bXHanfrmehicicvz48VJdBy7tvvvuU5ld86rrIxNXr17t1vUffvhhlTVu3Fhl//nPf9y6XkXGHTIAAAagIAMAYAAKMgAABqAgAwBgAIflejzLxSYa0uw0evRolZ09e1Zl7733nspmzJjhNLZ7/KLdCTbx8fEqs3vkWEXm5jbyGFP2o7tiY2NVtmDBAqdxYGBgWS3nkjIyMlT2xRdfqOzVV19VWXJycqmsqajYj/81f/58lfXt21dljRo1chq3bt1azXnjjTdUFhkZqTK777/333+/yvLz81Xmi9zdj9whAwBgAAoyAAAGoCADAGAACjIAAAbwuqaugwcPquzyyy8v1rXsHm0XGhparGtVdDTRFN0tt9ziNF65cmWxrzVt2jSV/fLLL269Ni8vz2k8derUYq/DFOzH/3rooYdUZte8euDAAadxSEiImmPXzDpx4kSVTZ48WWW5ubmXXKcvo6kLAAAvQkEGAMAAFGQAAAxAQQYAwABe9/jF8PBwlQ0ZMkRlEyZMUJlrk8vNN9/suYUBReT6eDs/P79yWgl82eHDh92ad+WVVzqNd+3apebYnba1Y8eOYq0LGnfIAAAYgIIMAIABKMgAABjA6w4GcVerVq1UdvLkSadxenp6WS3H53EQA0zCfoRJOBgEAAAvQkEGAMAAFGQAAAxAQQYAwAA+29SFskUTDUzCfoRJaOoCAMCLUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAAO4fVIXAAAoPdwhAwBgAAoyAAAGoCADAGAACjIAAAagIIvIoUOHxOFwyMsvv+yxa27cuFEcDods3LjRY9dExcB+hEnYj2XHawvy/PnzxeFwyLZt28p7KaUmPT1d+vfvL8HBwVKzZk259dZb5fvvvy/vZcFGRdiP69evl549e0q9evUkODhYIiMj5d133y3vZcGGr+/H5s2bi8PhsP3VokWL8l5esVUu7wXA3qlTp6Rnz55y4sQJ+fvf/y7+/v4ydepU6dGjh+zYsUPq1q1b3ktEBbJq1SqJiYmRa6+9VhITE8XhcMjSpUslLi5OMjMzZcSIEeW9RFQg06ZNk1OnTjllP/zwg4wdO1ZuvPHGclpVyVGQDTVr1ixJS0uT1NRU6dSpk4iI9OnTR9q1aydTpkyRiRMnlvMKUZHMmDFDGjduLBs2bJCAgAARERk6dKi0atVK5s+fT0FGmYqJiVHZP/7xDxERuffee8t4NZ7jtf9k7Y6cnBwZP368XHPNNVKrVi2pVq2adO/eXZKTky/6mqlTp0pISIgEBQVJjx49ZPfu3WrO3r17JTY2VurUqSOBgYESEREhq1atKnQ9p0+flr1790pmZmahc5OSkqRTp04FxVhEpFWrVtK7d29ZunRpoa+Hebx5P2ZnZ0vt2rULirGISOXKlaVevXoSFBRU6OthHm/ej3bee+89ueKKK6Rr167Fer0JfLogZ2dny5tvvilRUVEyefJkSUxMlIyMDImOjpYdO3ao+e+8845Mnz5dhg0bJk8//bTs3r1bevXqJUeOHCmYs2fPHunSpYt8++23MmbMGJkyZYpUq1ZNYmJiZMWKFZdcT2pqqrRu3VpmzJhxyXn5+fmya9cuiYiIUL8XGRkpBw4ckJMnT7r3RYAxvHU/iohERUXJnj17ZNy4cbJ//345cOCATJgwQbZt2yajR48u8tcC5c+b96Orr776Sr799lu55557ivxao1he6u2337ZExNq6detF5+Tm5lrnzp1zyn777TerYcOG1pAhQwqygwcPWiJiBQUFWYcPHy7IU1JSLBGxRowYUZD17t3bCg8Pt86ePVuQ5efnW127drVatGhRkCUnJ1siYiUnJ6ssISHhkn+2jIwMS0Ss5557Tv3ezJkzLRGx9u7de8lroGz58n60LMs6deqU1b9/f8vhcFgiYomIVbVqVWvlypWFvhZlz9f3o6uRI0daImJ98803RX6tSXz6DtnPz0+qVKkiIn/cdWZlZUlubq5ERETI9u3b1fyYmBhp2rRpwTgyMlI6d+4sa9euFRGRrKws2bBhg/Tv319OnjwpmZmZkpmZKceOHZPo6GhJS0uT9PT0i64nKipKLMuSxMTES677zJkzIiJO/zx4QWBgoNMceA9v3Y8if+zFli1bSmxsrCxevFgWLlwoERERMnDgQPniiy+K+JWACbx5P/5Zfn6+vP/++9KxY0dp3bp1kV5rGp9v6lqwYIFMmTJF9u7dK+fPny/Ir7jiCjXXrl2+ZcuWBZ/Z7t+/XyzLknHjxsm4ceNs3+/o0aNOm7Y4Lnwmd+7cOfV7Z8+edZoD7+KN+1FEZPjw4fLFF1/I9u3bpVKlP/4e379/f2nbtq3Ex8dLSkpKid8DZc9b9+Of/fvf/5b09HSfaCz06YK8cOFCGTRokMTExMiTTz4pDRo0ED8/P5k0aZIcOHCgyNfLz88XEZFRo0ZJdHS07ZywsLASrVlEpE6dOhIQECC//PKL+r0LWZMmTUr8Pihb3rofc3JyZN68eTJ69OiCYiwi4u/vL3369JEZM2ZITk5Owd0WvIO37kdXixYtkkqVKsndd9/t8WuXNZ8uyElJSRIaGirLly8Xh8NRkCckJNjOT0tLU9m+ffukefPmIiISGhoqIn98I7rhhhs8v+D/V6lSJQkPD7f9of6UlBQJDQ2VGjVqlNr7o3R46348duyY5ObmSl5envq98+fPS35+vu3vwWzeuh//7Ny5c7Js2TKJioryiZsUn/8MWUTE+tMjn1NSUmTLli2281euXOn0GUdqaqqkpKRInz59RESkQYMGEhUVJXPnzrW9e83IyLjkeorS1h8bGytbt251KsrfffedbNiwQe68885CXw/zeOt+bNCggQQHB8uKFSskJyenID916pSsXr1aWrVqxUcoXshb9+OfrV27Vo4fP+7VP3v8Z15/h/zWW2/JunXrVB4fHy/9+vWT5cuXy2233SZ9+/aVgwcPypw5c6RNmzbqlBeRP/45pVu3bvLoo4/KuXPnZNq0aVK3bl2nH+uYOXOmdOvWTcLDw+Whhx6S0NBQOXLkiGzZskUOHz4sO3fuvOhaU1NTpWfPnpKQkFBo48Lf/vY3eeONN6Rv374yatQo8ff3l1deeUUaNmwoI0eOdP8LhDLli/vRz89PRo0aJWPHjpUuXbpIXFyc5OXlybx58+Tw4cOycOHCon2RUGZ8cT/+2aJFiyQgIEDuuOMOt+Ybr9z6u0voQlv/xX799NNPVn5+vjVx4kQrJCTECggIsDp27GitWbPGuv/++62QkJCCa11o63/ppZesKVOmWM2aNbMCAgKs7t27Wzt37lTvfeDAASsuLs5q1KiR5e/vbzVt2tTq16+flZSUVDDHE239P/30kxUbG2vVrFnTql69utWvXz8rLS2tuF8ylKKKsB8XLVpkRUZGWsHBwVZQUJDVuXNnp/eAOSrCfjxx4oQVGBho3X777cX9MhnHYVl/+vcKAABQLnz6M2QAALwFBRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCA2yd1/fmsU8BVWf84O/sRl8J+hEnc3Y/cIQMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAaoXN4LAHBxdevWVdl9992nsttvv11l3bt3V1lSUpLKPvvsM6fxa6+9VpQlAvAQ7pABADAABRkAAANQkAEAMAAFGQAAAzgsy7LcmuhwlPZa4MXc3EYeU1H249ixY1X27LPPquzs2bMqW7JkicoGDhyosjNnzjiNQ0JC1Jzjx49fapnGYT/CJO7uR+6QAQAwAAUZAAADUJABADAABRkAAAPQ1FXKwsLCVPbOO++U6nvefPPNKvv9999Vdv78eY+9J000JdeuXTuVffzxxypr1KiRyuwavZ577jmVffPNNyr7y1/+4jSePn26mtO8eXOVPf744yr74YcfVFYeKsJ+vP7661Vm99/u559/VtmyZctUNm/ePM8sDApNXQAAeBEKMgAABqAgAwBgAAoyAAAGoKnLgzp06KCyTz75RGV2j9QrbXanNj322GNO48zMzGJfvyI00ZS2CRMmqOzvf/+7ytLS0lQWHh6uMrumvWHDhqls4sSJTuPq1atfcp0XJCcnq+yGG25w67WlrSLsx/T0dJXZNfy567fffnMaf/DBB2qOXTOY6+tERL788stir8MX0dQFAIAXoSADAGAACjIAAAagIAMAYIDK5b0AX2J3clF5NHDZGTBggMreeOMNp7Fdkw7KTmBgoFvz8vLyVObuqWszZ85UmWtD0uTJk91a2/r16916T5SOTZs2qax9+/Yq8/PzU1loaKjK6tSp4zR++OGH1Ry7zG4/Hjx4UGV2jW9bt25VmWvT4o033qjmnDx5UmVPPvmkynbt2qUyk3GHDACAASjIAAAYgIIMAIABKMgAABjAZ5u6qlWrprL69es7je2asOyaBexOrGnbtq3K+vXr59bazp496zS2ezTi/v37Vfbpp5+qbPTo0SqrUqWKW+uAWTIyMlRm1wjj6VOhFi9e7DS2OzEsKChIZcuXL/foOlA0do2adipX1t/mr776apWNGjXKaWx36lqtWrVUZtc0ZvfYWbt9e+WVV6qsuFzXLyISFxfnseuXBe6QAQAwAAUZAAADUJABADAABRkAAAP4RFNXkyZNVDZ37lyV3XzzzcW6vt0j8Ny1evVqla1Zs8Zp/Oabbxb7+nbNGcX9c8I8do9t8/SjBV955RWncc2aNdWcpKQkle3bt8+j60DpyM3NVVlqaqrK+vfv7zQODg5Wc7p06aKy6OholdmdUGjX1BUbG6uy4jalmnIqYklwhwwAgAEoyAAAGICCDACAAbzuM2S7zxdcn1okInLTTTeVxXKczJ8/X2WPPfaYyuwOAikP8fHxTmOe9uQdatSooTLXJ/WIiGRlZamsXbt2KrM7AMIVh4BUPMePH1fZunXr3Mrcdd9996ksMTHRaTx+/Hi3rjVs2LBir8MU3CEDAGAACjIAAAagIAMAYAAKMgAABjC6qWvEiBEqe/TRR1XmySeG2LH7wfoFCxaozO5pI55s4LJr5gkMDCz29ewOe0D5WbRokcomTZqksqZNm6rs2WefVZndgTOuh9KIiDRq1Mhp/N5776k5y5YtUxlQGq655hqnsd1BOHZP4Dt06FBpLanMcIcMAIABKMgAABiAggwAgAEoyAAAGMCYpq7GjRurLCoqSmWl3cBlx66B6+GHHy7zdTz44IMq69WrV7Gvl56eXpLlwMOOHj2qsk2bNqmsW7duKrvjjjtU1r59e5XZPRnthx9+cBonJCSoOefPn1cZUFJ2J3X9z//8T6Gv+/LLL4v9nnanPV522WUq+/7774v9HsXFHTIAAAagIAMAYAAKMgAABqAgAwBgAGOautq0aaOyfv36lfk67B6haHcCV2m7/PLLVXbXXXeV+TpQduwap1555RWVde/eXWWup22J2DdK5uTkqCwuLs5pXB7NLCi6iRMnqiwiIkJldif82Z1qtWXLFqex3QlZJeFwOFR22223qczf37/Qa/Xs2VNl586dU5lds2PNmjVVVq9ePZU1a9as0HV4GnfIAAAYgIIMAIABKMgAABiAggwAgAEclpuf3Nt9IO9JGRkZKqtTp45H3+Prr792Gvft21fNOX78uMo8+QhFd9k1tH344YfFvt706dNVNmbMGKexXVOEuzzdAFKY0t6PprD7f2D//v0qq1WrlsrsvkY33nijytavX1/M1ZmrIuxHT3/PdP0zlEVTlyff4+eff1ZZcnKyylzrgIjI6tWrVbZ3717PLEzc/3NyhwwAgAEoyAAAGICCDACAASjIAAAYwJiTuuxOSsnPzy/29TZv3qyyu+++22lcXo8fbNu2rdPY7uSll156qdjX//XXX1W2YcMGlZWkiQtl495771WZXQOXu3766aeSLAcGcT1ZS8T90w3tvvcdO3bMaexuI5JdM5XdiWENGjRw63qu7L6Xv/rqqypLSkoq1vVNwh0yAAAGoCADAGAACjIAAAagIAMAYABjmro8rXr16iqrXLl0/7jDhw9X2fXXX6+y0NBQp3HHjh09uo7Bgwer7JNPPvHoe8DzoqOjVfbiiy+qzO4RillZWSqze/ziDTfcoLLvvvvO3SXCIHaPYw0MDHTrtXZ7yO7xn+6waw61e+Tjjh07VNa8eXOVpaamOo3tHrWYm5vr/gK9CHfIAAAYgIIMAIABKMgAABjAmM+Q7Z5g4/pZa1G0b99eZa6fYeTl5RX7+naqVq2qsoCAAI9d/+jRoyqz+6H8f//73x57T5Sehg0bOo1feOEFNadKlSoqe+SRR1T27bffquyzzz5T2cSJE1Xm+qSbH3/8US8Wxjl9+rRbWXmIj49X2RVXXKEyu8NHli5d6jT21c+L7XCHDACAASjIAAAYgIIMAIABKMgAABjAmKYuu0YVTx9mUbNmTY9erzTNmjVLZZ9++qnKVq1aVRbLQSl4/vnnncZ2jYivv/66yt544w23ru9wOFRmd2BDhw4dnMY0daEounXrprInnnjCrdcuX75cZbNnzy7xmrwVd8gAABiAggwAgAEoyAAAGICCDACAAYxp6tq5c6fKXE9sERHp379/WSynVK1du9ZpbNfAtX79epUV92ksKH/33nuvyu644w6ncXp6upozatQot65v95Qfu1OQ7DLAXXaNsW+99ZbKatWqpTK7/f3000+r7OzZs8VcnffjDhkAAANQkAEAMAAFGQAAA1CQAQAwgDFNXZmZmSq77777VGZ3oteaNWtKZU1F9eKLL6rM7nQt18c+0qzl++z2rWvji11zzODBg1XWq1cvlV1++eUlWB3gnmHDhqksLCzMrdfaNXDZPXa3IuMOGQAAA1CQAQAwAAUZAAADUJABADCAw3Lz6B67R7kBF5T1CVDeth8///xzlV133XVO4zNnzqg5didwuevXX39Vmd2j7VybEXNycor9nqZgP5ac3aM/Bw0apDI/Pz+Vbdu2TWXXX3+9ys6dO1e8xXkZd/cjd8gAABiAggwAgAEoyAAAGICCDACAAWjqgkfQRHNp7dq1U1nfvn2dxrfddpua06lTJ5Vt3bpVZV9//bXKxo4dq7IjR45ccp2+gv1YdD169HAa250yaNfA5XryoIj9Xv7oo49KsDrvRlMXAABehIIMAIABKMgAABiAggwAgAFo6oJH0EQDk7Afi65Pnz5OY3cfa5uYmKiyCRMmeGJJPoOmLgAAvAgFGQAAA1CQAQAwQOXyXgAAoPy5PqHp5MmTak5ycrLK3nvvvVJbU0XDHTIAAAagIAMAYAAKMgAABqAgAwBgAA4GgUdwEANMwn6ESTgYBAAAL0JBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAM4PZJXQAAoPRwhwwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAAf4PkLEsNK/INnsAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# plot a 3x3 grid of MNIST digits\n", + "idxs = np.random.randint(0, len(X_train), size=(3, 3))\n", + "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", + "\n", + "for i in range(3):\n", + " for j in range(3):\n", + " axes[i, j].imshow(X_train[idxs[i, j]], cmap=\"gray\")\n", + " axes[i, j].axis(\"off\")\n", + " axes[i, j].set_title(f\"Label: {y_train[idxs[i, j]]}\")\n", + "\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Model\n", + "\n", + "To create a convolutional neural network using NNX define a `nnx.Module` subclass. We define the model by subclassing `nnx.Module` and defining a `forward` method that returns the model output. Like in PyTorch, the `__init__` method instantiates all the modules that will be used in the model. The `__call__` in this case\n", + "will define the forward computation. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/plain": [ + "(1, 10)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from flax.experimental import nnx\n", + "\n", + "\n", + "class CNN(nnx.Module):\n", + "\n", + " def __init__(self, *, rngs: nnx.Rngs):\n", + " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", + " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", + " self.linear1 = nnx.Linear(7 * 7 * 64, 256, rngs=rngs)\n", + " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", + " self.num_calls = nnx.var(\"counts\", 0)\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " self.num_calls += 1\n", + " x = self.conv1(x)\n", + " x = nnx.relu(x)\n", + " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = self.conv2(x)\n", + " x = nnx.relu(x)\n", + " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " x = self.linear1(x)\n", + " x = nnx.relu(x)\n", + " x = self.linear2(x)\n", + " return x\n", + "\n", + "\n", + "model = CNN(rngs=nnx.Rngs(0))\n", + "\n", + "y = model(X_train[:1])\n", + "y.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One notable difference with other frameworks is that `__init__`, by convention, accepts a `rngs: nnx.Rngs` keyword-only argument. This object is passed around to generate PRNG keys as random state is explicit in JAX.\n", + "\n", + "One of the nice things about NNX is that Module contain their own state, are fully inspectable, and you can run them eargerly. For example, we can easily check out the kernel shape of the first `Conv` layer:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 3, 1, 32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.conv1.kernel.shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also view the entire `State` of the model using the `.filter()` method. TODO: talk about collections." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "State({\n", + " 'conv1/bias': Variable(\n", + " collection='params',\n", + " value=(32,)\n", + " ),\n", + " 'conv1/kernel': Variable(\n", + " collection='params',\n", + " value=(3, 3, 1, 32)\n", + " ),\n", + " 'conv2/bias': Variable(\n", + " collection='params',\n", + " value=(64,)\n", + " ),\n", + " 'conv2/kernel': Variable(\n", + " collection='params',\n", + " value=(3, 3, 32, 64)\n", + " ),\n", + " 'linear1/bias': Variable(\n", + " collection='params',\n", + " value=(256,)\n", + " ),\n", + " 'linear1/kernel': Variable(\n", + " collection='params',\n", + " value=(3136, 256)\n", + " ),\n", + " 'linear2/bias': Variable(\n", + " collection='params',\n", + " value=(10,)\n", + " ),\n", + " 'linear2/kernel': Variable(\n", + " collection='params',\n", + " value=(256, 10)\n", + " )\n", + "})" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.tree_map(jnp.shape, model.extract(nnx.Param))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training in eager mode\n", + "\n", + "For pedagogical purposes, we first train the model in eager mode. This will be uselful to take a look at some of NNX's features, its be more approachable for new users, and great for debugging, but it is not the recommended way to train models in JAX.\n", + "\n", + "Here we will run a simple `for` loop for just 10 iterations, at each step we will sample a batch of data, define a `loss_fn` to compute the loss, and use `nnx.value_and_grad` to compute the gradients of the loss with respect to the model parameters. Using the gradients we will update the parameters using stochastic gradient descent (SGD) via a simple `tree_map` operation. Finally, we will update the model's parameters using the `.update_state` method." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: loss=58.7676\n", + "Step 1: loss=80.0420\n", + "Step 2: loss=108.3005\n", + "Step 3: loss=26.6188\n", + "Step 4: loss=10.7236\n", + "Step 5: loss=4.7499\n", + "Step 6: loss=3.9177\n", + "Step 7: loss=2.9419\n", + "Step 8: loss=2.4733\n", + "Step 9: loss=1.8060\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "for step in range(10):\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " def loss_fn(model: CNN):\n", + " logits = model(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + "\n", + " loss, grads = nnx.value_and_grad(loss_fn, wrt=\"params\")(model)\n", + " params = model.extract(\"params\")\n", + " params = jax.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + "\n", + " model.update(params)\n", + " print(f\"Step {step}: loss={loss:.4f}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The loss is going down 🎉.\n", + "\n", + "### Training with the Functional API\n", + "\n", + "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", + "\n", + "In this next example we will use the `.split` method to split the model into a `params: State` and `moduledef: ModuleDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `moduledef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `ModuleDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "params, moduledef = model.split(\"params\")\n", + "\n", + "\n", + "@jax.jit\n", + "def train_step(params: nnx.State, x, y):\n", + " def loss_fn(params):\n", + " logits, _updates = moduledef.apply(params)(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + " params = jax.tree_map(lambda w, g: w - 0.001 * g, params, grads)\n", + "\n", + " return loss, params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `train_step` we can run a few more iterations and see that the loss is still going down, however, this time execution should be much faster." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: loss=1.4396\n", + "Step 1: loss=1.4127\n", + "Step 2: loss=1.8718\n", + "Step 3: loss=1.7080\n", + "Step 4: loss=1.7984\n", + "Step 5: loss=1.0350\n", + "Step 6: loss=1.2076\n", + "Step 7: loss=0.9081\n", + "Step 8: loss=0.8217\n", + "Step 9: loss=0.6687\n" + ] + } + ], + "source": [ + "for step in range(10):\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " loss, params = train_step(params, x, y)\n", + " print(f\"Step {step}: loss={loss:.4f}\")\n", + "\n", + "model.update(params)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Realistic Training using TrainState\n", + "\n", + "For real training scenarios, we recommend using `TrainState` to manage the state of your training loop. `TrainState` manages the `params` of your network along with other types of state, and uses `optax` to update the parameters according to the gradients.\n", + "\n", + "Next, we will define a `train_step` function that accepts a `TrainState` and a batch of data, and returns a new `TrainState` with updated parameters. The `apply_gradients` method will return a new `state` with the updated parameters. Flax users should be familiar with this API. In this case will will also define a `eval_step` function that will be used to evaluate the model on the test set and return some metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "state = nnx.TrainState(\n", + " apply_fn=moduledef.apply,\n", + " params=params,\n", + " tx=optax.adam(0.001),\n", + ")\n", + "\n", + "\n", + "@jax.jit\n", + "def train_step(state: nnx.TrainState, x, y):\n", + " def loss_fn(params):\n", + " logits, _updates = state.apply_fn(params)(x)\n", + " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", + "\n", + " grads = jax.grad(loss_fn)(state.params)\n", + "\n", + " state = state.apply_gradients(grads=grads)\n", + "\n", + " return state\n", + "\n", + "\n", + "@jax.jit\n", + "def eval_step(state: nnx.TrainState, x, y):\n", + " logits, _updates = state.apply_fn(state.params)(x)\n", + " metrics = {\n", + " 'accuracy': jnp.mean(jnp.argmax(logits, axis=-1) == y),\n", + " 'loss': optax.softmax_cross_entropy_with_integer_labels(logits, y).mean(),\n", + " }\n", + " return metrics" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now lets create a simple training loop that runs for 1000 iterations and prints the metrics every 100 steps. At the end of training we will compute the final metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0: {'accuracy': Array(0.63119996, dtype=float32), 'loss': Array(1.1837534, dtype=float32)}\n", + "Step 100: {'accuracy': Array(0.9492, dtype=float32), 'loss': Array(0.16359854, dtype=float32)}\n", + "Step 200: {'accuracy': Array(0.9564, dtype=float32), 'loss': Array(0.14198248, dtype=float32)}\n", + "Step 300: {'accuracy': Array(0.96279997, dtype=float32), 'loss': Array(0.12757339, dtype=float32)}\n", + "Step 400: {'accuracy': Array(0.97169995, dtype=float32), 'loss': Array(0.09900841, dtype=float32)}\n", + "Step 500: {'accuracy': Array(0.96889997, dtype=float32), 'loss': Array(0.10143881, dtype=float32)}\n", + "Step 600: {'accuracy': Array(0.9745, dtype=float32), 'loss': Array(0.08513925, dtype=float32)}\n", + "Step 700: {'accuracy': Array(0.96379995, dtype=float32), 'loss': Array(0.11632324, dtype=float32)}\n", + "Step 800: {'accuracy': Array(0.97679996, dtype=float32), 'loss': Array(0.07204168, dtype=float32)}\n", + "Step 900: {'accuracy': Array(0.9765, dtype=float32), 'loss': Array(0.08413408, dtype=float32)}\n", + "Final metrics: {'accuracy': Array(0.9819, dtype=float32), 'loss': Array(0.05711861, dtype=float32)}\n" + ] + } + ], + "source": [ + "total_steps = 1000\n", + "eval_every = 100\n", + "\n", + "for step in range(total_steps):\n", + " if step % eval_every == 0:\n", + " metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", + " print(f\"Step {step}: {metrics}\")\n", + "\n", + " idxs = np.random.randint(0, len(X_train), size=32)\n", + " x = jnp.array(X_train[idxs])\n", + " y = jnp.array(y_train[idxs])\n", + "\n", + " state = train_step(state, x, y)\n", + "\n", + "metrics = eval_step(state, jnp.array(X_test), jnp.array(y_test))\n", + "print(f\"Final metrics: {metrics}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference\n", + "\n", + "Finally, now that we have a trained model, lets use it to make some predictions. We will update the `model` object with the trained parameters and use it to make predictions on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAH4CAYAAACbup4ZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABBzklEQVR4nO3de1hVVf7H8S83lXukqJiGaGrmJW9Zk5cUUUa8JOakZY1iTVTeffLalKWOllrpoJk2hdXgpKaMk6GOlk6ieSnJUrPMMDXGS5PiDS/A+v3hD2qztnA4HDgLeL+exz/Wh7X3XpxWfNnnLNb2UEopAQAAbuXp7gEAAAAKMgAARqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBggEpXkOvXry9Dhw7Nb2/ZskU8PDxky5YtLruGh4eHvPDCCy47Hyou5iNMwnx0rzItyEuXLhUPD4/8f9WqVZPGjRvLiBEj5OTJk2U5lBJLSUkpN5Nq165d8vTTT0vbtm3Fx8dHPDw83D0kIzAfy15ubq4sXbpU+vbtK/Xq1RN/f39p3ry5zJgxQy5fvuzu4bkV89F9FixYIE2bNpWqVavKLbfcIuPGjZOLFy+W+Ti8y/yKIjJt2jSJiIiQy5cvS2pqqixatEhSUlJk37594ufnV6Zj6dy5s2RlZUmVKlWKdVxKSoosXLjQdtJlZWWJt7dbXlpbKSkp8re//U1atmwpDRo0kO+++87dQzIK87HsXLp0SeLi4uSee+6RJ598UmrWrCmfffaZTJ06VT7++GP55JNPKv0vjMzHsjVx4kSZPXu2DBgwQEaPHi0HDhyQhIQE2b9/v2zYsKFsB6PKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL5/TTpw4oS5duqSUKl/jLm3Mx7J35coVtW3bNi1/8cUXlYiojRs3umFUZmA+lr2MjAzl7e2tHn30UUuekJCgRET961//KtPxGPEZcmRkpIiIpKeni4jI0KFDJSAgQA4fPiwxMTESGBgogwcPFpHrb3nNmzdPmjVrJtWqVZNatWpJfHy8nDlzxnJOpZTMmDFD6tatK35+ftK1a1fZv3+/du0bfUayc+dOiYmJkZCQEPH395eWLVvK/Pnz88e3cOFCERHLW0x57D4jSUtLk549e0pQUJAEBARIt27dZMeOHZY+eW9Zbdu2TcaNGyehoaHi7+8vsbGxcvr0aUvfzMxMOXjwoGRmZhb5+taqVUt8fX2L7IfrmI/XlcZ8rFKlitx7771aHhsbKyIi33zzTaHHV0bMx+tKYz5+9tlnkp2dLYMGDbLkee3333+/0ONdzYj3DQ4fPiwiItWrV8/PsrOzJTo6Wjp27Chz587Nf6smPj5eli5dKnFxcTJq1ChJT0+XBQsWSFpammzbtk18fHxEROT555+XGTNmSExMjMTExMiePXukR48ecvXq1SLHs3HjRundu7eEhYXJ6NGjpXbt2vLNN9/I2rVrZfTo0RIfHy8ZGRmyceNGee+994o83/79+6VTp04SFBQkEyZMEB8fH1m8eLF06dJF/vOf/8jdd99t6T9y5EgJCQmRqVOnypEjR2TevHkyYsQIWb58eX6f5ORkiYuLk8TERMsiDJQc87Hs5+OJEydERKRGjRrFPraiYz6W3ny8cuWKiIh2w5L3en7xxRdFjt+lyvJ2PO8tmU2bNqnTp0+rY8eOqffff19Vr15d+fr6quPHjyullBoyZIgSETVp0iTL8Vu3blUiopKSkiz5+vXrLfmpU6dUlSpVVK9evVRubm5+vylTpigRsbwls3nzZiUiavPmzUoppbKzs1VERIQKDw9XZ86csVznt+cq7C0ZEVFTp07Nb/fr109VqVJFHT58OD/LyMhQgYGBqnPnztrrExUVZbnW2LFjlZeXlzp79qzWNzEx0XYMN1Je3koqC8xH98/HPFFRUSooKEj7HisT5mPZz8cvvvhCiYiaPn26Jc97zQICAgo93tXc8pZ1VFSUhIaGSr169WTQoEESEBAgycnJcsstt1j6PfXUU5b2ypUrJTg4WLp37y4///xz/r+2bdtKQECAbN68WURENm3aJFevXpWRI0da3ioZM2ZMkWNLS0uT9PR0GTNmjNx0002Wrzmz2CQnJ0f+/e9/S79+/aRBgwb5eVhYmDz88MOSmpoq586dsxzzxBNPWK7VqVMnycnJkR9//DE/Gzp0qCiluDt2Aeaje+fjzJkzZdOmTfLSSy9p32NlxHwsu/nYpk0bufvuu+Xll1+WxMREOXLkiKxbt07i4+PFx8dHsrKyiv09lYRb3rJeuHChNG7cWLy9vaVWrVrSpEkT8fS0/m7g7e0tdevWtWSHDh2SzMxMqVmzpu15T506JSKS/x+mUaNGlq+HhoZKSEhIoWPLe3uoefPmjn9DhTh9+rRcunRJmjRpon2tadOmkpubK8eOHZNmzZrl57feequlX96YC34OBNdgPl7njvm4fPly+fOf/yyPPfaYVmAqK+bjdWU1H1etWiUDBw6UYcOGiYiIl5eXjBs3Tv7zn//It99+69Q5neWWgty+fXtp165doX2qVq2qTcLc3FypWbOmJCUl2R4TGhrqsjG6k5eXl22ulCrjkVQOzMfCldZ83Lhxo/zxj3+UXr16yRtvvFGic1UkzMfCuXo+3nLLLZKamiqHDh2SEydOSKNGjaR27dpSp04dady4cUmGWmxGLOpyVMOGDWXTpk3SoUOHQlcNh4eHi8j13xh/+zbI6dOni/wtqmHDhiIism/fPomKirphP0ffngkNDRU/Pz/b37QOHjwonp6eUq9ePYfOBbMwH523c+dOiY2NlXbt2smKFSuM+rvU8or5WDKNGjXKf9fgwIED8t///rfMPxI04s+eHPXggw9KTk6OTJ8+Xftadna2nD17VkSufwbj4+MjCQkJlt+a5s2bV+Q12rRpIxERETJv3rz88+X57bn8/f1FRLQ+BXl5eUmPHj1kzZo1cuTIkfz85MmTsmzZMunYsaMEBQUVOa6CivNnTygdzMdfFWc+fvPNN9KrVy+pX7++rF27lj/JcxHm469K8vMxNzdXJkyYIH5+fvLkk08W+/iSKFe/lt53330SHx8vs2bNki+//FJ69OghPj4+cujQIVm5cqXMnz9fBgwYIKGhofLMM8/IrFmzpHfv3hITEyNpaWmybt26Iv+swtPTUxYtWiR9+vSRVq1aSVxcnISFhcnBgwctO7e0bdtWRERGjRol0dHR4uXlpf0tW54ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rUozp+Z/Pjjj/l/fvD555/nj0nk+m/Ljz76qFNjqOyYj79ydD6eP39eoqOj5cyZMzJ+/Hj56KOPLF9v2LCh/O53v3NqDJUd8/FXxfn5OHr0aLl8+bK0atVKrl27JsuWLZNdu3bJO++8o31eXerKckn3jXaiKWjIkCHK39//hl9fsmSJatu2rfL19VWBgYGqRYsWasKECSojIyO/T05OjnrxxRdVWFiY8vX1VV26dFH79u3TdqIpuKw/T2pqqurevbsKDAxU/v7+qmXLliohISH/69nZ2WrkyJEqNDRUeXh4WJb4S4Fl/UoptWfPHhUdHa0CAgKUn5+f6tq1q9q+fbtDr4/dGIvzZyZ5x9v9u++++4o8vqJiPpb9fExPT7/hXJQCf3JT2TAf3fPzMTExUd15553K399fBQYGqm7duqlPPvmkyONKg4dSrBQCAMDdytVnyAAAVFQUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAO79TlzKO1UHmU9Z+zMx9RGOYjTOLofOQOGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwgMOPXwTgvOjoaC0bP368pR0ZGenSa9o9EnDNmjVatn37dkt73rx5Wp+rV6+6bFwA7HGHDACAASjIAAAYgIIMAIABPJRSyqGONp9HAXkcnEYuY/J8rF+/vpbt379fy6pVq1YGoym+lJQULXv11Ve1bPPmzWUxHKcwH2ESR+cjd8gAABiAggwAgAEoyAAAGICCDACAAVjUBZdgEc2vGjZsqGXfffedG0biOufPn9eyyZMna9nq1au17OTJk6UypsIwH2ESFnUBAFCOUJABADAABRkAAANQkAEAMACLugpo27atlv3lL3/RMn9/fy2bPn26lv373/92zcAMxyKaX/n6+mrZzJkztSwzM9PS3rRpk9bnz3/+s5YlJiZqWUxMjJY1a9ZMy1q3bq1lrmS3y1efPn1K9Zp2mI8wCYu6AAAoRyjIAAAYgIIMAIABKMgAABiARV0FxMXFadmbb77p0LHZ2dla1qNHDy379NNPiz8ww7GIxjx2i8vsFli9/vrrlnZISIjT1/z++++1rF27dlpmt/OXKzEfCzdp0qQis6CgIIfO9dprr2nZ7NmztczRHdsK/sxs06aN1uell15y6FymYFEXAADlCAUZAAADUJABADAABRkAAAOwqKuA2267TcvWr1+vZfXr19cyu9dow4YNWma3q1J5xyKa8qvgohm7HcOCg4OdPv9jjz2mZUuXLnX6fI5gPv5q8ODBWvb3v/9dyy5cuGBpZ2VlaX3sdii0Wzz41VdfadnEiRO1zM/PT8vefffdIq/ZqlUrh65pChZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7ocMHLkSC2z253G7jXKyMjQsnr16rlmYAZhEU3F8eSTT2rZwoULnT7f8uXLtezhhx92+nyOYD7+auvWrVrWoUMHLSu4u5bdbl533HGHlr3yyitaFh0drWVXr17Vsl9++UXLateurWUF/elPf9Kyt956q8jj3IVFXQAAlCMUZAAADEBBBgDAABRkAAAM4O3uAQCo2CIjI909BDjAkd3YDhw4oGWxsbFaZrcw8NVXX9UyRxZwVSbcIQMAYAAKMgAABqAgAwBgAD5DBoAK7NChQ1pmtzHIsGHDLO2kpCStT2pqqpZdvnxZy3bv3l2cIeL/cYcMAIABKMgAABiAggwAgAEoyAAAGIBFXQBK1fz58909hErtpZde0rJevXppWc2aNS3tTz/9VOszbtw4LVu7dq2WlfbTr06ePFmq53cX7pABADAABRkAAANQkAEAMAAFGQAAA7CoC4DF448/7vSxdrs27d+/vyTDQQl99913WtajRw8t+/jjjy3tm2++Wetj98SmqVOnatnnn39enCEWym7HMLuFZBUBd8gAABiAggwAgAEoyAAAGICCDACAAVjUVcoWLFjg7iHACf7+/loWEhKiZfHx8VrWoEGDUhlTYVJSUrQsMzNTy8LDw7Xs3nvvtbRbtGjh9Dh++OEHLfvXv/7l9PlQOvbu3atljz76qKX91ltvaX1q166tZcHBwVrWrVs3p8d2/vx5S/svf/mL0+cqb7hDBgDAABRkAAAMQEEGAMAAFGQAAAzAoq5SdvHiRXcPAUWoWrWqlv3973/Xsr59+5bFcJwyaNAgdw9BROwX/XTo0EHLdu7cqWXZ2dmlMiY4Zt26dZZ2kyZNtD5jxozRsoceekjL/Pz8tOzWW291aBw//fSTpX3w4EGHjqsIuEMGAMAAFGQAAAxAQQYAwAAUZAAADMCiLgd4eHg4lHl66r/f2PWDWbKysrRMKeWGkZR/do/s+/TTT7Vs4sSJWjZ37txSGROcU3DHLBGR6dOnO5TZLeTbunWrQ9e1e1xkZcEdMgAABqAgAwBgAAoyAAAGoCADAGAAFnU5wG6Bj12Wk5OjZRcuXCiVMcFc586d07J//OMfTp3rkUce0TK7R0MCJinJYtbt27e7cCTlC3fIAAAYgIIMAIABKMgAABiAz5BdyO7JTomJiW4YCcrKl19+qWX9+vXTsmPHjjl1/q+++krLWrdurWWPP/64U+e388MPP2hZWlqalkVGRmpZSEiIy8aB8qskG+v079/f0p49e3ZJh1NucIcMAIABKMgAABiAggwAgAEoyAAAGIBFXUAJ3HbbbVq2YsUKl53fbgGXj4+Py84vInL06FFL+/XXX9f6vPbaa1rWp08fLRs3bpxD1zx+/LiDo0NlExQU5O4huA13yAAAGICCDACAASjIAAAYgIIMAIABWNTlQsuXL3f3EOCEb775RssaNWqkZV5eXloWEBCgZe3bt3fNwErBTz/9pGW///3vLe1vv/3WoXN9+OGHDmVAcezevdvdQ3Ab7pABADAABRkAAANQkAEAMAAFGQAAA7CoywGrVq3SspkzZ2qZh4dHWQwHLtasWTMtmzt3rpY99thjWuaOXYWuXbumZWfPntWyf/zjH1q2ePFiLXN0ERdQFvbt2+fuIbgNd8gAABiAggwAgAEoyAAAGICCDACAAVjU5YCMjAwty83N1bK6deuWxXBQBp555hktW7RokZZt2LBByyIiIpy6ZmpqqpbZ7Xz1448/atnKlSuduiZQGkqywLUyL47lDhkAAANQkAEAMAAFGQAAA1CQAQAwgIdSSjnUsRJ/0G4nMzPToX7BwcGlPBIzODiNXIb5iMIwH92rQ4cOWrZ161aHjt21a5elfc8997hkTO7k6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwADt1Oenzzz/Xsnbt2rlhJABQcbjjkaam4A4ZAAADUJABADAABRkAAAPwGbKTpk+frmUTJ050w0gAoOJwdNOliog7ZAAADEBBBgDAABRkAAAMQEEGAMAAPO0JLsHTdWAS5qN73XbbbVqWlJSkZXfddZeWxcbGWtpr1qxx3cDchKc9AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyigUmYjzAJi7oAAChHKMgAABiAggwAgAEoyAAAGMDhRV0AAKD0cIcMAIABKMgAABiAggwAgAEoyAAAGKDSFeT69evL0KFD89tbtmwRDw8P2bJli8uu4eHhIS+88ILLzoeKi/kIkzAf3atMC/LSpUvFw8Mj/1+1atWkcePGMmLECDl58mRZDqXEUlJSytWkys3NlUWLFkmrVq3E19dXqlevLpGRkbJ37153D81tmI/ud+3aNbnjjjvEw8ND5s6d6+7huBXz0X1M+fnoXaZX+3/Tpk2TiIgIuXz5sqSmpsqiRYskJSVF9u3bJ35+fmU6ls6dO0tWVpZUqVKlWMelpKTIwoULbSddVlaWeHu75aW9oWHDhklSUpL88Y9/lBEjRsjFixclLS1NTp065e6huR3z0X0SEhLk6NGj7h6GUZiPZc+Un49ueVV69uwp7dq1ExGRxx9/XKpXry6vvvqqrFmzRh566CHbYy5evCj+/v4uH4unp6dUq1bNped09flKasWKFfLOO+/I6tWrJTY21t3DMQ7z0T1OnTol06ZNk4kTJ8rzzz/v7uEYg/lYtkz6+WjEZ8iRkZEiIpKeni4iIkOHDpWAgAA5fPiwxMTESGBgoAwePFhErr+1MG/ePGnWrJlUq1ZNatWqJfHx8XLmzBnLOZVSMmPGDKlbt674+flJ165dZf/+/dq1b/QZyc6dOyUmJkZCQkLE399fWrZsKfPnz88f38KFC0VELG8x5bH7jCQtLU169uwpQUFBEhAQIN26dZMdO3ZY+uS9ZbVt2zYZN26chIaGir+/v8TGxsrp06ctfTMzM+XgwYOSmZlZ5Ov76quvSvv27SU2NlZyc3Pl4sWLRR5TmTEfryut+Zhn0qRJ0qRJE3nkkUccPqYyYj5eVxl+PhpRkA8fPiwiItWrV8/PsrOzJTo6WmrWrClz586VBx54QERE4uPjZfz48dKhQweZP3++xMXFSVJSkkRHR8u1a9fyj3/++eflueeekzvvvFPmzJkjDRo0kB49ejj0Ym/cuFE6d+4sBw4ckNGjR8srr7wiXbt2lbVr1+aPoXv37iIi8t577+X/u5H9+/dLp06dZO/evTJhwgR57rnnJD09Xbp06SI7d+7U+o8cOVL27t0rU6dOlaeeeko+/PBDGTFihKVPcnKyNG3aVJKTkwv9Xs6dOye7du2Su+66S6ZMmSLBwcESEBAgDRo0kBUrVhT5WlRGzEcrV87HPLt27ZJ33nlH5s2bx6MLi8B8tKrQPx9VGUpMTFQiojZt2qROnz6tjh07pt5//31VvXp15evrq44fP66UUmrIkCFKRNSkSZMsx2/dulWJiEpKSrLk69evt+SnTp1SVapUUb169VK5ubn5/aZMmaJERA0ZMiQ/27x5sxIRtXnzZqWUUtnZ2SoiIkKFh4erM2fOWK7z23MNHz5c3ejlExE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HkGfPnj1KRFT16tVVrVq11Ouvv66SkpJU+/btlYeHh1q3bl2hx1dkzMeyn495427fvr166KGHlFJKpaenKxFRc+bMKfLYioz5yM9HtxTkgv/Cw8PV+vXr8/vlTbgff/zRcvyoUaNUcHCwOnXqlDp9+rTlX0BAgHr88ceVUkotW7ZMiYjlnEpdn4hFTbjdu3crEVGvvfZaod+LoxMuOztb+fn5qQcffFDrFx8frzw9PVVmZqbl9VmxYoWl3+rVq5WIqL179xY6Jjuffvpp/uu8Y8eO/Pz8+fOqRo0aqkOHDsU+Z0XBfLQqi/molFJvv/228vX1VUePHlVKUZDzMB+tKuPPR7cs6lq4cKE0btxYvL29pVatWtKkSRPx9LS+e+7t7S1169a1ZIcOHZLMzEypWbOm7XnzVsT9+OOPIiLSqFEjy9dDQ0MlJCSk0LHlvT3UvHlzx7+hQpw+fVouXbokTZo00b7WtGlTyc3NlWPHjkmzZs3y81tvvdXSL2/MBT8HcoSvr6+IiERERMjdd9+dnwcEBEifPn3k73//u2RnZxu36rEsMR+vK4v5eO7cOZk8ebKMHz9e6tWrV+zjKwPm43WV8eejW34Kt2/fPn8V4Y1UrVpVm4S5ublSs2ZNSUpKsj0mNDTUZWN0Jy8vL9tcOfFgrjp16oiISK1atbSv1axZU65duyYXL16U4ODgYp+7omA+Fs6V83Hu3Lly9epVGThwoBw5ckRERI4fPy4i13+gHjlyROrUqVPsP7OpSJiPhavIPx/L1W1Rw4YNZdOmTdKhQ4f832zshIeHi8j13xgbNGiQn58+fbrI36IaNmwoIiL79u2TqKioG/ZzdCFKaGio+Pn5ybfffqt97eDBg+Lp6Vmqdwp16tSR2rVry08//aR9LSMjQ6pVqyaBgYGldv2KjPlYfEePHpUzZ85Y7njyzJw5U2bOnClpaWnSqlWrUhtDRcV8LD7Tfj4ascraUQ8++KDk5OTI9OnTta9lZ2fL2bNnRUQkKipKfHx8JCEhwfJb07x584q8Rps2bSQiIkLmzZuXf748vz1X3t/8FexTkJeXl/To0UPWrFmTf0cgInLy5ElZtmyZdOzYUYKCgoocV0HFWdY/cOBAOXbsmGzcuDE/+/nnn2XNmjUSGRmp/aYNxzAff+XofBw1apQkJydb/i1evFhErv+5THJyskRERBT7+mA+/lZ5/flYru6Q77vvPomPj5dZs2bJl19+KT169BAfHx85dOiQrFy5UubPny8DBgyQ0NBQeeaZZ2TWrFnSu3dviYmJkbS0NFm3bp3UqFGj0Gt4enrKokWLpE+fPtKqVSuJi4uTsLAwOXjwoOzfv182bNggIiJt27YVkes/YKKjo8XLy0sGDRpke84ZM2bIxo0bpWPHjvL000+Lt7e3LF68WK5cuSKzZ8926rVITk6WuLg4SUxMtOw9a2fy5MmyYsUKeeCBB2TcuHESHBwsb7zxhly7dk1mzpzp1PXBfPwtR+djmzZtpE2bNpYs7wdxs2bNpF+/fk5dH8zH3yq3Px/LcgVZ3iq53bt3F9pvyJAhyt/f/4ZfX7JkiWrbtq3y9fVVgYGBqkWLFmrChAkqIyMjv09OTo568cUXVVhYmPL19VVdunRR+/btU+Hh4YWuIsyTmpqqunfvrgIDA5W/v79q2bKlSkhIyP96dna2GjlypAoNDVUeHh6WFYVSYFm/UteX10dHR6uAgADl5+enunbtqrZv3+7Q62M3xuL8mYlSSh0+fFjFxsaqoKAg5evrqyIjI9WuXbscOraiYj66bz7+Fqusr2M+8vPRQyknPgkHAAAuxYeHAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYwOGduniIOApT1n/OznxEYZiPMImj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAABRkAAAMQEEGAMAAFGQAAAxAQQYAwAAUZAAADEBBBgDAAA4/frEiCgoK0rKvvvpKy0aOHKllH374YamMCQDcoeAjAvfu3av1mThxopZt2LCh1MZU2XCHDACAASjIAAAYgIIMAIABPFTBDw5u1NHDo7THUuqaNWtmaSckJGh9unbtqmXvvvuulg0ZMsR1A6sAHJxGLlMR5iNKD/Ox+FauXGlp9+/fX+tz4MABLWvfvr2WZWVluW5gFYCj85E7ZAAADEBBBgDAABRkAAAMQEEGAMAAlWpjkObNm1vadgu47Jw6dao0hgNUCnPmzNGyvn37almTJk3KYji4gczMzCL7nD9/XstycnJKYziVEnfIAAAYgIIMAIABKMgAABiAggwAgAEq1aIuR/zvf//TMrsnnADQjR07VsuefvppLfPx8dGyNm3aWNp79uxx3cDgEidOnNCyq1evumEkFRN3yAAAGICCDACAASjIAAAYgIIMAIABKuyiLi8vLy17+OGHizxu586dWpabm+uSMRVHSEiIlp07d07L2CWnYqtWrZqWFXyMqIjIF198UarjsPv/adq0aVo2efJkLbN79NyyZcu0zO7Rfig7hw4dcvcQKj3ukAEAMAAFGQAAA1CQAQAwAAUZAAADVNhFXYMHD9aygo98++WXX7Q+Tz75ZKmN6UZuv/12Ldu8ebOW2S04mzFjhpZ9/vnnrhkY3O7ZZ5/VsuzsbC1z5aIuuwVcL7zwgpZNmjTJofPZLeD605/+pGWXL1926HwoHR07drS0PTw8tD7bt28vq+FUStwhAwBgAAoyAAAGoCADAGAACjIAAAaosIu6srKyiuxjtxtWYGBgaQzHomXLlpb2ggULtD61a9fWsvvvv1/LYmJitKzg4gwRkV27dhVniHCDdu3aadn48eO1bNasWaU6jsjISC2bMmWKQ8euXLlSy4YNG6Zl165dK/7A4DJ2O8D16tXL0v7LX/6i9dm0aVOpjQncIQMAYAQKMgAABqAgAwBgAAoyAAAGqBCLum655RYti4uLK/K4o0ePapnd7l2O8vTUf78ZMGCAlg0fPtzS7tSpk9PXtNtVyW4nJ5hv9OjRWlalSpVSv25sbKylvWrVKoeOs9sRbuDAgS4ZE1zHbqFqfHx8kceNGDFCyxYuXOiSMcEed8gAABiAggwAgAEoyAAAGKBCfIb82GOPaVnPnj2LPK7gZ7kiIidOnHB6HM2bN9ey5cuXO30+R+Tm5mrZnj17SvWacI0aNWpY2nYbgyiltOznn392+pp33HGHlr399ttFXvPQoUNa1qFDB6fHgbJz6dIlLXvjjTe0rODmL+vWrdP6lOTnI4rGHTIAAAagIAMAYAAKMgAABqAgAwBggAqxqMtu8w07q1evtrRTUlKcvqafn5+WJSUlOXUuu0U6O3bs0LIWLVpoWXh4uJY9+OCDWrZixQqnxgbXqFWrlpZ99NFHlnaTJk20Plu3btWyJUuWOHTNoKAgLVu8eHGR/Y4cOaL16devn5bxxKbyoW7dulr20EMPadm+ffssbbuNX0zRuHFjh/p99913pTwS1+IOGQAAA1CQAQAwAAUZAAADUJABADBAuVvUZff0G7vdh+wU3LHGbkciOz4+Plq2dOlSLbPbqctOwd1uRo4cqfX54IMPHMrq1aunZT/88IND40DpsHsC1zvvvKNlrVu3trSvXr2q9Zk+fbqWObqY6oEHHtAyu921Cv5/cODAAa1P9+7dtcxul7jytoimMnj66ae17P7779eyw4cPW9qXL18utTEVJjo62tKePXu21sdugauHh4eWJSYmatmwYcNKMLrSxR0yAAAGoCADAGAACjIAAAagIAMAYIByt6jLbhcqu0U0X3/9tZY988wzRZ7fbtHYsmXLtMxuwYyjRo8ebWnbLday8/3332uZ3UIGu13EUHbee+89LbNbFFXQm2++qWV2/83tFuS0adNGy+Lj44u8pp2YmBiHsm3btmlZ586dnbomSs/Bgwe1LDIyUsvWrl1raS9atKjUxpTHbrfEggsPAwICtD6OLsjt1auXcwNzE+6QAQAwAAUZAAADUJABADAABRkAAAOUu0VdGzdu1DK7HYPsHj138uTJIs9/6623allJFnC9//77WuboIq6C7rrrLi2zW9TVv39/Lfv000+duiaKr127dlpm99+poBEjRmjZ8OHDXTKm4ozDUdu3b3fZuVB6du7cqWV2i2MffvhhS9vVi7rsdtyy2zmu4CKuPXv2aH08PfV7Sbuf3f/85z+LMUL34w4ZAAADUJABADAABRkAAANQkAEAMEC5W9Rlt+OR3Qf86enpRZ6rTp06WpacnOzcwEQkLS1Ny15++WUtK7gIzW4nmilTpmiZ3S5Ido/i27BhQ6HjhOvUqFFDy0qys5CrjruRs2fPalnBnZzsdrmzW4i4ZcsWVw0LpSgwMFDL7B716ayqVatq2aOPPqpldjslZmdna9nChQst7QsXLmh9/vCHP2iZ3Y5kBXdFNB13yAAAGICCDACAASjIAAAYgIIMAIAByt2iri5dujjUb8eOHUX22bRpk5Y1bdrUofOnpqZq2ciRI7UsLCxMy/r27WtpjxkzRusTEhLi0DhefPFFLVu3bp1Dx6Lkfv75Zy373e9+p2U+Pj5Fnstu1yK7RyjefffdWma3WGvOnDla9u6772pZRkZGkWND+WW3c9yVK1e07LvvvnPq/Hbz0W6xlp2pU6dq2axZsyxtu0WGDRo00LITJ05o2eXLlx0ahym4QwYAwAAUZAAADEBBBgDAAB7KwZ0HXPmUmJJYvXq1lsXGxmrZpUuXtOz8+fOWdq1atZweR8Fz3YjdH+U7y27jkXvuuUfLXPlH/45y9QYWRTFlPrpSlSpVtMzu87PbbrtNy+bOnatlEydOdM3AyiHmY/H17t3b0l67dq3Wx+6JTXabHz3xxBNadu+992rZV199pWU9e/a0tI8cOaL1sXuCld04TOHofOQOGQAAA1CQAQAwAAUZAAADUJABADBAudsYZNWqVVpmt6jLz8/PocxZrlysZcduIcOzzz6rZe5YwIXSMX36dC2zW8C1bds2LSu4mQKQJzw8XMvsFljt37+/yHN9//33Wma3wNXbWy8t9erV07KCT74TETlz5oylbbfJiN2GPBUBd8gAABiAggwAgAEoyAAAGICCDACAAcrdTl12Pv74Yy2LjIx0w0gcc+jQIUv7r3/9q9ZnwYIFZTUcl2BnpOL7wx/+YGmvWLFC62P3uto9AerNN9903cAqAOZj4aKiorSs4C5cdk/Da9SokZb997//1TK7OXrs2DEtCw0N1bK4uDhL+5dfftH62I3NZOzUBQBAOUJBBgDAABRkAAAMQEEGAMAAFWJRV0hIiJbZLSoYO3aspV2zZk2nr2m3Q1ZmZqaWHThwQMsGDBhgaVeEXWdYRFM4f39/LSv4aMX69etrfVauXKlldnP77NmzTo+tImI+Fl/BXbMcfQ3tfhb+8MMPWmb3Gi1btkzLCi4uqwi7EbKoCwCAcoSCDACAASjIAAAYgIIMAIABKsSiLkfdfffdlrbdY+y8vLwcOpfd4+7sdtw6ceKEg6Mr31hEU7hXXnlFy8aMGWNpX7p0SevTs2dPLUtNTXXZuCoq5mPxFdzt7YEHHtD6BAcHa9mHH36oZf369XPZuCoCFnUBAFCOUJABADAABRkAAANQkAEAMEClWtSF0sMiml917txZy5YuXapl4eHhlvbw4cO1Pm+88YbLxlWZMB9hEhZ1AQBQjlCQAQAwAAUZAAADUJABADAAi7rgEiyi+dWGDRu0LCoqSsvefvttS/tPf/pTqY2psmE+wiQs6gIAoByhIAMAYAAKMgAABuAzZLgEn9nBJMxHmITPkAEAKEcoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGMDhxy8CAIDSwx0yAAAGoCADAGAACjIAAAagIAMAYIBKV5Dr168vQ4cOzW9v2bJFPDw8ZMuWLS67hoeHh7zwwgsuOx8qLuYjTMJ8dK8yLchLly4VDw+P/H/VqlWTxo0by4gRI+TkyZNlOZQSS0lJKZeT6tq1a3LHHXeIh4eHzJ07193DcSvmo/usWLFC7rnnHrnpppukevXqct9998lHH33k7mG5FfPRPd5880257777pFatWlK1alWJiIiQuLg4OXLkSJmPxbvMrygi06ZNk4iICLl8+bKkpqbKokWLJCUlRfbt2yd+fn5lOpbOnTtLVlaWVKlSpVjHpaSkyMKFC20nXVZWlnh7u+WlLVJCQoIcPXrU3cMwCvOxbCUkJMioUaOkV69e8tJLL8nly5dl6dKl0rt3b1m1apX079/f3UN0K+Zj2UpLS5OIiAjp27evhISESHp6urz55puydu1a2bt3r9SpU6fsBqPKUGJiohIRtXv3bks+btw4JSJq2bJlNzz2woULLhlDeHi4GjJkSInPM3z4cFXGL1+JnTx5UgUHB6tp06YpEVFz5sxx95DcivnoHo0aNVJ33XWXys3Nzc8yMzNVQECA6tu3rxtH5l7MR3N8/vnnSkTUrFmzyvS6RnyGHBkZKSIi6enpIiIydOhQCQgIkMOHD0tMTIwEBgbK4MGDRUQkNzdX5s2bJ82aNZNq1apJrVq1JD4+Xs6cOWM5p1JKZsyYIXXr1hU/Pz/p2rWr7N+/X7v2jT4j2blzp8TExEhISIj4+/tLy5YtZf78+fnjW7hwoYiI5S2mPHafkaSlpUnPnj0lKChIAgICpFu3brJjxw5Ln7y3rLZt2ybjxo2T0NBQ8ff3l9jYWDl9+rSlb2Zmphw8eFAyMzMdeYlFRGTSpEnSpEkTeeSRRxw+pjJiPl5XWvPx3LlzUrNmTcsY88bh6+tb5PGVDfPxutL++fhb9evXFxGRs2fPOnW8s4x43+Dw4cMiIlK9evX8LDs7W6Kjo6Vjx44yd+7c/Ldq4uPjZenSpRIXFyejRo2S9PR0WbBggaSlpcm2bdvEx8dHRESef/55mTFjhsTExEhMTIzs2bNHevToIVevXi1yPBs3bpTevXtLWFiYjB49WmrXri3ffPONrF27VkaPHi3x8fGSkZEhGzdulPfee6/I8+3fv186deokQUFBMmHCBPHx8ZHFixdLly5d5D//+Y/cfffdlv4jR46UkJAQmTp1qhw5ckTmzZsnI0aMkOXLl+f3SU5Olri4OElMTLQswriRXbt2yTvvvCOpqamW/zmgYz6W7nzs0qWLfPDBB5KQkCB9+vSRy5cvS0JCgmRmZsro0aOLHH9lw3ws/Z+PIiL/+9//JCcnR44ePSrTpk0TEZFu3bo5dKzLlOXteN5bMps2bVKnT59Wx44dU++//76qXr268vX1VcePH1dKKTVkyBAlImrSpEmW47du3apERCUlJVny9evXW/JTp06pKlWqqF69elneFpsyZYoSEctbMps3b1YiojZv3qyUUio7O1tFRESo8PBwdebMGct1fnuuwt6SERE1derU/Ha/fv1UlSpV1OHDh/OzjIwMFRgYqDp37qy9PlFRUZZrjR07Vnl5eamzZ89qfRMTE23HUHDc7du3Vw899JBSSqn09HTeslbMR3fNx5MnT6pu3bopEcn/V6NGDbV9+/Yij63ImI/umY95qlatmj8fq1evrv761786fKyruOUt66ioKAkNDZV69erJoEGDJCAgQJKTk+WWW26x9Hvqqacs7ZUrV0pwcLB0795dfv755/x/bdu2lYCAANm8ebOIiGzatEmuXr0qI0eOtNwNjhkzpsixpaWlSXp6uowZM0Zuuukmy9ecubPMycmRf//739KvXz9p0KBBfh4WFiYPP/ywpKamyrlz5yzHPPHEE5ZrderUSXJycuTHH3/Mz4YOHSpKKYd++1u6dKl8/fXX8vLLLxd7/JUB87Fs56Ofn580adJEhgwZIitXrpS3335bwsLCpH///vL9998X+3uqaJiPZTsf86xbt05SUlLklVdekVtvvVUuXrxY7O+npNzylvXChQulcePG4u3tLbVq1ZImTZqIp6f1dwNvb2+pW7euJTt06JBkZmZKzZo1bc976tQpEZH8/zCNGjWyfD00NFRCQkIKHVve20PNmzd3/BsqxOnTp+XSpUvSpEkT7WtNmzaV3NxcOXbsmDRr1iw/v/XWWy398sZc8HMgR5w7d04mT54s48ePl3r16hX7+MqA+XhdWcxHEZE//OEP4u3tLR9++GF+dv/990ujRo3k2Weftbz1WBkxH68rq/mYp2vXriIi0rNnT7n//vulefPmEhAQICNGjCjReYvDLQW5ffv20q5du0L7VK1aVZuEubm5UrNmTUlKSrI9JjQ01GVjdCcvLy/bXDnxYK65c+fK1atXZeDAgfl/V3f8+HERuT6Bjxw5InXq1Cn2nzVUJMzHwrlyPv7www+yfv16WbJkiSW/+eabpWPHjrJt2zanxliRMB8L58r5eCMNGzaU1q1bS1JSUsUvyM5q2LChbNq0STp06FDoaszw8HARuf4b42/fBjl9+nSRv0U1bNhQRET27dsnUVFRN+zn6NszoaGh4ufnJ99++632tYMHD4qnp2ep3rkePXpUzpw5Y/kNM8/MmTNl5syZkpaWJq1atSq1MVRUzMfiy9vgIicnR/vatWvXJDs7u9SuXdExH10rKytLrly5UqbXNOLPnhz14IMPSk5OjkyfPl37WnZ2dv4S9aioKPHx8ZGEhATLb03z5s0r8hpt2rSRiIgImTdvnrbk/bfn8vf3F5Gil8V7eXlJjx49ZM2aNZadX06ePCnLli2Tjh07SlBQUJHjKsjRZf2jRo2S5ORky7/FixeLyPXPWZKTkyUiIqLY1wfz8bccnY+33XabeHp6yvLlyy3jP378uGzdulVat25d7GvjOubjrxydj9nZ2ba/hOzatUu+/vrrIt+pcLVydYd83333SXx8vMyaNUu+/PJL6dGjh/j4+MihQ4dk5cqVMn/+fBkwYICEhobKM888I7NmzZLevXtLTEyMpKWlybp166RGjRqFXsPT01MWLVokffr0kVatWklcXJyEhYXJwYMHZf/+/bJhwwYREWnbtq2IXC940dHR4uXlJYMGDbI954wZM2Tjxo3SsWNHefrpp8Xb21sWL14sV65ckdmzZzv1Wji6rL9NmzbSpk0bS5Y38Zs1ayb9+vVz6vpgPv6Wo/MxNDRUhg0bJn/729+kW7du0r9/fzl//ry8/vrrkpWVJZMnT3bq+mA+/paj8/HChQtSr149GThwoDRr1kz8/f3l66+/lsTERAkODpbnnnvOqes7rSyXdN9oJ5qChgwZovz9/W/49SVLlqi2bdsqX19fFRgYqFq0aKEmTJigMjIy8vvk5OSoF198UYWFhSlfX1/VpUsXtW/fPm0nmoLL+vOkpqaq7t27q8DAQOXv769atmypEhIS8r+enZ2tRo4cqUJDQ5WHh4dlib8UWNavlFJ79uxR0dHRKiAgQPn5+amuXbtqf+Zxo9fHbozOLOvPw589Xcd8dM98vHbtmkpISFCtWrVSAQEBKiAgQHXt2lV98sknRR5bkTEfy34+XrlyRY0ePVq1bNlSBQUFKR8fHxUeHq4ee+wxlZ6eXuixpcFDKRd+Eg4AAJxSrj5DBgCgoqIgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABnB4py4eao/ClPWfszMfURjmI0zi6HzkDhkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADAABRkAAANQkAEAMAAFGQAAA1CQAQAwAAUZAAADUJABADCAt7sHAJRn3t6O/S90//33a9mdd95paa9evVrr07x5cy37/PPPtezgwYMOjQOAubhDBgDAABRkAAAMQEEGAMAAFGQAAAzgoZRSDnX08CjtsaAI9evX17LZs2drWWRkpJY1atRIy86cOeOScYmIODiNXMYd8zEwMFDLPvroIy3z9/fXsoiICC07evSopd2iRQuHxrF7924tW7p0qZZt3rxZy7799luHrlHeVYb5aIoOHTo41C8qKkrLJk6cqGWbNm2ytJOTk7U+dnP7yJEjDo3DHRydj9whAwBgAAoyAAAGoCADAGAACjIAAAZgUZehfH19tSwpKUnL7HaAevnll7VsypQprhnYDVS0RTQPPPCAltktQGnbtm2pjsNR2dnZWrZu3Tot++Mf/6hl586dK5UxuVNFm48l4ePjo2XBwcFadvnyZUt73LhxWp9BgwZp2e23365lrnz97V5bu59xkydPdtk1XY1FXQAAlCMUZAAADEBBBgDAABRkAAAMwOMXnVS1alUts9uNyW53mvfee8/SzszM1PosWbJEy+wWcH3xxRdaNmfOHC1D8XzwwQdalpub64aROObUqVNa9sknn2iZ3S5fBefosWPHXDcwuN2iRYu0LC4uTssK7hx36623ltqY8mzdulXLOnXqVOrXNRV3yAAAGICCDACAASjIAAAYgIIMAIABWNTlgAYNGmjZSy+9pGV2uzvZadeunaWdkJCg9RkwYIBD53r22We1zJWPVayszp49q2VBQUFlPxAH2T3y0c4vv/yiZdWrVy+yT2hoqJYdP35cy+x2DEPZee2117Rs2LBhWma3c1TBRVyHDh3S+tjtVmfXz24BZHx8vJa1adNGywrau3evlq1fv77I48oj7pABADAABRkAAANQkAEAMAAFGQAAA7Coq4D69etrmd0CLrtFV+fPn9eyzZs3a1nBx4SlpqZqfex2AktLS9OyjRs3ahlKzm6hnd0COjt2u2YV3AVJRGTZsmWW9lNPPaX1OXHihJa99dZbWvbVV19pmd0OcHY7NBVc1PW73/1O67NgwQIts9uZzu77RNk5efKk08cW3DVr8ODBWp+ffvrJoXO98MILWvbII49o2c0336xl3333naX9+9//XutTku/TZNwhAwBgAAoyAAAGoCADAGAACjIAAAbwUHZbtth19PAo7bGUuTp16mhZSkqKlrVo0ULLLly4oGWvv/66lj3//PNaNmLECEv7lVde0fpcvHhRy+wWRaxZs0bL3MHBaeQypT0fW7VqpWV2j7q0s2fPHoeOLbhQyu6xh3YLs+zYLUbs3r27lr3xxhtaVnChTmBgoNbHbpcykxd1VbT5WBJ2r4VdtnDhQkvb19dX69O4cWMts3tcot35v/nmGy1buXKlltktCCvvHJ2P3CEDAGAACjIAAAagIAMAYIBK9Rny7bffbmnbPTGkXr16WnbgwAEtGzNmjJZ9/PHHWmb32d7hw4ct7atXr2p9hg4dqmXLly/XMlNUtM/sPD3131XtPu/q16+fy645ZcoULbPb+KVu3bpatmTJEi2ze0KTI+zOtWvXLi2zm4+XLl1y6pquVtHmY0l07NhRy5KTk7XMbpMOR8yZM0fLPvjgAy07ePCgltmtxamI+AwZAIByhIIMAIABKMgAABiAggwAgAEq1aKugk8zuffeex06bvr06Vrm6B+vF/xjexGRJ5980tK2W0Rj9+Qfk1WGRTQtW7bUsnXr1mlZ7dq1y2I4LrN9+3ZLOyYmRutj9yQzk1WG+VgSdgta7TaSccTf/vY3LfvnP/+pZXb/r1QWLOoCAKAcoSADAGAACjIAAAagIAMAYIAKu6hr4sSJWvbSSy9Z2ufOndP6tG7dWst++OEHh645bdo0Lfvzn/+sZdu2bbO07Z6WUt5U1kU0Y8eO1bK5c+e6YSTOmz9/vqU9btw4N43EdSrrfHSU3ZOc+vbta2n3799f62O361dYWJiW5eTkaNnevXu1rODPZBGRjz76yNLOysrS+pQ3LOoCAKAcoSADAGAACjIAAAagIAMAYIAKsajLboHCjh07tKxFixaW9rVr17Q+//vf/xy6pt3rUaNGDS2ze4zf5cuXLe3OnTtrfb744guHxmGKyrqIxsfHR8tmzZqlZXaLv0wxfPhwSzs1NVXrs2/fvrIajktU1vlY2mrVqqVldjsevvXWW1oWHBzs0DVWrVplab/77rtan7Vr1zp0LlOwqAsAgHKEggwAgAEoyAAAGICCDACAASrsoq7XX39dyx588EFLu1q1ak5f0+71sHsp7RaJJSYmWtp2u9WcOXPG6bG5A4toflWlShUtW7FihZb16dOnLIZTbHZz9vDhw1r217/+Vcu+/vprLXPHgjDmo3mioqK0bNGiRVrWsGFDS9vutZ0yZYqW2S2mNAWLugAAKEcoyAAAGICCDACAASjIAAAYoEIs6nJU8+bNLW27x4bZiYyM1DK7xzvavZQFd0ESEXnjjTccum55wiKawj3zzDNa9vLLLzt1rp9++knL7BbHzJgxQ8u+/fZbLWvSpIlT47Czfv16LYuNjdWyq1evuuyadpiP5UNoaKiWPfLII5b2c889p/UJCAjQsmeffVbLXnnlFS3Lzc0tzhBdgkVdAACUIxRkAAAMQEEGAMAAFGQAAAxQqRZ1OWvDhg1a1r17dy3bu3evlrVu3bpUxmQaFtEUzm7xyttvv21px8TEOH3+5ORkLbN7VN6kSZO0rOBcvvPOO7U+ffv21bIDBw5o2R133KFldgvJtm7dqmVPPfWUpV2SxTfMx4qjQ4cOWvbpp586dGzNmjW1zNFH7LoSi7oAAChHKMgAABiAggwAgAH4DLmA3//+91pm9/lc1apVtcxuExC7DRsqIj6zK76CT4VatWqV1qcknytfunRJyzIyMrTsrbfesrTt1kIsX75cy7Kzs7UsJCSkOEO0aN++vaX9xRdfOH2u8jwfvb29tczuyXQXLlxw2TVN5uPjo2V2TxC77bbbtGzs2LFaZveUstLGZ8gAAJQjFGQAAAxAQQYAwAAUZAAADKCvHqhE7BYL2D0xpODiGxGRzz77TMvefPNN1wwMlULBJx5t27ZN61OSRV1+fn5aZrfwZdasWZa23cIvX19fLbNbfFQSU6dOtbTtNiOpDAYPHqxldk8Lmzlzppb94x//KJUxudO1a9e0LCcnx6Fj7Rbfmow7ZAAADEBBBgDAABRkAAAMQEEGAMAAlXqnLrtdXObOnevQsQMHDtSyDz74oMRjKq/K885IprBbZFi9enUtGz16tJZNmDChVMbkTl5eXk4fW57n40033aRldgv+mjZtqmVr167VstmzZ2tZamqqc4NzA7uFiLt379ay4OBgLYuLi9Oyd955xzUDKwZ26gIAoByhIAMAYAAKMgAABqAgAwBggEq9qGv79u1ads8992jZe++9p2VDhgwplTGVV+V5EU15Y7dD1owZM7Rs/PjxZTEcl9m5c6elfe+99zp9roo2HwMCArTM7pGY3bp107JffvlFy+wWtK5bt87SPnfuXHGG6DIFHzX5r3/9S+tj930W3PlOROSWW27RMrvXo7SxqAsAgHKEggwAgAEoyAAAGICCDACAASrV4xcfeughS7tFixZaH7uFAVu2bCmtIQHFlp2drWV2jw21W2gUGxurZatWrbK0H3vsMa2P3Y5hrjZt2rRSv0Z5deHCBS3r1auXlnXs2FHL7BalLlu2TMtOnDhhaT/++ONan4ILv4ojLCxMy7p3765lBXeia926tdbHbpHU4sWLtcwdC7hKgjtkAAAMQEEGAMAAFGQAAAxAQQYAwACVaqeugo/satOmjdbnrbfe0rInnnii1MZUUVS0nZEqAk9P/fdtu6zgIrE777xT69O/f38ts3uEn92iIjt2jwS8dOmSpV2SOcV8/JXdLl92/52WLFliadeuXVvrk5SUpGWnTp3SMrsFs3Y7rwUGBmpZQV999ZWWTZo0Scs2b96sZXaLdN2BnboAAChHKMgAABiAggwAgAEoyAAAGKBSLer6/vvvLW27XVxeffVVLXv//fdLbUwVBYtoYBLmY/HdfPPNlvbtt9/u0HF2C6zsdhGzY7dIbPXq1Zb2Z599pvU5efKkQ+c3BYu6AAAoRyjIAAAYgIIMAIABKtVnyCg9fGYHkzAfYRI+QwYAoByhIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYAAKMgAABqAgAwBgAAoyAAAGoCADAGAACjIAAAagIAMAYACHH78IAABKD3fIAAAYgIIMAIABKMgAABiAggwAgAEoyAAAGICCDACAASjIAAAYgIIMAIABKMgAABjg/wCJ6yDm5w+D/QAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.update(state.params)\n", + "\n", + "# plot a 3x3 grid of MNIST digits\n", + "idxs = np.random.randint(0, len(X_test), size=(3, 3))\n", + "fig, axes = plt.subplots(3, 3, figsize=(3 * 2, 3 * 2))\n", + "\n", + "for i in range(3):\n", + " for j in range(3):\n", + " logits = model(jnp.array([X_test[idxs[i, j]]]))\n", + " axes[i, j].imshow(X_test[idxs[i, j]], cmap=\"gray\")\n", + " axes[i, j].axis(\"off\")\n", + " axes[i, j].set_title(f\"Prediction: {jnp.argmax(logits)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Awesome! We hope you've enjoyed this tutorial and learned the basics of NNX." + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/experimental/nnx/docs/tiny_nnx.ipynb new file mode 100644 index 0000000000..05d35fd34a --- /dev/null +++ b/flax/experimental/nnx/docs/tiny_nnx.ipynb @@ -0,0 +1,465 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tiny NNX\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cgarciae/nnx/blob/main/docs/tiny_nnx.ipynb)\n", + "\n", + "A pedagogical implementation of NNX's core APIs.\n", + "\n", + "## Core APIs" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import hashlib\n", + "import typing as tp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random\n", + "import dataclasses\n", + "\n", + "A = tp.TypeVar(\"A\")\n", + "M = tp.TypeVar(\"M\", bound=\"Module\")\n", + "Sharding = tp.Tuple[tp.Optional[str], ...]\n", + "Array = random.Array\n", + "\n", + "\n", + "class Variable(tp.Generic[A]):\n", + "\n", + " def __init__(\n", + " self,\n", + " value: A,\n", + " *,\n", + " sharding: tp.Optional[Sharding] = None,\n", + " ):\n", + " self.value = value\n", + " self.sharding = sharding\n", + "\n", + " def __repr__(self) -> str:\n", + " return (\n", + " f\"{type(self).__name__}(value={self.value}, sharding={self.sharding})\"\n", + " )\n", + "\n", + " def __init_subclass__(cls):\n", + " super().__init_subclass__()\n", + " jax.tree_util.register_pytree_node(\n", + " cls,\n", + " lambda x: ((x.value,), (x.sharding,)),\n", + " lambda metadata, value: Variable(value[0], sharding=metadata[0]),\n", + " )\n", + "\n", + "\n", + "class State(dict[str, Variable[tp.Any]]):\n", + "\n", + " def filter(self, variable_type: tp.Type[Variable]) -> \"State\":\n", + " return State(\n", + " {\n", + " path: variable\n", + " for path, variable in self.items()\n", + " if isinstance(variable, variable_type)\n", + " }\n", + " )\n", + "\n", + " def __repr__(self) -> str:\n", + " elems = \",\\n \".join(\n", + " f\"'{path}': {variable}\".replace(\"\\n\", \"\\n \")\n", + " for path, variable in self.items()\n", + " )\n", + " return f\"State({{\\n {elems}\\n}})\"\n", + "\n", + "\n", + "jax.tree_util.register_pytree_node(\n", + " State,\n", + " # in reality, values and paths should be sorted by path\n", + " lambda x: (tuple(x.values()), tuple(x.keys())),\n", + " lambda paths, values: State(dict(zip(paths, values))),\n", + ")\n", + "\n", + "\n", + "@dataclasses.dataclass\n", + "class ModuleDef(tp.Generic[M]):\n", + " type: tp.Type[M]\n", + " index: int\n", + " submodules: tp.Dict[str, tp.Union[\"ModuleDef[Module]\", int]]\n", + " static_fields: tp.Dict[str, tp.Any]\n", + "\n", + " def merge(self, state: State) -> M:\n", + " module = ModuleDef._build_module_recursive(self, {})\n", + " module.update(state)\n", + " return module\n", + "\n", + " @staticmethod\n", + " def _build_module_recursive(\n", + " moduledef: tp.Union[\"ModuleDef[M]\", int],\n", + " index_to_module: tp.Dict[int, \"Module\"],\n", + " ) -> M:\n", + " if isinstance(moduledef, int):\n", + " return index_to_module[moduledef] # type: ignore\n", + "\n", + " assert moduledef.index not in index_to_module\n", + "\n", + " # add a dummy module to the index to avoid infinite recursion\n", + " module = object.__new__(moduledef.type)\n", + " index_to_module[moduledef.index] = module\n", + "\n", + " submodules = {\n", + " name: ModuleDef._build_module_recursive(submodule, index_to_module)\n", + " for name, submodule in moduledef.submodules.items()\n", + " }\n", + " vars(module).update(moduledef.static_fields)\n", + " vars(module).update(submodules)\n", + " return module\n", + "\n", + " def apply(\n", + " self, state: State\n", + " ) -> tp.Callable[..., tuple[tp.Any, tuple[State, \"ModuleDef[M]\"]]]:\n", + " def _apply(*args, **kwargs):\n", + " module = self.merge(state)\n", + " out = module(*args, **kwargs) # type: ignore\n", + " return out, module.split()\n", + "\n", + " return _apply\n", + "\n", + "\n", + "class Module:\n", + "\n", + " def split(self: M) -> tp.Tuple[State, ModuleDef[M]]:\n", + " state = State()\n", + " moduledef = Module._partition_recursive(\n", + " module=self, module_id_to_index={}, path_parts=(), state=state\n", + " )\n", + " assert isinstance(moduledef, ModuleDef)\n", + " return state, moduledef\n", + "\n", + " @staticmethod\n", + " def _partition_recursive(\n", + " module: M,\n", + " module_id_to_index: tp.Dict[int, int],\n", + " path_parts: tp.Tuple[str, ...],\n", + " state: State,\n", + " ) -> tp.Union[ModuleDef[M], int]:\n", + " if id(module) in module_id_to_index:\n", + " return module_id_to_index[id(module)]\n", + "\n", + " index = len(module_id_to_index)\n", + " module_id_to_index[id(module)] = index\n", + "\n", + " submodules = {}\n", + " static_fields = {}\n", + "\n", + " # iterate fields sorted by name to ensure deterministic order\n", + " for name, value in sorted(vars(module).items(), key=lambda x: x[0]):\n", + " value_path = (*path_parts, name)\n", + " # if value is a Module, recurse\n", + " if isinstance(value, Module):\n", + " submoduledef = Module._partition_recursive(\n", + " value, module_id_to_index, value_path, state\n", + " )\n", + " submodules[name] = submoduledef\n", + " # if value is a Variable, add to state\n", + " elif isinstance(value, Variable):\n", + " state[\"/\".join(value_path)] = value\n", + " else: # otherwise, add to static fields\n", + " static_fields[name] = value\n", + "\n", + " return ModuleDef(\n", + " type=type(module),\n", + " index=index,\n", + " submodules=submodules,\n", + " static_fields=static_fields,\n", + " )\n", + "\n", + " def update_state(self, state: State) -> None:\n", + " for path, value in state.items():\n", + " path_parts = path.split(\"/\")\n", + " Module._set_value_at_path(self, path_parts, value)\n", + "\n", + " @staticmethod\n", + " def _set_value_at_path(\n", + " module: \"Module\", path_parts: tp.Sequence[str], value: Variable[tp.Any]\n", + " ) -> None:\n", + " if len(path_parts) == 1:\n", + " setattr(module, path_parts[0], value)\n", + " else:\n", + " Module._set_value_at_path(\n", + " getattr(module, path_parts[0]), path_parts[1:], value\n", + " )\n", + "\n", + "\n", + "@dataclasses.dataclass\n", + "class Rngs:\n", + " key: jax.Array\n", + " count: int = 0\n", + " count_path: tuple[int, ...] = ()\n", + "\n", + " def fork(self) -> \"Rngs\":\n", + " \"\"\"Forks the context, guaranteeing that all the random numbers generated\n", + " will be different from the ones generated in the original context. Fork is\n", + " used to create a new Rngs that can be passed to a JAX transform\"\"\"\n", + " count_path = self.count_path + (self.count,)\n", + " self.count += 1\n", + " return Rngs(self.key, count_path=count_path)\n", + "\n", + " def make_rng(self) -> jax.Array:\n", + " fold_data = self._stable_hash(self.count_path + (self.count,))\n", + " self.count += 1\n", + " return random.fold_in(self.key, fold_data) # type: ignore\n", + "\n", + " @staticmethod\n", + " def _stable_hash(data: tuple[int, ...]) -> int:\n", + " hash_str = \" \".join(str(x) for x in data)\n", + " _hash = hashlib.blake2s(hash_str.encode())\n", + " hash_bytes = _hash.digest()\n", + " # uint32 is represented as 4 bytes in big endian\n", + " return int.from_bytes(hash_bytes[:4], byteorder=\"big\")\n", + "\n", + "\n", + "# in the real NNX Rngs is not a pytree, instead\n", + "# it has a split/merge API similar to Module\n", + "# but for simplicity we use a pytree here\n", + "jax.tree_util.register_pytree_node(\n", + " Rngs,\n", + " lambda x: ((x.key,), (x.count, x.count_path)),\n", + " lambda metadata, value: Rngs(value[0], *metadata),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "class Param(Variable[A]):\n", + " pass\n", + "\n", + "\n", + "class BatchStat(Variable[A]):\n", + " pass\n", + "\n", + "\n", + "class Linear(Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, rngs: Rngs):\n", + " self.din = din\n", + " self.dout = dout\n", + " key = rngs.make_rng()\n", + " self.w = Param(random.uniform(key, (din, dout)))\n", + " self.b = Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " return x @ self.w.value + self.b.value\n", + "\n", + "\n", + "class BatchNorm(Module):\n", + "\n", + " def __init__(self, din: int, mu: float = 0.95):\n", + " self.mu = mu\n", + " self.scale = Param(jax.numpy.ones((din,)))\n", + " self.bias = Param(jax.numpy.zeros((din,)))\n", + " self.mean = BatchStat(jax.numpy.zeros((din,)))\n", + " self.var = BatchStat(jax.numpy.ones((din,)))\n", + "\n", + " def __call__(self, x, train: bool) -> jax.Array:\n", + " if train:\n", + " axis = tuple(range(x.ndim - 1))\n", + " mean = jax.numpy.mean(x, axis=axis)\n", + " var = jax.numpy.var(x, axis=axis)\n", + " # ema update\n", + " self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean\n", + " self.var.value = self.mu * self.var.value + (1 - self.mu) * var\n", + " else:\n", + " mean, var = self.mean.value, self.var.value\n", + "\n", + " scale, bias = self.scale.value, self.bias.value\n", + " x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias\n", + " return x\n", + "\n", + "\n", + "class Dropout(Module):\n", + "\n", + " def __init__(self, rate: float):\n", + " self.rate = rate\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", + " if train:\n", + " mask = random.bernoulli(rngs.make_rng(), (1 - self.rate), x.shape)\n", + " x = x * mask / (1 - self.rate)\n", + " return x" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scan Over Layers Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Block(Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, rngs: Rngs):\n", + " self.linear = Linear(din, dout, rngs=rngs)\n", + " self.bn = BatchNorm(dout)\n", + " self.dropout = Dropout(0.1)\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", + " x = self.linear(x)\n", + " x = self.bn(x, train=train)\n", + " x = jax.nn.gelu(x)\n", + " x = self.dropout(x, train=train, rngs=rngs)\n", + " return x\n", + "\n", + "\n", + "class ScanMLP(Module):\n", + "\n", + " def __init__(self, hidden_size: int, n_layers: int, *, rngs: Rngs):\n", + " self.n_layers = n_layers\n", + "\n", + " # lift init\n", + " key = random.split(rngs.make_rng(), n_layers - 1)\n", + " moduledef: ModuleDef[Block] = None # type: ignore\n", + "\n", + " def init_fn(key):\n", + " nonlocal moduledef\n", + " state, moduledef = Block(\n", + " hidden_size, hidden_size, rngs=Rngs(key)\n", + " ).split()\n", + " return state\n", + "\n", + " state = jax.vmap(init_fn)(key)\n", + " self.layers = moduledef.merge(state)\n", + " self.linear = Linear(hidden_size, hidden_size, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", + " # lift call\n", + " key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1) # type: ignore\n", + " state, moduledef = self.layers.split()\n", + "\n", + " def scan_fn(x, inputs: tuple[jax.Array, State]):\n", + " key, state = inputs\n", + " x, (state, _) = moduledef.apply(state)(x, train=train, rngs=Rngs(key))\n", + " return x, state\n", + "\n", + " x, state = jax.lax.scan(scan_fn, x, (key, state))\n", + " self.layers.update(state)\n", + " x = self.linear(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state = State({\n", + " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", + " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", + " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + "})\n", + "moduledef = ModuleDef(type=, index=0, submodules={'layers': ModuleDef(type=, index=1, submodules={'bn': ModuleDef(type=, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': ModuleDef(type=, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': ModuleDef(type=, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': ModuleDef(type=, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n" + ] + } + ], + "source": [ + "module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))\n", + "x = jax.random.normal(random.key(0), (2, 10))\n", + "y = module(x, train=True, rngs=Rngs(random.key(1)))\n", + "\n", + "state, moduledef = module.split()\n", + "print(\"state =\", jax.tree_map(jnp.shape, state))\n", + "print(\"moduledef =\", moduledef)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filtering State" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params = State({\n", + " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", + " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", + " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", + " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + "})\n", + "batch_stats = State({\n", + " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", + " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None)\n", + "})\n" + ] + } + ], + "source": [ + "# split\n", + "params = state.extract(Param)\n", + "batch_stats = state.extract(BatchStat)\n", + "# merge\n", + "state = State({**params, **batch_stats})\n", + "\n", + "print(\"params =\", jax.tree_map(jnp.shape, params))\n", + "print(\"batch_stats =\", jax.tree_map(jnp.shape, batch_stats))" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb new file mode 100644 index 0000000000..54269e982e --- /dev/null +++ b/flax/experimental/nnx/docs/why.ipynb @@ -0,0 +1,391 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Why NNX?\n", + "\n", + "Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n", + "\n", + "We introduced some ideas that have proven to be good:\n", + " - Organizing variables into [collections](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) or types to support JAX transforms and segregation of different data types in training loops.\n", + " - Automatic and efficient [PRNG management](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) (with support for splitting/broadcast control across map transforms)\n", + " - [Variable Metadata](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.with_partitioning.html#flax.linen.with_partitioning) for SPMD annotations, optimizer metadata, and other uses.\n", + "\n", + "One choice we made was to use functional \"define by call\" semantics for NN programming via the lazy (ie just in time) initialization of parameters. This made for concise (`compact`) implementation code and allowed for a single specification when transforming a layer. It also aligned our API to be closer to Haiku. However that lazy-init meant that the semantics of variables in Flax were non-pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of transformations on neural nets.\n", + "\n", + "NNX is an attempt to keep the features that made Linen great while introducing some new principles:\n", + "\n", + "- Regular Python semantics for Modules, including (within JIT boundaries) support for mutability and shared references.\n", + "- A simple API to interact directly with the JAX, this includes the ability to easily implement custom lifted Modules and other purely functional tricks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### NNX is Pythonic\n", + "The main feature of NNX Module is that it adheres to Python semantics. This means that:\n", + "\n", + "* fields are mutable so you can perform inplace updates\n", + "* Module references can be shared between multiple Modules\n", + "* Module construction implies parameter initialization\n", + "* Module methods can be called directly" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model = CounterLinear(\n", + " linear=Linear(\n", + " in_features=4,\n", + " out_features=4,\n", + " use_bias=True,\n", + " dtype=None,\n", + " param_dtype=,\n", + " precision=None,\n", + " kernel_init=.init at 0x7f5d3c57baf0>,\n", + " bias_init=,\n", + " dot_general=\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "from flax.experimental import nnx\n", + "import jax\n", + "from jax import random, numpy as jnp\n", + "\n", + "class Count(nnx.Variable): pass\n", + "\n", + "class CounterLinear(nnx.Module):\n", + " def __init__(self, din, dout, *, rngs): # explicit RNG threading\n", + " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", + " self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections\n", + "\n", + " def __call__(self, x):\n", + " self.count += 1 # inplace stateful updates\n", + " return self.linear(x)\n", + "\n", + "\n", + "model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", + "y = model(jnp.ones((2, 4))) # call methods directly\n", + "\n", + "print(f'{model = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because NNX Modules contain their own state, they are very easily to inspect:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count = Array(1, dtype=int32)\n", + "model.linear.kernel = Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", + " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", + " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", + " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n" + ] + } + ], + "source": [ + "print(f'{model.count = }')\n", + "print(f'{model.linear.kernel = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Intuitive Surgery\n", + "\n", + "In NNX surgery can be done at the Module level by simply updating / replacing existing fields." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def load_pretrained():\n", + " return nnx.Linear(4, 4, rngs=nnx.Rngs(42)) # pretend this is pretrained\n", + "\n", + "model.linear = load_pretrained() # you can replace modules\n", + "\n", + "y = model(jnp.ones((2, 4)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The benefit of this is not only that its easier than messing with dictionary structures, but can even replace a field with a completely different Module type, or even change the architecture (e.g. share two Modules that were not shared before)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "rngs = nnx.Rngs(0)\n", + "model = nnx.Sequence(\n", + " [\n", + " nnx.Conv(1, 16, [3, 3], padding='SAME', rngs=rngs),\n", + " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", + " nnx.Conv(16, 32, [3, 3], padding='SAME', rngs=rngs),\n", + " partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2)),\n", + " lambda x: x.reshape((x.shape[0], -1)), # flatten\n", + " nnx.Linear(32 * 7 * 7, 10, rngs=rngs),\n", + " ]\n", + ")\n", + "\n", + "y = model(jnp.ones((2, 28, 28, 1)))\n", + "\n", + "for i, layer in enumerate(model):\n", + " if isinstance(layer, nnx.Conv):\n", + " model[i] = nnx.Linear(layer.in_features, layer.out_features, rngs=rngs)\n", + "\n", + "y = model(jnp.ones((2, 28, 28, 1)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that here we are replacing `Conv` with `Linear` as a silly example, but in reality you would do things like replacing a layer with its quantized version, or changing a layer with an optimized version, etc." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interacting with JAX is easy\n", + "\n", + "While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations. NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", + "\n", + "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state = State({\n", + " 'count': Array(0, dtype=int32),\n", + " 'linear/bias': Array([0., 0., 0., 0.], dtype=float32),\n", + " 'linear/kernel': Array([[ 0.4541089 , -0.5264876 , -0.36505195, -0.57566494],\n", + " [ 0.38802508, 0.5655534 , 0.4870657 , 0.2267774 ],\n", + " [-0.9015767 , 0.24465278, -0.5844707 , 0.18421966],\n", + " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", + "})\n" + ] + } + ], + "source": [ + "model = CounterLinear(4, 4, rngs=nnx.Rngs(0))\n", + "\n", + "state, static = model.split()\n", + "\n", + "print(f'{state = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. \n", + "\n", + "Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (2, 4)\n", + "state[\"count\"] = Array(1, dtype=int32)\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def forward(state: nnx.State, x: jax.Array):\n", + " model = static.merge(state)\n", + " y = model(x)\n", + " state, _ = model.split()\n", + " return y, state\n", + "\n", + "x = jnp.ones((2, 4))\n", + "y, state = forward(state, x)\n", + "\n", + "print(f'{y.shape = }')\n", + "print(f'{state[\"count\"] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Custom lifted Modules\n", + "\n", + "By using the same mechanism inside Module methods you can implement lifted Modules, that is, Modules that use a JAX transformation to have a distinct behavior. One of Linen's current pain points is that it is not easy to interact with JAX transformations that are not currently supported by the framework. NNX makes this so easy that its realistic to implement custom lifted Modules for specific use cases.\n", + "\n", + "As an example here we will create a `LinearEnsemble` Module that uses `jax.vmap` both during `__init__` and `__call__` to vectorize the computation over multiple `CounterLinear` models (defined above). The example is a little bit longer, but notice how each method conceptually very simple:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (8, 4)\n", + "ensemble.models.count = Array(1, dtype=int32)\n", + "state = State({\n", + " 'models/count': (),\n", + " 'models/linear/bias': (8, 4),\n", + " 'models/linear/kernel': (8, 4, 4)\n", + "})\n" + ] + } + ], + "source": [ + "class LinearEnsemble(nnx.Module):\n", + " def __init__(self, din, dout, *, num_models, rngs: nnx.Rngs):\n", + " # get raw rng seeds\n", + " keys = rngs.fork(num_models) # split all keys into `num_models`\n", + "\n", + " # define pure init fn and vmap\n", + " def vmap_init(keys):\n", + " return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(\n", + " nnx.Param, Count\n", + " )\n", + "\n", + " params, counts, static = jax.vmap(\n", + " vmap_init, in_axes=(0,), out_axes=(0, None, None)\n", + " )(keys)\n", + " # update wrapped submodule reference\n", + " self.models = static.merge(params, counts)\n", + "\n", + " def __call__(self, x):\n", + " # get module values, define pure fn\n", + " params, counts, static = self.models.split(nnx.Param, Count)\n", + "\n", + " def vmap_apply(x, params, counts, static):\n", + " model = static.merge(params, counts)\n", + " y = model(x)\n", + " params, counts, static = model.split(nnx.Param, Count)\n", + " return y, params, counts, static\n", + "\n", + " # vmap and call\n", + " y, params, counts, static = jax.vmap(\n", + " vmap_apply, in_axes=(None, 0, None, None), out_axes=(0, 0, None, None)\n", + " )(x, params, counts, static)\n", + " # update wrapped module\n", + " self.models.update(params, counts, static) # use `update` to integrate the new state\n", + " return y\n", + "\n", + "x = jnp.ones((4,))\n", + "ensemble = LinearEnsemble(4, 4, num_models=8, rngs=nnx.Rngs(0))\n", + "\n", + "# forward pass\n", + "y = ensemble(x)\n", + "\n", + "print(f'{y.shape = }')\n", + "print(f'{ensemble.models.count = }')\n", + "print(f'state = {jax.tree_map(jnp.shape, ensemble.get_state())}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Why Modules are not Pytrees?\n", + "\n", + "Finally one of the most common questions we get is why NNX Modules are not Pytrees? Given the existance of Pytree-based NN frameworks like Equinox, Treex, [PytreeClass](https://github.com/ASEM000/PyTreeClass), it is a fair question.\n", + "\n", + "The answer is that Pytrees assume value semantics (referencial transparency) while Modules assume reference semantics, and therefore its not a good idea for Modules to be Pytrees. As an example, lets take a look at what would happen if we allowed this very simple program to be valid:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def f(m1: nnx.Module, m2: nnx.Module):\n", + " return m1, m2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we are just creating a jitted function `f` that takes in two Modules `(m1, m2)` and returns them as is. What could go wrong?\n", + "\n", + "There are two main problems with this:\n", + "* Shared references are not maintained, that is, if `m1.shared is m2.shared` outside `f`, this will NOT be true both inside `f`, and at the output of `f`.\n", + "* Even if you accept this fact and added code to compensate for this, `f` would now behave differently depending on whether its being `jit`ted or not, this is an undisired asymmetry and `jit` would no longer be a no-op." + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/examples/00_demo.ipynb b/flax/experimental/nnx/examples/00_demo.ipynb new file mode 100644 index 0000000000..97da6426bf --- /dev/null +++ b/flax/experimental/nnx/examples/00_demo.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "from flax.experimental import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " # static attributes\n", + " self.din = din\n", + " self.dout = dout\n", + " # variables\n", + " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.w + self.b\n", + "\n", + "\n", + "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'b': Param(\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'w': Param(\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n", + "ModuleDef(\n", + " type=Linear,\n", + " index=0,\n", + " static_fields=(('din', 2), ('dout', 2)),\n", + " variables=(('b', Param(\n", + " value=Empty\n", + " )), ('w', Param(\n", + " value=Empty\n", + " ))),\n", + " submodules=()\n", + ")\n" + ] + } + ], + "source": [ + "state, moduledef = linear.split()\n", + "\n", + "print(state)\n", + "print(moduledef)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2,\n", + " submodule=Linear(...)\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din = din\n", + " self.dout = dout\n", + " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + " # introduce a self-reference\n", + " self.submodule = self\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.submodule.w + self.submodule.b\n", + "\n", + "\n", + "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'b': Param(\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'w': Param(\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n", + "ModuleDef(\n", + " type=Linear,\n", + " index=0,\n", + " static_fields=(('din', 2), ('dout', 2)),\n", + " variables=(('b', Param(\n", + " value=Empty\n", + " )), ('w', Param(\n", + " value=Empty\n", + " ))),\n", + " submodules=(\n", + " ('submodule', 0)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "state, moduledef = linear.split()\n", + "\n", + "print(state)\n", + "print(moduledef)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "linear2 = moduledef.merge(state)\n", + "\n", + "linear2.submodule is linear2" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " din=2,\n", + " dout=2\n", + ")\n", + "[[0.63114893 1.2928092 ]\n", + " [0.63114893 1.2928092 ]]\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + "\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " # static attributes\n", + " self.din = din\n", + " self.dout = dout\n", + " # variables\n", + " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", + " self.b = nnx.Param(jnp.zeros((dout,)))\n", + "\n", + " def __call__(self, x):\n", + " y = x @ self.w + self.b\n", + " self.y = nnx.Intermediate(y)\n", + " return y\n", + "\n", + "\n", + "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", + "\n", + "y = linear(jnp.ones((2, 2)))\n", + "\n", + "print(linear)\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "State({\n", + " 'y': Intermediate(\n", + " value=Array([[0.63114893, 1.2928092 ],\n", + " [0.63114893, 1.2928092 ]], dtype=float32)\n", + " )\n", + "})\n", + "State({\n", + " 'b': Param(\n", + " value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'w': Param(\n", + " value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " )\n", + "})\n" + ] + } + ], + "source": [ + "intermediates = linear.pop(nnx.Intermediate)\n", + "state, moduledef = linear.split()\n", + "\n", + "print(intermediates)\n", + "print(state)" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flax/experimental/nnx/examples/01_functional_api.py b/flax/experimental/nnx/examples/01_functional_api.py new file mode 100644 index 0000000000..80ad492002 --- /dev/null +++ b/flax/experimental/nnx/examples/01_functional_api.py @@ -0,0 +1,108 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable[nnx.A]): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, rngs=rngs) + self.linear2 = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +params, counts, modeldef = MLP( + din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0) +).split(nnx.Param, Count) + + +@jax.jit +def train_step(params, counts, batch): + x, y = batch + + def loss_fn(params): + y_pred, (updates, _) = modeldef.apply(params, counts)(x) + counts_ = updates.extract(Count) + loss = jnp.mean((y - y_pred) ** 2) + return loss, counts_ + + grad, counts = jax.grad(loss_fn, has_aux=True)(params) + # |-------- sgd ---------| + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad) + + return params, counts + + +@jax.jit +def test_step(params: nnx.State, counts: nnx.State, batch): + x, y = batch + y_pred, _ = modeldef.apply(params, counts)(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + params, counts = train_step(params, counts, batch) + + if step % 1000 == 0: + logs = test_step(params, counts, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = modeldef.merge(params, counts) +print('times called:', model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/experimental/nnx/examples/02_lifted_transforms.py b/flax/experimental/nnx/examples/02_lifted_transforms.py new file mode 100644 index 0000000000..d74db70d82 --- /dev/null +++ b/flax/experimental/nnx/examples/02_lifted_transforms.py @@ -0,0 +1,106 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, rngs=rngs) + self.linear2 = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count += 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) + + +@nnx.jit +def train_step(model: MLP, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + # |--default--| + grad: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # sdg update + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grad) + ) + + # no return!!! + + +@nnx.jit +def test_step(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + train_step(model, batch) + + if step % 1000 == 0: + logs = test_step(model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +print('times called:', model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/experimental/nnx/examples/03_train_state.py b/flax/experimental/nnx/examples/03_train_state.py new file mode 100644 index 0000000000..bc65d4165c --- /dev/null +++ b/flax/experimental/nnx/examples/03_train_state.py @@ -0,0 +1,117 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax + +from flax.experimental import nnx + +X = np.linspace(0, 1, 100)[:, None] +Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + +def dataset(batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable[nnx.A]): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear1 = Linear(din, dhidden, rngs=rngs) + self.linear2 = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count = self.count + 1 + x = self.linear1(x) + x = jax.nn.relu(x) + x = self.linear2(x) + return x + + +params, counts, moduledef = MLP( + din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0) +).split(nnx.Param, ...) + +state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(0.1), + counts=counts, +) +del params, counts + + +@jax.jit +def train_step(state: nnx.TrainState[MLP], batch): + x, y = batch + + def loss_fn(params): + y_pred, (updates, _) = state.apply(params, 'counts')(x) + counts = updates.extract(Count) + loss = jnp.mean((y - y_pred) ** 2) + return loss, counts + + grads, counts = jax.grad(loss_fn, has_aux=True)(state.params) + # sdg update + state = state.apply_gradients(grads=grads, counts=counts) + + return state + + +@jax.jit +def test_step(state: nnx.TrainState[MLP], batch): + x, y = batch + y_pred, _ = state.apply('params', 'counts')(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + +total_steps = 10_000 +for step, batch in enumerate(dataset(32)): + state = train_step(state, batch) + + if step % 1000 == 0: + logs = test_step(state, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + +model = moduledef.merge(state.params, state.counts) +print('times called:', model.count) + +y_pred = model(X) + +plt.scatter(X, Y, color='blue') +plt.plot(X, y_pred, color='black') +plt.show() diff --git a/flax/experimental/nnx/examples/05_vae.py b/flax/experimental/nnx/examples/05_vae.py new file mode 100644 index 0000000000..719bab9ef4 --- /dev/null +++ b/flax/experimental/nnx/examples/05_vae.py @@ -0,0 +1,218 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from datasets import load_dataset + +from flax.experimental import nnx + +np.random.seed(42) +latent_size = 32 +image_shape: tp.Sequence[int] = (28, 28) +steps_per_epoch: int = 200 +batch_size: int = 64 +epochs: int = 20 + + +dataset = load_dataset('mnist') +X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8) +X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8) +# Now binarize data +X_train = (X_train > 0).astype(jnp.float32) +X_test = (X_test > 0).astype(jnp.float32) + +print('X_train:', X_train.shape, X_train.dtype) +print('X_test:', X_test.shape, X_test.dtype) + + +class Loss(nnx.Variable): + pass + + +# %% +class Encoder(nnx.Module): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dmid, rngs=rngs) + self.linear_mean = nnx.Linear(dmid, dout, rngs=rngs) + self.linear_std = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + x = x.reshape((x.shape[0], -1)) # flatten + x = self.linear1(x) + x = jax.nn.relu(x) + + mean = self.linear_mean(x) + std = jnp.exp(self.linear_std(x)) + + self.kl_loss = Loss( + jnp.mean( + 0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) + ) + ) + key = rngs.noise() + z = mean + std * jax.random.normal(key, mean.shape) + return z + + +class Decoder(nnx.Module): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dmid, rngs=rngs) + self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, z: jax.Array) -> jax.Array: + z = self.linear1(z) + z = jax.nn.relu(z) + logits = self.linear2(z) + return logits + + +class VAE(nnx.Module): + def __init__( + self, + din: int, + hidden_size: int, + latent_size: int, + output_shape: tp.Sequence[int], + *, + rngs: nnx.Rngs, + ): + self.output_shape = output_shape + self.encoder = Encoder(din, hidden_size, latent_size, rngs=rngs) + self.decoder = Decoder( + latent_size, hidden_size, int(np.prod(output_shape)), rngs=rngs + ) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + z = self.encoder(x, rngs=rngs) + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.output_shape)) + return logits + + def generate(self, z): + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.output_shape)) + return nnx.sigmoid(logits) + + +params, moduledef = VAE( + din=int(np.prod(image_shape)), + hidden_size=256, + latent_size=latent_size, + output_shape=image_shape, + rngs=nnx.Rngs(0), +).split(nnx.Param) + +state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), +) + + +# %% +@jax.jit +def train_step(state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array): + def loss_fn(params: nnx.State): + rngs = nnx.Rngs(noise=jax.random.fold_in(key, state.step)) + logits, (updates, _) = state.apply(params)(x, rngs=rngs) + + losses = updates.extract(Loss) + kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + # jax.debug.print("kl_loss={kl_loss}", kl_loss=kl_loss) + + loss = reconstruction_loss + 0.1 * kl_loss + return loss + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + + return state, loss + + +@partial(jax.jit, donate_argnums=(0,)) +def forward( + state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array +) -> jax.Array: + rngs = nnx.Rngs(noise=key) + y_pred = state.apply('params')(x, rngs=rngs)[0] + return jax.nn.sigmoid(y_pred) + + +@jax.jit +def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: + return state.apply('params').generate(z)[0] + + +# %% +key = jax.random.key(0) + +for epoch in range(epochs): + losses = [] + for step in range(steps_per_epoch): + idxs = np.random.randint(0, len(X_train), size=(batch_size,)) + x_batch = X_train[idxs] + + state, loss = train_step(state, x_batch, key) + losses.append(np.asarray(loss)) + + print(f'Epoch {epoch} loss: {np.mean(losses)}') + +# exit() +# %% +# get random samples +idxs = np.random.randint(0, len(X_test), size=(5,)) +x_sample = X_test[idxs] + +# get predictions +y_pred = forward(state, x_sample, key) + +# plot reconstruction +figure = plt.figure(figsize=(3 * 5, 3 * 2)) +plt.title('Reconstruction Samples') +for i in range(5): + plt.subplot(2, 5, i + 1) + plt.imshow(x_sample[i], cmap='gray') + plt.subplot(2, 5, 5 + i + 1) + plt.imshow(y_pred[i], cmap='gray') + # # tbwriter.add_figure("VAE Example", figure, epochs) + +plt.show() + +# %% +# plot generative samples +z_samples = np.random.normal(scale=1.5, size=(12, latent_size)) +samples = sample(state, z_samples) + +figure = plt.figure(figsize=(3 * 5, 3 * 2)) +plt.title('Generative Samples') +for i in range(5): + plt.subplot(2, 5, 2 * i + 1) + plt.imshow(samples[i], cmap='gray') + plt.subplot(2, 5, 2 * i + 2) + plt.imshow(samples[i + 1], cmap='gray') + +plt.show() + +# %% diff --git a/flax/experimental/nnx/examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/06_scan_over_layers.py new file mode 100644 index 0000000000..24dcfdb22c --- /dev/null +++ b/flax/experimental/nnx/examples/06_scan_over_layers.py @@ -0,0 +1,87 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax +import jax.numpy as jnp + +from flax.experimental import nnx + + +class Block(nnx.Module): + def __init__(self, dim: int, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(dim, dim, rngs=rngs) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + x = self.linear(x) + x = self.dropout(x, rngs=rngs) + x = jax.nn.gelu(x) + return x + + +class ScanMLP(nnx.Module): + """ + An MLP that uses `vmap` during `__init__` to create a Block instance + with an additional `layer` axis, and `scan` during `__call__` to apply + the sequence of layers iteratively over the input / output `x`. + """ + + def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): + self.n_layers = n_layers + # fork Rngs, split keys into `n_layers` + keys = rngs.fork(n_layers) + + def create_block(keys): + # create Block instance and return its split + return Block(dim, rngs=nnx.Rngs(keys)).split() + + # call vmap over create_block, passing the split `params` key + # and immediately merge to get a Block instance + self.layers = nnx.merge(jax.vmap(create_block)(keys)) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + # fork Rngs, split keys into `n_layers` + keys = rngs.fork(self.n_layers) + # split Module to get params + params, moduledef = self.layers.split(nnx.Param) + + def scan_fn( + x: jax.Array, inputs: Tuple[nnx.State, dict[str, nnx.RngStream]] + ) -> Tuple[jax.Array, nnx.State]: + params, keys = inputs + # merge back Module and Rngs + module = moduledef.merge(params) + # forward pass + x = module(x, rngs=nnx.Rngs(keys)) + # split state and return + params, _ = module.split(nnx.Param) + return x, params + + # call scan passing x as the carry, and params + keys as the input + x, params = jax.lax.scan(scan_fn, x, (params, keys)) + # update layers state and return + self.layers.update(params) + return x + + +model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) + +x = jnp.ones((3, 10)) +with nnx.flags(deterministic=False): + y = model(x, rngs=nnx.Rngs(dropout=1)) + +print(jax.tree_map(jnp.shape, model.get_state())) +print(y.shape) diff --git a/flax/experimental/nnx/examples/07_transformer.py b/flax/experimental/nnx/examples/07_transformer.py new file mode 100644 index 0000000000..d0352e32dd --- /dev/null +++ b/flax/experimental/nnx/examples/07_transformer.py @@ -0,0 +1,414 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import PartitionSpec as P + +from flax.experimental import nnx + +ShardSpec = tp.Union[str, tp.Tuple[str, ...], None] + + +# Sharding +@dataclasses.dataclass +class Sharding: + batch: ShardSpec = 'data' + sequence: ShardSpec = None + layers: ShardSpec = None + vocab: ShardSpec = 'model' + embed: ShardSpec = None + heads: ShardSpec = 'model' + depth: ShardSpec = None + hidden: ShardSpec = 'model' + + +# Config +@dataclasses.dataclass +class Config: + # mode + decode: bool = False + # shapes + batch: int = 16 + layers: int = 2 + vocab: int = 1024 + embed: int = 64 + heads: int = 12 + depth: int = 64 + hidden: int = 256 + max_length: int = 256 + # dtypes + param_dtype: tp.Any = jnp.float32 + dtype: tp.Any = jnp.float32 + # sharding + sharding: Sharding = Sharding() + scanned: bool = False + # layer params + epsilon: float = 1e-6 + dropout_rate: float = 0.0 + rp_num_buckets: int = 32 + rp_max_distance: int = 128 + + +cfg = Config() + + +def nd_dense_init(scale, mode, distribution): + """Initializer with in_axis, out_axis set at call time.""" + + def init_fn(key, shape, dtype, in_axis, out_axis) -> jax.Array: + fn = jax.nn.initializers.variance_scaling( + scale, mode, distribution, in_axis, out_axis + ) + return fn(key, shape, dtype) + + return init_fn + + +dense_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal') +embed_init = nd_dense_init(1.0, 'fan_in', 'normal') + + +def make_attention_mask( + query_input: tp.Any, + key_input: tp.Any, + pairwise_fn: tp.Callable = jnp.multiply, + dtype: tp.Any = jnp.float32, +): + mask = pairwise_fn( + jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) + ) + return jnp.expand_dims(mask, axis=-3).astype(dtype) + + +def make_causal_mask(x, dtype=jnp.float32): + idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) + return make_attention_mask(idxs, idxs, jnp.greater_equal, dtype=dtype) + + +# padding mask +# make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype) +# packing mask +# make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype) + + +def sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): + fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features + timescale = min_timescale * (max_timescale / min_timescale) ** fraction + rotational_frequency = 1.0 / timescale + # Must use high precision einsum here, bfloat16 rounding is catastrophic. + sinusoid_inp = jnp.einsum( + 'i,j->ij', + jnp.arange(length), + rotational_frequency, + precision=jax.lax.Precision.HIGHEST, + ) + sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) + return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) + + +def rotate_half(x): + x1, x2 = jnp.split(x, 2, axis=-1) + x = jnp.concatenate([-x2, x1], axis=-1) + return x + + +def apply_rotary_embedding(q, k, cos, sin, index=None): + """Helper function to apply Rotary Embeddings.""" + batch, qlen, qheads, d = q.shape + kbatch, klen, kheads, kd = k.shape + if index is not None: + qcos = jax.lax.broadcast_in_dim( + cos[index, :], (batch, qlen, qheads, d), (3,) + ) + qsin = jax.lax.broadcast_in_dim( + sin[index, :], (batch, qlen, qheads, d), (3,) + ) + else: + qcos = jax.lax.broadcast_in_dim( + cos[:qlen, :], (batch, qlen, qheads, d), (1, 3) + ) + qsin = jax.lax.broadcast_in_dim( + sin[:qlen, :], (batch, qlen, qheads, d), (1, 3) + ) + kcos = jax.lax.broadcast_in_dim( + cos[:klen, :], (batch, klen, kheads, d), (1, 3) + ) + ksin = jax.lax.broadcast_in_dim( + sin[:klen, :], (batch, klen, kheads, d), (1, 3) + ) + out_q = (q * qcos) + (rotate_half(q) * qsin) + out_k = (k * kcos) + (rotate_half(k) * ksin) + return out_q, out_k + + +def rms_norm(cfg, scale, x): + x = jnp.asarray(x, jnp.float32) + mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * jax.lax.rsqrt(mean2 + cfg.epsilon), cfg.dtype) + return y * jnp.asarray(scale, cfg.dtype) + + +def dropout(cfg: Config, x, broadcast_dims=(-2,), *, rngs: nnx.Rngs): + if cfg.dropout_rate == 0.0: + return x + broadcast_shape = list(x.shape) + for dim in broadcast_dims: + broadcast_shape[dim] = 1 + keep_rate = 1.0 - cfg.dropout_rate + key = rngs.dropout() + mask = jax.random.bernoulli(key, p=keep_rate, shape=broadcast_shape) + return jax.lax.select( + jnp.broadcast_to(mask, x.shape), x / keep_rate, jnp.zeros_like(x) + ) + + +class Attention(nnx.Module): + def __init__(self, cfg: Config, *, rngs: nnx.Rngs): + sharding = cfg.sharding + + key = rngs.params() + self.WQ = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = rngs.params() + self.WK = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = rngs.params() + self.WV = nnx.Param( + dense_init( + key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) + ), + P(sharding.embed, sharding.heads, sharding.depth), + ) + key = rngs.params() + self.WO = nnx.Param( + dense_init( + key, (cfg.heads, cfg.depth, cfg.embed), cfg.param_dtype, (0, 1), 2 + ), + P(sharding.heads, sharding.depth, sharding.embed), + ) + # cache + self.index = nnx.variable('cache', jnp.array(0, dtype=jnp.int32), P()) + self.key = nnx.variable( + 'cache', + jnp.zeros( + (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), + jnp.bfloat16, + ), + P(sharding.batch, sharding.heads, sharding.depth, None), + ) + self = nnx.variable( + 'cache', + jnp.zeros( + (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), + jnp.bfloat16, + ), + P(sharding.batch, sharding.heads, sharding.depth, None), + ) + + # We combine the cache and params into "vs", but it would be no harder at all + # to thread through a separate "cache" argument storing cache entries. + def __call__(self, cfg: Config, x_q, x_kv, mask=None, *, rngs: nnx.Rngs): + q = jnp.einsum('bse,enh->bsnh', x_q, self.WQ.astype(cfg.dtype)).astype( + jnp.float32 + ) + k = jnp.einsum('bte,enh->btnh', x_kv, self.WK.astype(cfg.dtype)).astype( + jnp.float32 + ) + v = jnp.einsum('bte,enh->btnh', x_kv, self.WV.astype(cfg.dtype)) + + index = None + if cfg.decode: + index = self.index + one_hot_indices = jax.nn.one_hot( + self.index, cfg.max_length, dtype=cfg.dtype + ) + self.key = self.key + jnp.moveaxis(k, -3, -1) * one_hot_indices + self = self + jnp.moveaxis(v, -3, -1) * one_hot_indices + k = jnp.moveaxis(self.key, -1, -3) + v = jnp.moveaxis(self, -1, -3) + cache_mask = jnp.broadcast_to( + jnp.arange(cfg.max_length) <= self.index, + (cfg.batch, 1, 1, cfg.max_length), + ) + mask = jnp.logical_and( + cache_mask if mask is None else mask, cache_mask + ).astype(cfg.dtype) + self.index = self.index + 1 + + attention_bias = 0.0 + if mask is None: # Hack in lieu of general mask routing. + mask = make_causal_mask(x, jnp.float32) + if mask is not None: + attention_bias = jax.lax.select( + mask > 0, + jnp.full(mask.shape, 0.0, cfg.dtype), + jnp.full(mask.shape, -1e10, cfg.dtype), + ) + + sin, cos = sine_table(q.shape[-1], max(q.shape[1], k.shape[1])) + q, k = apply_rotary_embedding(q, k, cos, sin, index=index) + + l = ( + jnp.einsum('bsnh,btnh->bnst', q, k) / np.sqrt(cfg.depth) + attention_bias + ) + s = jax.nn.softmax(l).astype(cfg.dtype) + s = dropout(cfg, s, rngs=rngs) + a = jnp.einsum('bnst,btnh->bsnh', s, v) + o = jnp.einsum('bsnh,nhe->bse', a, self.WO.astype(cfg.dtype)) + + return o + + +class MLP(nnx.Module): + def __init__(self, cfg: Config, *, rngs: nnx.Rngs): + sharding = cfg.sharding + self.Win1 = nnx.Param( + dense_init( + rngs.params(), + (cfg.embed, cfg.hidden), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.embed, sharding.hidden), + ) + self.Win2 = nnx.Param( + dense_init( + rngs.params(), + (cfg.embed, cfg.hidden), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.embed, sharding.hidden), + ) + self.Wout = nnx.Param( + dense_init( + rngs.params(), + (cfg.hidden, cfg.embed), + cfg.param_dtype, + 0, + 1, + ), + P(sharding.hidden, sharding.embed), + ) + + def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): + h1 = jnp.einsum('bse,eh->bsh', x, self.Win1.astype(cfg.dtype)) + h2 = jnp.einsum('bse,eh->bsh', x, self.Win2.astype(cfg.dtype)) + h = jax.nn.gelu(h1) * h2 + h = dropout(cfg, h, rngs=rngs) + o = jnp.einsum('bsh,he->bse', h, self.Wout.astype(cfg.dtype)) + return o + + +class DecoderBlock(nnx.Module): + def __init__(self, cfg: Config, *, rngs: nnx.Rngs): + sharding = cfg.sharding + self.attn = Attention(cfg, rngs=rngs) + self.mlp = MLP(cfg, rngs=rngs) + self.scale1 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + self.scale2 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + + def __call__(self, cfg: Config, input, *, rngs: nnx.Rngs): + x = rms_norm(cfg, self.scale1, input) + x = self.attn(cfg, x, x, mask=None, rngs=rngs) + x = dropout(cfg, x, rngs=rngs) + x = x + input + y = rms_norm(cfg, self.scale2, x) + y = self.mlp(cfg, y, rngs=rngs) + y = dropout(cfg, y, rngs=rngs) + return y + x + + +class Decoder(nnx.Module): + def __init__(self, cfg: Config, *, rngs: nnx.Rngs): + sharding = cfg.sharding + self.embed = nnx.Param( + embed_init( + rngs.params(), + (cfg.vocab, cfg.embed), + cfg.param_dtype, + 1, + 0, + ), + P(sharding.vocab, sharding.embed), + ) + self.unembed = nnx.Param( + dense_init(rngs.params(), (cfg.embed, cfg.vocab), jnp.float32, 0, 1), + P(sharding.embed, sharding.vocab), + ) + self.scale1 = nnx.Param( + jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) + ) + + if cfg.scanned: + self.layers = nnx.merge( + jax.vmap(lambda key: DecoderBlock(cfg, rngs=nnx.Rngs(key)).split())( + jax.random.split(rngs.params(), cfg.layers) + ) + ) + else: + self.layers = nnx.Sequence( + DecoderBlock(cfg, rngs=rngs) for _ in range(cfg.layers) + ) + + def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): + # TODO: handle right-shifting for training: here or in train loop. + # TODO: handle general mask routing. + x = self.embed.astype(cfg.dtype)[x] + + if cfg.scanned: + assert isinstance(self.layers, DecoderBlock) + + state, moduledef = self.layers.split() + rngs, rngsdef = rngs.fork() + dropout_key = jax.random.split(rngs['dropout'], cfg.layers) + + def scan_fn(x, s: tp.Tuple[jax.Array, nnx.State]): + dropout_key, state = s + rngs = rngsdef.merge({'dropout': dropout_key}) + y, (state, _) = moduledef.apply(state)(cfg, x, rngs=rngs) + return y, state + + x, state = jax.lax.scan( + scan_fn, + x, + (dropout_key, state), + ) + self.layers.update(state) + else: + assert isinstance(self.layers, nnx.Sequence) + for decoder_block in self.layers: + x = decoder_block(cfg, x, rngs=rngs) + + x = jnp.einsum('bse,ev->bsv', x, self.unembed) + return x diff --git a/flax/experimental/nnx/examples/08_save_load_checkpoints.py b/flax/experimental/nnx/examples/08_save_load_checkpoints.py new file mode 100644 index 0000000000..4e958a8ad1 --- /dev/null +++ b/flax/experimental/nnx/examples/08_save_load_checkpoints.py @@ -0,0 +1,67 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tempfile import TemporaryDirectory + +import jax +import jax.numpy as jnp +import orbax.checkpoint as orbax + +from flax.experimental import nnx + + +class MLP(nnx.Module): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.dense1 = nnx.Linear(din, dmid, rngs=rngs) + self.dense2 = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.dense1(x) + x = jax.nn.relu(x) + x = self.dense2(x) + return x + + +def create_model(seed: int): + return MLP(10, 20, 30, rngs=nnx.Rngs(seed)) + + +def create_and_save(seed: int, path: str): + model = create_model(seed) + state = model.get_state() + # Save the parameters + checkpointer = orbax.PyTreeCheckpointer() + checkpointer.save(f'{path}/state', state) + + +def load_model(path: str) -> MLP: + # create that model with abstract shapes + state, moduledef = jax.eval_shape(lambda: create_model(0).split()) + # Load the parameters + checkpointer = orbax.PyTreeCheckpointer() + state = checkpointer.restore(f'{path}/state', item=state) + # Merge the parameters into the model + model = moduledef.merge(state) + return model + + +with TemporaryDirectory() as tmpdir: + # create a checkpoint + create_and_save(42, tmpdir) + # load model from checkpoint + model = load_model(tmpdir) + # run the model + y = model(jnp.ones((1, 10))) + print(model) + print(y) diff --git a/flax/experimental/nnx/examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/09_parameter_surgery.py new file mode 100644 index 0000000000..cbdeae3eed --- /dev/null +++ b/flax/experimental/nnx/examples/09_parameter_surgery.py @@ -0,0 +1,56 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax + +from flax.experimental import nnx + + +# lets pretend this function loads a pretrained model from a checkpoint +def load_pretrained(): + return nnx.Linear(784, 128, rngs=nnx.Rngs(0)) + + +# create a simple linear classifier using a pretrained backbone +class Classifier(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.backbone = nnx.Linear(784, 128, rngs=nnx.Rngs(0)) + self.head = nnx.Linear(128, 10, rngs=rngs) + + def __call__(self, x): + x = self.backbone(x) + x = nnx.relu(x) + x = self.head(x) + return x + + +# create the classifier using the pretrained backbone, here we are technically +# doing "parameter surgery", however, compared to Haiku/Flax where you must manually +# construct the parameter structure, in NNX this is done automatically +model = Classifier(rngs=nnx.Rngs(42)) +model.backbone = load_pretrained() + + +# create a filter to select all the parameters that are not part of the +# backbone, i.e. the classifier parameters +is_trainable = lambda path, node: ( + path.startswith('backbone') and isinstance(node, nnx.Param) +) + +# split the parameters into trainable and non-trainable parameters +trainable_params, non_trainable, moduledef = model.split(is_trainable, ...) + +print('trainable_params =', jax.tree_map(jax.numpy.shape, trainable_params)) +print('non_trainable = ', jax.tree_map(jax.numpy.shape, non_trainable)) diff --git a/flax/experimental/nnx/examples/10_quantization.py b/flax/experimental/nnx/examples/10_quantization.py new file mode 100644 index 0000000000..0ac4ac7de2 --- /dev/null +++ b/flax/experimental/nnx/examples/10_quantization.py @@ -0,0 +1,437 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from datasets import load_dataset + +from flax.experimental import nnx + +np.random.seed(42) +image_shape: tp.Sequence[int] = (28, 28) +steps_per_epoch: int = 200 +batch_size: int = 64 +epochs: int = 20 + + +@jax.custom_vjp +def diff_round(x) -> jax.Array: + y = jnp.round(x) + return y + + +def diff_round_fwd(x): + return diff_round(x), None + + +def diff_round_bwd(_, g): + return (g,) + + +diff_round.defvjp(diff_round_fwd, diff_round_bwd) + + +@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +def diff_clip(x, low, high) -> jax.Array: + return jnp.clip(x, low, high) + + +def diff_clip_fwd(x, low, high): + return diff_clip(x, low, high), None + + +def diff_clip_bwd(_, _1, _2, dy): + return (dy,) + + +diff_clip.defvjp(diff_clip_fwd, diff_clip_bwd) + + +# %% +def f(x): + return diff_clip(diff_round(x * 128) + 128, 0, 255) + + +df = jax.vmap(jax.grad(f)) + +x = jnp.linspace(-1.5, 1.5, 100) +dx = df(x) + +plt.plot(x, dx) + +# %% +dataset = load_dataset('mnist') +X_train = np.array(np.stack(dataset['train']['image']), dtype=np.float32) +Y_train = np.array(dataset['train']['label'], dtype=np.int32) +X_test = np.array(np.stack(dataset['test']['image']), dtype=np.float32) +Y_test = np.array(dataset['test']['label'], dtype=np.int32) +# normalize data +X_train = X_train / 255.0 +X_test = X_test / 255.0 + + +print('X_train:', X_train.shape, X_train.dtype) +print('X_test:', X_test.shape, X_test.dtype) + + +# %% +class MLP(nnx.Module): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dmid, rngs=rngs) + self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = x.reshape((x.shape[0], -1)) + x = self.linear1(x) + x = jax.nn.gelu(x) + x = self.linear2(x) + return x + + +params, moduledef = MLP( + din=np.prod(image_shape), dmid=256, dout=10, rngs=nnx.Rngs(0) +).split(nnx.Param) + +state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), +) + + +# %% +@jax.jit +def train_step( + state: nnx.TrainState[MLP], + inputs: jax.Array, + labels: jax.Array, +): + def loss_fn(params: nnx.State): + logits, _ = state.apply(params)(inputs) + loss = jnp.mean( + optax.softmax_cross_entropy_with_integer_labels(logits, labels) + ) + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + state = state.apply_gradients(grads=grads) + + return state, loss + + +@jax.jit +def eval_step(state: nnx.TrainState[MLP], inputs: jax.Array, labels: jax.Array): + logits, _ = state.apply('params')(inputs) + loss = jnp.mean( + optax.softmax_cross_entropy_with_integer_labels(logits, labels) + ) + acc = jnp.mean(jnp.argmax(logits, axis=-1) == labels) + return {'loss': loss, 'accuracy': acc} + + +@partial(jax.jit, donate_argnums=(0,)) +def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array: + y_pred = state.apply('params')(inputs)[0] + return jnp.argmax(y_pred, axis=-1) + + +# %% +key = jax.random.key(0) + +for epoch in range(epochs): + for step in range(steps_per_epoch): + idxs = np.random.randint(0, len(X_train), size=(batch_size,)) + x_batch = X_train[idxs] + y_batch = Y_train[idxs] + + state, loss = train_step(state, x_batch, y_batch) + + metrics = eval_step(state, X_test, Y_test) + metrics = jax.tree_map(lambda x: x.item(), metrics) + print(f'Epoch {epoch} - {metrics}') + +# %% +# get random samples +idxs = np.random.randint(0, len(X_test), size=(10,)) +x_sample = X_test[idxs] +y_sample = Y_test[idxs] + +# get predictions +y_pred = forward(state, x_sample) + +# plot predictions +figure = plt.figure(figsize=(3 * 5, 3 * 2)) + +for i in range(10): + plt.subplot(2, 5, i + 1) + plt.imshow(x_sample[i].reshape(image_shape), cmap='gray') + plt.title(f'{y_pred[i]}') + +plt.show() + +model = state.moduledef.merge(state.params) +# %% +# Quantization + +A = tp.TypeVar('A') + + +class QParam(nnx.Variable[A]): + pass + + +class QHParam(nnx.Variable[A]): + pass + + +class QLinear(nnx.Module): + def __init__(self, din: int, dout: int): + self.scale = QHParam(jnp.array(0.5)) + self.zero_point = QHParam(jnp.array(0.5)) + self.qkernel = QParam(jnp.zeros((din, dout))) + self.qbias = QParam(jnp.zeros((dout,))) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.quantize(x, 8, jnp.uint8) + print(x.shape, self.qkernel.shape, self.qbias.shape) + x = jnp.dot(x, self.qkernel, preferred_element_type=jnp.uint16) + x = (x + self.qbias).astype(jnp.uint32) + x = self.dequantize(x) + return x + + def quantize(self, x: jax.Array, b: int, dtype: jnp.dtype) -> jax.Array: + return jnp.clip( + diff_round(x / self.scale) + self.zero_point, 0, 2**b - 1 + ).astype(dtype) + + def dequantize(self, x: jax.Array) -> jax.Array: + return (x - self.zero_point) * self.scale + + def optimize( + self, + pretrained: nnx.Linear, + x: jax.Array, + *, + num_steps: int = 100, + debug: bool = False, + ): + q_hparams, rest, moduledef = self.split(QHParam, ...) + tx = optax.adam(1e-3) + opt_state = tx.init(q_hparams) + + print(jax.tree_map(lambda x: x.shape, q_hparams)) + + @jax.jit + def optimization_step( + q_hparams: nnx.State, + rest: nnx.State, + opt_state: optax.OptState, + x: jax.Array, + ): + print('JITTING') + + def loss_fn(q_hparams: nnx.State): + model = moduledef.merge(q_hparams, rest) + model.qkernel = model.quantize(pretrained.kernel, 8, jnp.uint8) + assert pretrained.bias is not None + model.qbias = model.quantize(pretrained.bias, 16, jnp.uint16) + + y_quant = model(x) + y_unquant = pretrained(x) + loss = jnp.mean((y_unquant - y_quant) ** 2) + return loss + + loss, grads = jax.value_and_grad(loss_fn)(q_hparams) + + updates, opt_state = tx.update(grads, opt_state, q_hparams) + q_hparams = optax.apply_updates(q_hparams, updates) # type: ignore + + return q_hparams, opt_state, loss + + for step in range(num_steps): + q_hparams, opt_state, loss = optimization_step( + q_hparams, rest, opt_state, x + ) + if debug and step % (num_steps / 10) == 0: + print(f'Step {step} - loss: {loss}') + + self.update(q_hparams) + + self.qkernel = self.quantize(pretrained.kernel, 8, jnp.uint8) + assert pretrained.bias is not None + self.qbias = self.quantize(pretrained.bias, 16, jnp.uint16) + + +def optimize2( + self, + pretrained: nnx.Linear, + X: jax.Array, +): + W = pretrained.kernel + b = pretrained.bias + assert b is not None + + # X + alpha_X = jnp.min(X) + beta_X = jnp.max(X) + s_X, z_X = generate_quantization_int8_constants(alpha=alpha_X, beta=beta_X) + X_q = quantization_int8(x=X, s=s_X, z=z_X) + X_q_dq = dequantization(x_q=X_q, s=s_X, z=z_X) + + # W + alpha_W = jnp.min(W) + beta_W = jnp.max(W) + s_W, z_W = generate_quantization_int8_constants(alpha=alpha_W, beta=beta_W) + W_q = quantization_int8(x=W, s=s_W, z=z_W) + W_q_dq = dequantization(x_q=W_q, s=s_W, z=z_W) + + # b + alpha_b = jnp.min(b) + beta_b = jnp.max(b) + s_b, z_b = generate_quantization_int8_constants(alpha=alpha_b, beta=beta_b) + b_q = quantization_int8(x=b, s=s_b, z=z_b) + b_q_dq = dequantization(x_q=b_q, s=s_b, z=z_b) + + # Y + Y = jnp.matmul(X, W) + b + alpha_Y = jnp.min(Y) + beta_Y = jnp.max(Y) + s_Y, z_Y = generate_quantization_int8_constants(alpha=alpha_Y, beta=beta_Y) + Y_q = quantization_int8(x=Y, s=s_Y, z=z_Y) + + Y_prime = jnp.matmul(X_q_dq, W_q_dq) + b_q_dq + Y_prime_q = quantization_int8(x=Y_prime, s=s_Y, z=z_Y) + Y_prime_q_dq = dequantization(x_q=Y_prime_q, s=s_Y, z=z_Y) + + print('Expected FP32 Y:') + print(Y) + print('Expected FP32 Y Quantized:') + print(Y_q) + + Y_q_simulated = quantization_matrix_multiplication_int8( + X_q=X_q, + W_q=W_q, + b_q=b_q, + s_X=s_X, + z_X=z_X, + s_W=s_W, + z_W=z_W, + s_b=s_b, + z_b=z_b, + s_Y=s_Y, + z_Y=z_Y, + ) + Y_simulated = dequantization(x_q=Y_q_simulated, s=s_Y, z=z_Y) + + print('Expected Quantized Y_q from Quantized Matrix Multiplication:') + print(Y_q_simulated) + print( + 'Expected Quantized Y_q from Quantized Matrix Multiplication Dequantized:' + ) + print(Y_simulated) + + # Ensure the algorithm implementation is correct + assert jnp.array_equal(Y_simulated, Y_prime_q_dq) + assert jnp.array_equal(Y_q_simulated, Y_prime_q) + + +def quantization(x, s, z, alpha_q, beta_q): + x_q = jnp.round(1 / s * x + z, decimals=0) + x_q = jnp.clip(x_q, a_min=alpha_q, a_max=beta_q) + + return x_q + + +def quantization_int8(x, s, z): + x_q = quantization(x, s, z, alpha_q=-128, beta_q=127) + x_q = x_q.astype(jnp.int8) + + return x_q + + +def dequantization(x_q, s, z): + # x_q - z might go outside the quantization range. + x_q = x_q.astype(jnp.int32) + x = s * (x_q - z) + x = x.astype(jnp.float32) + + return x + + +def generate_quantization_constants(alpha, beta, alpha_q, beta_q): + # Affine quantization mapping + s = (beta - alpha) / (beta_q - alpha_q) + z = int((beta * alpha_q - alpha * beta_q) / (beta - alpha)) + + return s, z + + +def generate_quantization_int8_constants(alpha, beta): + b = 8 + alpha_q = -(2 ** (b - 1)) + beta_q = 2 ** (b - 1) - 1 + + s, z = generate_quantization_constants( + alpha=alpha, beta=beta, alpha_q=alpha_q, beta_q=beta_q + ) + + return s, z + + +def quantization_matrix_multiplication_int8( + X_q, W_q, b_q, s_X, z_X, s_W, z_W, s_b, z_b, s_Y, z_Y +): + p = W_q.shape[0] + + # Y_q_simulated is FP32 + Y_q_simulated = ( + z_Y + + (s_b / s_Y * (b_q.astype(jnp.int32) - z_b)) + + ( + (s_X * s_W / s_Y) + * ( + jnp.matmul(X_q.astype(jnp.int32), W_q.astype(jnp.int32)) + - z_W * jnp.sum(X_q.astype(jnp.int32), axis=1, keepdims=True) + - z_X * jnp.sum(W_q.astype(jnp.int32), axis=0, keepdims=True) + + p * z_X * z_W + ) + ) + ) + + Y_q_simulated = jnp.round(Y_q_simulated, decimals=0) + Y_q_simulated = jnp.clip(Y_q_simulated, a_min=-128, a_max=127) + Y_q_simulated = Y_q_simulated.astype(jnp.int8) + + return Y_q_simulated + + +# %% +qlinear1 = QLinear(din=np.prod(image_shape), dout=256) +# qlinear2 = QLinear(din=256, dout=10) + +idxs = np.random.randint(0, len(X_test), size=(100,)) +x_optimize = jnp.asarray(X_test[idxs], dtype=jnp.float32) +x_optimize = x_optimize.reshape((x_optimize.shape[0], -1)) +print(x_optimize.shape) +qlinear1.optimize(model.linear1, x_optimize, num_steps=1000, debug=True) + +# %% diff --git a/flax/experimental/nnx/examples/requirements.txt b/flax/experimental/nnx/examples/requirements.txt new file mode 100644 index 0000000000..e44f155e47 --- /dev/null +++ b/flax/experimental/nnx/examples/requirements.txt @@ -0,0 +1,2 @@ +matplotlib>=3.7.1 +datasets>=2.12.0" \ No newline at end of file diff --git a/flax/experimental/nnx/ideas/shape_inference.py b/flax/experimental/nnx/ideas/shape_inference.py new file mode 100644 index 0000000000..272616a33f --- /dev/null +++ b/flax/experimental/nnx/ideas/shape_inference.py @@ -0,0 +1,210 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax import random + +from flax.experimental import nnx + + +class Linear(nnx.Module): + @tp.overload + def __init__(self, *, din: int, dout: int, rngs: nnx.Rngs): + ... + + @tp.overload + def __init__(self, *, dout: int): + ... + + @tp.overload + def __init__( + self, + *, + din: tp.Optional[int] = None, + dout: int, + rngs: tp.Optional[nnx.Rngs] = None, + ): + ... + + def __init__( + self, + *, + din: tp.Optional[int] = None, + dout: int, + rngs: tp.Optional[nnx.Rngs] = None, + ): + self.dout = dout + if din is not None: + if rngs is None: + raise ValueError('rngs must be provided if din is provided') + self.init_variables(din, rngs) + + def init_variables(self, din: int, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(random.uniform(key, (din, self.dout))) + self.b = nnx.Param(jnp.zeros((self.dout,))) + + def __call__( + self, x: jax.Array, *, rngs: tp.Optional[nnx.Rngs] = None + ) -> jax.Array: + if self.is_initializing and not hasattr(self, 'w'): + if rngs is None: + raise ValueError('rngs must be provided to initialize module') + self.init_variables(x.shape[-1], rngs) + + return x @ self.w + self.b + + +class BatchNorm(nnx.Module): + @tp.overload + def __init__(self, *, mu: float = 0.95): + ... + + @tp.overload + def __init__(self, *, din: int, mu: float = 0.95, rngs: nnx.Rngs): + ... + + @tp.overload + def __init__( + self, + *, + din: tp.Optional[int] = None, + mu: float = 0.95, + rngs: tp.Optional[nnx.Rngs] = None, + ): + ... + + def __init__( + self, + *, + din: tp.Optional[int] = None, + mu: float = 0.95, + rngs: tp.Optional[nnx.Rngs] = None, + ): + self.mu = mu + + if din is not None: + if rngs is None: + raise ValueError('rngs must be provided if din is provided') + self.init_variables(din, rngs) + + def init_variables(self, din: int, rngs: nnx.Rngs): + self.scale = nnx.Param(jax.numpy.ones((din,))) + self.bias = nnx.Param(jax.numpy.zeros((din,))) + self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) + self.var = nnx.BatchStat(jax.numpy.ones((din,))) + + def __call__( + self, x, *, train: bool, rngs: tp.Optional[nnx.Rngs] = None + ) -> jax.Array: + if self.is_initializing and not hasattr(self, 'scale'): + if rngs is None: + raise ValueError('rngs must be provided to initialize module') + self.init_variables(x.shape[-1], rngs) + + if train: + axis = tuple(range(x.ndim - 1)) + mean = jax.numpy.mean(x, axis=axis) + var = jax.numpy.var(x, axis=axis) + # ema update + self.mean = self.mu * self.mean + (1 - self.mu) * mean + self.var = self.mu * self.var + (1 - self.mu) * var + else: + mean, var = self.mean, self.var + + scale, bias = self.scale, self.bias + x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias + return x + + +class Dropout(nnx.Module): + def __init__(self, rate: float): + self.rate = rate + + def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: + if train: + mask = random.bernoulli(rngs.dropout(), (1 - self.rate), x.shape) + x = x * mask / (1 - self.rate) + return x + + +# ---------------------------- +# test Linear +# ---------------------------- +print('test Linear') + +# eager +m1 = Linear(din=32, dout=10, rngs=nnx.Rngs(params=0)) +y = m1(x=jnp.ones((1, 32))) +print(jax.tree_map(jnp.shape, m1.get_state())) + +# lazy +m2 = Linear(dout=10) +y = m2.init(x=jnp.ones((1, 32)), rngs=nnx.Rngs(params=0)) +print(jax.tree_map(jnp.shape, m2.get_state())) + +# usage +y1 = m1(x=jnp.ones((1, 32))) +y2 = m2(x=jnp.ones((1, 32))) + +# ---------------------------- +# Test scan +# ---------------------------- +print('\ntest scan') + + +class Block(nnx.Module): + def __init__( + self, + din: tp.Optional[int] = None, + dout: int = 10, + rngs: tp.Optional[nnx.Rngs] = None, + ): + self.linear = Linear(din=din, dout=dout, rngs=rngs) + self.bn = BatchNorm(din=dout if din is not None else None, rngs=rngs) + self.dropout = Dropout(0.5) + + def __call__(self, x: jax.Array, _, *, train: bool, rngs: nnx.Rngs): + x = self.linear(x, rngs=rngs) + x = self.bn(x, train=train, rngs=rngs) + x = self.dropout(x, train=train, rngs=rngs) + x = jax.nn.gelu(x) + return x, None + + +MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + variable_carry=nnx.BatchStat, + split_rngs={'params': True, 'dropout': True}, + length=5, +) + + +# eager +mlp = MLP(din=10, dout=10, rngs=nnx.Rngs(params=0)) +y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) +print(f'{y.shape=}') +print('state =', jax.tree_map(jnp.shape, mlp.get_state())) +print() + +# lazy +mlp = MLP(dout=10) +mlp.init(jnp.ones((1, 10)), None, train=False, rngs=nnx.Rngs(params=0)) +y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) +print(f'{y.shape=}') +print('state =', jax.tree_map(jnp.shape, mlp.get_state())) diff --git a/flax/experimental/nnx/nnx/__init__.py b/flax/experimental/nnx/nnx/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/nnx/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/nnx/compatibility.py b/flax/experimental/nnx/nnx/compatibility.py new file mode 100644 index 0000000000..919b69bb31 --- /dev/null +++ b/flax/experimental/nnx/nnx/compatibility.py @@ -0,0 +1,93 @@ +import dataclasses +import typing as tp +from typing import Any + +from flax import linen +from flax.experimental.nnx.nnx import helpers +from flax.experimental.nnx.nnx import variables as variableslib +from flax.experimental.nnx.nnx.module import Module, ModuleDef +from flax.experimental.nnx.nnx.rnglib import Rngs +from flax.experimental.nnx.nnx.state import State + +M = tp.TypeVar('M', bound=Module) + + +# Flax-like style is NNX +@dataclasses.dataclass +class Functional(tp.Generic[M]): + module_type: tp.Type[M] + moduledef: tp.Optional[ModuleDef[M]] + args: tuple[tp.Any, ...] + kwargs: dict[str, tp.Any] + + def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: + kwargs = {} + if rngs is not None: + kwargs['rngs'] = rngs + module = self.module_type(*self.args, **self.kwargs, **kwargs) + state, moduledef = module.split() + self.moduledef = moduledef + return state + + def apply(self, *states: tp.Any): + assert self.moduledef is not None + return self.moduledef.apply(*states) + + +def functional(cls: tp.Type[M]) -> tp.Callable[..., Functional[M]]: + def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]: + return Functional(cls, None, args, kwargs) + + return _functional_constructor + + +class LinenWrapper(Module): + def __init__( + self, + module: linen.Module, + *args: tp.Any, + rngs: tp.Optional[Rngs] = None, + **kwargs: tp.Any, + ): + self.module = module + + _rngs = ( + {name: stream.key for name, stream in rngs._rngs.items()} if rngs else {} + ) + # rename default to params + if 'params' not in _rngs and 'default' in _rngs: + _rngs['params'] = _rngs['default'] + del _rngs['default'] + + variables = module.init(_rngs, *args, **kwargs) + + self.states = helpers.Dict( + (collection, variableslib.variable_type(collection)(value)) + for collection, value in variables.items() + ) + + def __call__( + self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any + ) -> Any: + _rngs = ( + {name: stream.key for name, stream in rngs._rngs.items()} if rngs else {} + ) + + variables = {collection: value for collection, value in self.states.items()} + out = self.module.apply(variables, *args, rngs=_rngs, **kwargs) + + if kwargs.get('mutable', False) != False: + out, updates = out + for collection, value in updates.items(): + if collection in self.states: + self.states[collection] = value + else: + self.states[collection] = variableslib.variable_type(collection)( + value + ) + + return out + + +class NNXWrapper(linen.Module): + ... diff --git a/flax/experimental/nnx/nnx/dataclasses.py b/flax/experimental/nnx/nnx/dataclasses.py new file mode 100644 index 0000000000..9a37b2f667 --- /dev/null +++ b/flax/experimental/nnx/nnx/dataclasses.py @@ -0,0 +1,188 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import typing as tp + +import typing_extensions as tpe + +from flax.experimental import nnx +from flax.experimental.nnx.nnx import pytreelib, variables + +A = tp.TypeVar('A') + + +def field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + return dataclasses.field( # type: ignore + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def treenode_field( + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +): + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if 'nnx_variable_constructor' in metadata: + raise ValueError("'nnx_variable_constructor' found in metadata") + + metadata['nnx_variable_constructor'] = lambda value: pytreelib.TreeNode(value) + + return field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def variable_field( + variable_type: tp.Type[variables.Variable[tp.Any]], + *, + default: tp.Any = dataclasses.MISSING, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +) -> tp.Any: + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + + if 'nnx_variable_constructor' in metadata: + raise ValueError("'nnx_variable_constructor' found in metadata") + + metadata['nnx_variable_constructor'] = lambda value: variable_type(value) + + return field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +def param_field( + default: tp.Any = dataclasses.MISSING, + *, + default_factory: tp.Any = dataclasses.MISSING, + init: bool = True, + repr: bool = True, + hash: tp.Optional[bool] = None, + compare: bool = True, + metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None, +) -> tp.Any: + return variable_field( + variables.Param, + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + ) + + +@tp.overload +def dataclass(cls: tp.Type[A]) -> tp.Type[A]: + ... + + +@tp.overload +def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Callable[[tp.Type[A]], tp.Type[A]]: + ... + + +@tpe.dataclass_transform( + field_specifiers=( + field, + treenode_field, + variable_field, + param_field, + ) +) +def dataclass( + cls: tp.Optional[tp.Type[A]] = None, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]: + def decorator(cls: tp.Type[A]) -> tp.Type[A]: + cls = dataclasses.dataclass( + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + )(cls) + if issubclass(cls, nnx.Module): + + def hash_fn(module: nnx.Module): + return hash(module._module__state.id) + + cls.__hash__ = hash_fn + + return cls + + if cls is None: + return decorator + + return decorator(cls) diff --git a/flax/experimental/nnx/nnx/errors.py b/flax/experimental/nnx/nnx/errors.py new file mode 100644 index 0000000000..c72305e62d --- /dev/null +++ b/flax/experimental/nnx/nnx/errors.py @@ -0,0 +1,17 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TraceContextError(Exception): + pass diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/experimental/nnx/nnx/filterlib.py new file mode 100644 index 0000000000..c80184935a --- /dev/null +++ b/flax/experimental/nnx/nnx/filterlib.py @@ -0,0 +1,100 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import dataclasses +import typing as tp + +if tp.TYPE_CHECKING: + ellipsis = builtins.ellipsis +else: + ellipsis = tp.Any + +Path = str +Predicate = tp.Callable[[Path, tp.Any], bool] +FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None] +Filter = tp.Union[FilterLiteral, tuple[FilterLiteral, ...], list[FilterLiteral]] + + +def to_predicate(filter: Filter) -> Predicate: + if isinstance(filter, str): + return AtPath(filter) + elif isinstance(filter, type): + return OfType(filter) + elif isinstance(filter, bool): + return Everything() if filter else Nothing() + elif filter is Ellipsis: + return Everything() + elif filter is None: + return Nothing() + elif callable(filter): + return filter + elif isinstance(filter, (list, tuple)): + return Any(*filter) + else: + raise TypeError(f'Invalid collection filter: {filter:!r}. ') + + +@dataclasses.dataclass +class AtPath: + path: str + + def __call__(self, path: Path, x: tp.Any): + return self.path == path + + +@dataclasses.dataclass +class OfType: + type: type + + def __call__(self, path: Path, x: tp.Any): + return isinstance(x, self.type) + + +class Any: + def __init__(self, *filters: Filter): + self.predicates = tuple( + to_predicate(collection_filter) for collection_filter in filters + ) + + def __call__(self, path: Path, x: tp.Any): + return any(predicate(path, x) for predicate in self.predicates) + + +class All: + def __init__(self, *filters: Filter): + self.predicates = tuple( + to_predicate(collection_filter) for collection_filter in filters + ) + + def __call__(self, path: Path, x: tp.Any): + return all(predicate(path, x) for predicate in self.predicates) + + +class Not: + def __init__(self, collection_filter: Filter): + self.predicate = to_predicate(collection_filter) + + def __call__(self, path: Path, x: tp.Any): + return not self.predicate(path, x) + + +class Everything: + def __call__(self, path: Path, x: tp.Any): + return True + + +class Nothing: + def __call__(self, path: Path, x: tp.Any): + return False diff --git a/flax/experimental/nnx/nnx/flaglib.py b/flax/experimental/nnx/nnx/flaglib.py new file mode 100644 index 0000000000..45d19d9be0 --- /dev/null +++ b/flax/experimental/nnx/nnx/flaglib.py @@ -0,0 +1,55 @@ +import dataclasses +import threading +import typing as tp +from contextlib import contextmanager +from types import MappingProxyType + + +@dataclasses.dataclass +class FlagsContext(threading.local): + flags_stack: tp.List[MappingProxyType[str, tp.Hashable]] = dataclasses.field( + default_factory=lambda: [MappingProxyType({})] + ) + + +FLAGS_CONTEXT = FlagsContext() + + +class Flags(tp.Mapping[str, tp.Hashable]): + __slots__ = () + + def __getitem__(self, name: str) -> tp.Hashable: + current_flags = FLAGS_CONTEXT.flags_stack[-1] + if name not in current_flags: + raise ValueError(f'Unknown Flag: {name}') + return current_flags[name] + + __getattr__ = __getitem__ + + def __iter__(self) -> tp.Iterator[str]: + return iter(FLAGS_CONTEXT.flags_stack[-1]) + + def __len__(self) -> int: + return len(FLAGS_CONTEXT.flags_stack[-1]) + + def __contains__(self, name: tp.Any) -> bool: + return name in FLAGS_CONTEXT.flags_stack[-1] + + @contextmanager + def __call__(self, **kwargs: tp.Hashable): + current_flags = FLAGS_CONTEXT.flags_stack[-1] + FLAGS_CONTEXT.flags_stack.append( + MappingProxyType(dict(current_flags, **kwargs)) + ) + try: + yield + finally: + FLAGS_CONTEXT.flags_stack.pop() + + def get( + self, name: str, default: tp.Hashable = None + ) -> tp.Optional[tp.Hashable]: + return FLAGS_CONTEXT.flags_stack[-1].get(name, default) + + +flags = Flags() diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py new file mode 100644 index 0000000000..42cae73e91 --- /dev/null +++ b/flax/experimental/nnx/nnx/helpers.py @@ -0,0 +1,177 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import inspect +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +import optax + +from flax.experimental.nnx.nnx import pytreelib +from flax.experimental.nnx.nnx.module import ApplyCaller, Module, ModuleDef +from flax.experimental.nnx.nnx.rnglib import Rngs +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar('A') +M = tp.TypeVar('M', bound=Module) + + +class Dict(Module, tp.Mapping[str, A]): + @tp.overload + def __init__(self, __iterable: tp.Iterable[tp.Tuple[str, A]]): + ... + + @tp.overload + def __init__( + self, __mapping: tp.Optional[tp.Mapping[str, A]] = None, **kwargs: A + ): + ... + + def __init__(self, *args, **kwargs): + for name, value in dict(*args, **kwargs).items(): + setattr(self, name, value) + + def __getitem__(self, key) -> A: + return getattr(self, key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getattr__(self, key) -> A: + return super().__getattribute__(key) + + def __setattr__(self, key, value): + super().__setattr__(key, value) + + def __iter__(self) -> tp.Iterator[str]: + return (k for k in vars(self) if k != '_module__state') + + def __len__(self) -> int: + return len(vars(self)) + + +class Sequence(Module, tp.Generic[A]): + def __init__(self, iterable: tp.Iterable[A]): + i = 0 + for i, value in enumerate(iterable): + setattr(self, str(i), value) + self._length = i + 1 + + def __getitem__(self, key: int) -> A: + if key >= len(self): + raise IndexError(f'index {key} out of range for {self}') + return getattr(self, str(key)) + + def __setitem__(self, key: int, value: A): + if key >= len(self): + raise IndexError(f'index {key} out of range for {self}') + setattr(self, str(key), value) + + def __iter__(self) -> tp.Iterator[A]: + for i in range(len(self)): + yield getattr(self, str(i)) + + def __len__(self) -> int: + return self._length + + def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: + output: tp.Any = None + + for i, f in enumerate(self): + if not callable(f): + raise TypeError(f'Sequence[{i}] is not callable: {f}') + if i > 0: + if isinstance(output, tp.Tuple): + args = output + kwargs = {} + elif isinstance(output, tp.Dict): + args = () + kwargs = output + else: + args = (output,) + kwargs = {} + if rngs is not None and has_keyword_arg(f, 'rngs'): + kwargs['rngs'] = rngs + + output = f(*args, **kwargs) + + return output + + +class ModuleDefApply(tp.Protocol, tp.Generic[M]): + def __call__( + self, state: State, *states: State + ) -> ApplyCaller[tuple[State, ModuleDef[M]]]: + ... + + +class TrainState(pytreelib.Pytree, tp.Generic[M]): + def __init__( + self, + moduledef: ModuleDef[M], + *, + params: State, + tx: optax.GradientTransformation, + step: int = 0, + **kwargs, + ): + self.moduledef = moduledef + self.params: State = pytreelib.TreeNode(params) + self.tx = tx + self.opt_state = pytreelib.TreeNode(tx.init(self.params)) + self.step = pytreelib.TreeNode(jnp.asarray(step)) + for name, value in kwargs.items(): + if isinstance(value, (jax.Array, np.ndarray, State)): + value = pytreelib.TreeNode(value) + setattr(self, name, value) + + if tp.TYPE_CHECKING: + + def __getattr__(self, key: str) -> tp.Any: + ... + + def apply( + self, state: tp.Union[State, str], *states: tp.Union[State, str] + ) -> ApplyCaller[tuple[State, ModuleDef[M]]]: + states = (state, *states) + + _states = ( + getattr(self, state) if isinstance(state, str) else state + for state in states + ) + + return self.moduledef.apply(*_states) + + def apply_gradients(self, grads: State, **kwargs) -> 'TrainState[M]': + updates, opt_state = self.tx.update(grads, self.opt_state, self.params) + params = optax.apply_updates(self.params, updates) # type: ignore + step = self.step + 1 + return self.replace( + params=params, + opt_state=opt_state, + step=step, + **kwargs, + ) + + +def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: + """Return True if func has keyword-only arguments with the given name.""" + return any( + param.name == name + and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) + for param in inspect.signature(func).parameters.values() + ) diff --git a/flax/experimental/nnx/nnx/ids.py b/flax/experimental/nnx/nnx/ids.py new file mode 100644 index 0000000000..40db11605f --- /dev/null +++ b/flax/experimental/nnx/nnx/ids.py @@ -0,0 +1,79 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UUIDs for Flax internals.""" + +import threading + + +class UUIDManager: + """Globally unique counter-based id manager. + + We need globally unique key ids for Module and Variable object instances + to preserve and recreate sharing-by-reference relationship when lifting + transforms and adopting outside Modules. + - Use of id() is unacceptable because these identifiers are literally + pointers which can be recycled, so we rely on a globally unique counter id + instead. + - We need to handle copy/deepcopy uniqueness via a wrapped type. + """ + + def __init__(self): + self._lock = threading.Lock() + self._id = 0 + + def __call__(self): + with self._lock: + self._id += 1 + return UUID(self._id) + + +uuid = UUIDManager() + + +class UUID: + """Hashable wrapper for ids that handles uniqueness of copies.""" + + def __init__(self, rawid): + self.id = rawid + + def __eq__(self, other): + return isinstance(other, UUID) and other.id == self.id + + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return f'UUID({self.id})' + + def __deepcopy__(self, memo): + del memo + return uuid() + + def __copy__(self): + return uuid() diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py new file mode 100644 index 0000000000..e7103680f5 --- /dev/null +++ b/flax/experimental/nnx/nnx/module.py @@ -0,0 +1,932 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +import enum +import typing as tp +from abc import ABCMeta +from copy import deepcopy +from functools import partial + +import jax +import jax.tree_util as jtu +import numpy as np +import typing_extensions as tpe + +from flax.experimental.nnx.nnx import ( + errors, + filterlib, + ids, + reprlib, + tracers, +) +from flax.experimental.nnx.nnx import variables as variableslib +from flax.experimental.nnx.nnx.rnglib import Rngs +from flax.experimental.nnx.nnx.state import State +from flax.experimental.nnx.nnx.variables import Variable + +A = tp.TypeVar('A') +B = tp.TypeVar('B') +M = tp.TypeVar('M', bound='Module') +S = tp.TypeVar('S', bound=tp.Union[State, tuple[State, ...]]) +V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any]) + +Path = str +PathParts = tuple[str, ...] +StateDict = tp.Dict[Path, tp.Any] +StateMapping = tp.Mapping[Path, tp.Any] + + +class _ProxyContext(tp.Protocol): + def __call__(self, accessor: 'DelayedAccessor', /, *args, **kwargs) -> tp.Any: + ... + + +@tp.runtime_checkable +class _HasSetup(tp.Protocol): + def setup(self) -> None: + ... + + +@dataclasses.dataclass +class CallableProxy: + _proxy_context: _ProxyContext + _proxy_callable: tp.Callable[..., tp.Any] + + def __call__(self, *args, **kwargs): + return self._proxy_context(self._proxy_callable, *args, **kwargs) + + def __getattr__(self, name) -> 'CallableProxy': + return CallableProxy( + self._proxy_context, getattr(self._proxy_callable, name) + ) + + def __getitem__(self, key) -> 'CallableProxy': + return CallableProxy(self._proxy_context, self._proxy_callable[key]) + + +def _identity(x): + return x + + +@dataclasses.dataclass +class DelayedAccessor: + accessor: tp.Callable[[tp.Any], tp.Any] = _identity + + def __call__(self, x): + return self.accessor(x) + + def __getattr__(self, name): + return DelayedAccessor(lambda x: getattr(x, name)) + + def __getitem__(self, key): + return DelayedAccessor(lambda x: x[key]) + + +class ApplyCaller(tp.Protocol, tp.Generic[A]): + def __getattr__(self, __name) -> 'ApplyCaller[A]': + ... + + def __getitem__(self, __name) -> 'ApplyCaller[A]': + ... + + def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: + ... + + +@dataclasses.dataclass(repr=False) +class _SubmodulesRepr(reprlib.Representable): + submodules: tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...] + + def __nnx_repr__(self): + yield reprlib.Object(type='', value_sep=', ') + + for name, submodule in self.submodules: + yield reprlib.Attr(repr(name), submodule, start='(', end=')') + + +class ModuleDef(tp.Generic[M], reprlib.Representable): + __slots__ = ( + '_type', + '_index', + '_submodules', + '_static_fields', + '_variables', + '_module_state', + ) + + def __init__( + self, + type: tp.Type[M], + index: int, + submodules: tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...], + static_fields: tuple[tuple[str, tp.Any], ...], + variables: tuple[ + tuple[str, variableslib.Variable[variableslib.Empty]], ... + ], + module_state: 'ModuleStateTuple', + ): + self._type = type + self._index = index + self._submodules = submodules + self._static_fields = static_fields + self._variables = variables + self._module_state = module_state + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + + yield reprlib.Attr('type', self._type.__name__) + yield reprlib.Attr('index', self._index) + yield reprlib.Attr('static_fields', self._static_fields) + yield reprlib.Attr('variables', self._variables) + yield reprlib.Attr('submodules', _SubmodulesRepr(self._submodules)) + + def __hash__(self) -> int: + return hash( + (self._type, self._submodules, self._static_fields, self._variables) + ) + + def __eq__(self, other: tp.Any) -> bool: + if not isinstance(other, ModuleDef): + return False + return ( + self._type == other._type + and self._submodules == other._submodules + and self._static_fields == other._static_fields + ) + + @property + def type(self) -> tp.Type[M]: + return self._type + + @property + def index(self) -> int: + return self._index + + @property + def submodules( + self, + ) -> tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...]: + return self._submodules + + @property + def static_fields(self) -> tuple[tuple[str, tp.Any], ...]: + return self._static_fields + + @property + def variables( + self, + ) -> tuple[tuple[str, variableslib.Variable[variableslib.Empty]], ...]: + return self._variables + + @property + def module_state(self) -> 'ModuleStateTuple': + return self._module_state + + def make_module(self) -> M: + return _build_module(self) + + def merge(self, state: State, *states: State) -> M: + states = (state, *states) + module = self.make_module() + _update_module_dynamic_state(module, states) + return module + + def apply( + self, state: State, *states: State + ) -> ApplyCaller[tuple[State, 'ModuleDef[M]']]: + accessesor = DelayedAccessor() + + def _apply( + accessesor, *args, **kwargs + ) -> tuple[tp.Any, tuple[State, ModuleDef[M]]]: + module = self.merge(state, *states) + fn = accessesor(module) + out = fn(*args, **kwargs) + return out, module.split() + + return CallableProxy(_apply, accessesor) # type: ignore + + +def _moddef_flatten(moduledef: ModuleDef[M]): + return (), ( + moduledef._type, + moduledef._index, + moduledef._submodules, + moduledef._static_fields, + moduledef._variables, + moduledef._module_state, + ) + + +def _moddef_unflatten( + metadata: tuple[ + tp.Type[M], + int, + tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...], + tuple[tuple[str, tp.Any], ...], + tuple[tuple[str, variableslib.Variable[variableslib.Empty]], ...], + 'ModuleStateTuple', + ], + _, +) -> ModuleDef[M]: + return ModuleDef(*metadata) + + +jtu.register_pytree_node(ModuleDef, _moddef_flatten, _moddef_unflatten) + + +SEEN_MODULES_REPR: tp.Optional[tp.Set[ids.UUID]] = None + +ModuleStateTuple = tuple[()] + + +class ModuleState(reprlib.Representable): + __slots__ = ('_trace_state', '_id') + + def __init__(self): + self._trace_state = tracers.TraceState() + self._id = ids.uuid() + + @property + def trace_state(self) -> tracers.TraceState: + return self._trace_state + + @property + def id(self) -> ids.UUID: + return self._id + + def to_tuple(self) -> ModuleStateTuple: + return () + + @classmethod + def from_tuple(cls, tup: ModuleStateTuple) -> 'ModuleState': + return cls(*tup) + + def __nnx_repr__(self): + yield reprlib.Object(type(self)) + yield reprlib.Attr('trace_state', self._trace_state) + + +class ModuleMeta(ABCMeta): + if not tp.TYPE_CHECKING: + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._meta_call(*args, **kwargs) + + def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M: + module = cls.__new__(cls, *args, **kwargs) + vars(module)['_module__state'] = ModuleState() + module.__init__(*args, **kwargs) + + if dataclasses.is_dataclass(module): + if isinstance(module, _HasSetup): + module.setup() + + assert isinstance(module, Module) + + for field in dataclasses.fields(module): + value = vars(module)[field.name] + # set Rngs instances to None + if isinstance(value, Rngs): + vars(module)[field.name] = None + continue + + if 'nnx_variable_constructor' not in field.metadata: + continue + + variable_constructor = field.metadata['nnx_variable_constructor'] + vars(module)[field.name] = variable_constructor(value) + + return module + + +tuple_reduce = lambda xs, x: xs + (x,) +tuple_init = lambda: () + +Updates = tp.Union[ + M, + ModuleDef[M], + tuple[State, ModuleDef[M]], + tuple[tuple[State, ...], ModuleDef[M]], + State, + tuple[State, ...], +] + + +class Module(reprlib.Representable, metaclass=ModuleMeta): + if tp.TYPE_CHECKING: + _module__state: ModuleState + + if not tp.TYPE_CHECKING: + + def __getattribute__(self, name: str) -> Any: + value = object.__getattribute__(self, name) + if isinstance(value, Variable): + return value.get_value() + return value + + def __setattr__(self, name: str, value: Any) -> None: + self._setattr(name, value) + + def _setattr(self, name: str, value: Any) -> None: + if not self._module__state.trace_state.is_valid(): + raise errors.TraceContextError( + 'Cannot mutate Module from different trace level' + ) + + vars_dict = vars(self) + if name in vars_dict and isinstance(vars_dict[name], Variable): + vars_dict[name] = vars_dict[name].set_value(value) + else: + if isinstance(value, Variable): + value = value.copy() + elif isinstance(value, (jax.Array, np.ndarray, State)): + raise ValueError( + f"Trying to assing a '{type(value).__name__}' to the Module" + f" attribute '{name}'. This is not supported. Non-hashable " + 'objects are not valid static state in JAX. Please wrap ' + 'the value in a Variable type instead.' + ) + vars_dict[name] = value + + def __deepcopy__(self: M, memo=None) -> M: + state, moduledef = self.split() + moduledef = deepcopy(moduledef) + state = deepcopy(state) + return moduledef.merge(state) + + def __hash__(self) -> int: + return hash(self._module__state.id) + + def __nnx_repr__(self): + global SEEN_MODULES_REPR + + if SEEN_MODULES_REPR is None: + SEEN_MODULES_REPR = set() + clear_seen = True + else: + clear_seen = False + + if self._module__state.id in SEEN_MODULES_REPR: + yield reprlib.Object(type=type(self), empty_repr='...') + return + + yield reprlib.Object(type=type(self)) + SEEN_MODULES_REPR.add(self._module__state.id) + + try: + for name, value in vars(self).items(): + if isinstance(value, Module) or ( + not isinstance(value, Variable) and not name.startswith('_') + ): + yield reprlib.Attr(name, value) + finally: + if clear_seen: + SEEN_MODULES_REPR = None + + @classmethod + def init(cls: type[M], *args, **kwargs) -> tuple[State, ModuleDef[M]]: + return cls(*args, **kwargs).split() + + @classmethod + @property + def create_abstract(cls: type[M]) -> type[M]: + accessesor = DelayedAccessor() + + def lift_rngs(kwargs: dict[str, tp.Any]): + if 'rngs' in kwargs and isinstance(kwargs['rngs'], Rngs): + kwargs['rngs'] = kwargs['rngs'].copy() + return kwargs + + def _create_abstract(accessesor, *args, **kwargs): + constructor = accessesor(cls) + state, moduledef = jax.eval_shape( + lambda: constructor(*args, **lift_rngs(kwargs)).split() + ) + return moduledef.merge(state) + + return CallableProxy(_create_abstract, accessesor) # type: ignore + + def clone(self: M) -> M: + return merge(self.split()) + + @tp.overload + def split(self: M) -> tuple[State, ModuleDef[M]]: + ... + + @tp.overload + def split(self: M, first: filterlib.Filter, /) -> tuple[State, ModuleDef[M]]: + ... + + @tp.overload + def split( + self: M, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[State, tpe.Unpack[tuple[State, ...]], ModuleDef[M]]: + ... + + def split( + self: M, *filters: filterlib.Filter + ) -> tuple[State, tpe.Unpack[tuple[State, ...]], ModuleDef[M]]: + moduledef = self.get_moduledef() + state = self.get_state() + + if len(filters) == 0: + states = (state,) + elif len(filters) == 1: + states = (state.split(filters[0]),) + else: + states = state.split(filters[0], filters[1], *filters[2:]) + + return *states, moduledef + + def get_state(self) -> State: + return State(_iter_state(self)) + + def get_moduledef(self: M) -> ModuleDef[M]: + module_index: tp.Dict[ids.UUID, int] = {} + path: PathParts = () + moduledef = _make_moduledef_recursive(self, module_index, path) + assert isinstance(moduledef, ModuleDef) + return moduledef + + @tp.overload + def extract(self, first: filterlib.Filter, /) -> State: + ... + + @tp.overload + def extract( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[State, ...]: + ... + + def extract( + self, + first: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Union[State, tuple[State, ...]]: + state = self.get_state() + + if len(filters) == 0: + states = state.extract(first) + else: + states = state.extract(first, filters[0], *filters[1:]) + + return states + + @tp.overload + def pop( + self, + filter: filterlib.Filter, + /, + ) -> State: + ... + + @tp.overload + def pop( + self, + filter: filterlib.Filter, + filter2: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[State, ...]: + ... + + def pop( + self, *filters: filterlib.Filter + ) -> tp.Union[State, tuple[State, ...]]: + if len(filters) == 0: + raise ValueError('Expected at least one filter') + + states = _pop(self, filters) + + if len(states) == 1: + return states[0] + else: + return states + + @property + def apply(self: M) -> ApplyCaller[M]: + accessesor = DelayedAccessor() + + def _apply(accessesor, *args, **kwargs) -> tuple[tp.Any, M]: + module = self.clone() + fn = accessesor(module) + out = fn(*args, **kwargs) + return out, module + + return CallableProxy(_apply, accessesor) # type: ignore + + def update(self: M, update: Updates[M], *updates: Updates[M]) -> None: + updates = (update, *updates) + + def _states_and_moduledef( + updates, + ) -> tuple[list[State], tp.Optional[Module]]: + leaves = jax.tree_util.tree_leaves( + updates, is_leaf=lambda x: isinstance(x, (ModuleDef, State)) + ) + states: list[State] = [] + module: tp.Optional[Module] = None + + for leaf in leaves: + if isinstance(leaf, (Module, ModuleDef)): + if module is not None: + raise ValueError( + 'Expected only one ModuleDef or Module in the updates' + ) + + if isinstance(leaf, Module): + module = leaf + states.append(leaf.get_state()) + else: + module = leaf.make_module() + elif isinstance(leaf, State): + states.append(leaf) + else: + raise ValueError( + 'Expected a ModuleDef, Module or State, got' + f' {type(leaf).__name__}' + ) + + return states, module + + states, module_update = _states_and_moduledef(updates) + + if module_update is not None: + _update_module_static_state(self, module_update) + + if states: + _update_module_dynamic_state(self, states) + + def sow( + self, + variable_type: tp.Type[variableslib.Variable[tp.Any]], + name: str, + value: A, + reduce_fn: tp.Callable[[B, A], B] = tuple_reduce, + init_fn: tp.Callable[[], B] = tuple_init, # type: ignore + ) -> None: + if hasattr(self, name): + variable = vars(self)[name] + if not isinstance(variable, variableslib.Variable): + raise ValueError( + f"Expected '{name}' to be a Variable, got {type(variable).__name__}" + ) + elif type(variable) != variable_type: + raise ValueError( + f"Expected '{name}' to be of type '{variable_type.__name__}', " + f"got '{type(variable).__name__}'" + ) + reduced_value = reduce_fn(variable.value, value) + setattr(self, name, reduced_value) + else: + reduced_value = reduce_fn(init_fn(), value) + setattr(self, name, variable_type(reduced_value)) + + def for_each( + self, module_type: tp.Type[M], fn: tp.Callable[[M], None] + ) -> None: + visited: tp.Set[ids.UUID] = set() + self._on_all(module_type, fn, visited) + + def _on_all( + self, + module_type: tp.Type[M], + fn: tp.Callable[[M], None], + visited: tp.Set[ids.UUID], + ) -> None: + if self._module__state.id in visited: + return + + visited.add(self._module__state.id) + + if isinstance(self, module_type): + fn(self) + + for value in vars(self).values(): + if isinstance(value, Module): + value._on_all(module_type, fn, visited) + + def __init_subclass__(cls, experimental_pytree: bool = False) -> None: + super().__init_subclass__() + + if experimental_pytree: + jtu.register_pytree_with_keys( + cls, + partial(_module_flatten, with_keys=True), + _module_unflatten, + flatten_func=partial(_module_flatten, with_keys=False), + ) + + +# Pytree Definition +def _module_flatten(module: Module, *, with_keys: bool): + state, moduledef = module.split() + variables = state.variables + paths = tuple(variables.keys()) + + if with_keys: + children = tuple( + (jtu.DictKey(path), variable) for path, variable in variables.items() + ) + else: + children = tuple(variables.values()) + + return children, (paths, moduledef) + + +def _module_unflatten( + paths_moduledef: tuple[tuple[Path, ...], ModuleDef[M]], + variables: tuple[Variable[tp.Any], ...], +) -> M: + paths, moduledef = paths_moduledef + return moduledef.merge(State(zip(paths, variables))) + + +def _make_moduledef_recursive( + module: M, + module_index: tp.Dict[ids.UUID, int], + path: PathParts, +) -> tp.Union[ModuleDef[M], int]: + if module._module__state.id in module_index: + return module_index[module._module__state.id] + + index = len(module_index) + module_index[module._module__state.id] = index + + submodules = [] + static_fields = [] + variables = [] + + for name, value in sorted(vars(module).items(), key=lambda x: x[0]): + value_path = (*path, name) + if isinstance(value, Module): + submodule_def = _make_moduledef_recursive(value, module_index, value_path) + submodules.append((name, submodule_def)) + elif isinstance(value, variableslib.Variable): + variables.append((name, value.as_empty())) + elif not name.startswith('_module__'): + static_fields.append((name, value)) + + module_def = ModuleDef( + type=type(module), + index=index, + submodules=tuple(submodules), + static_fields=tuple(static_fields), + variables=tuple(variables), + module_state=module._module__state.to_tuple(), + ) + return module_def + + +def _iter_state(module: Module) -> tp.Iterator[tuple[Path, tp.Any]]: + seen_modules: tp.Set[ids.UUID] = set() + path_parts: PathParts = () + + yield from _iter_state_recursive(module, seen_modules, path_parts) + + +def _iter_state_recursive( + module: Module, seen_modules: tp.Set[ids.UUID], path_parts: PathParts +) -> tp.Iterator[tuple[Path, tp.Any]]: + if module._module__state.id in seen_modules: + return + + seen_modules.add(module._module__state.id) + + for name, value in sorted(vars(module).items(), key=lambda x: x[0]): + new_path_parts = (*path_parts, name) + if isinstance(value, Module): + yield from _iter_state_recursive(value, seen_modules, new_path_parts) + elif isinstance(value, variableslib.Variable): + if value.is_empty: + # skip empty Variables + continue + path = '/'.join(new_path_parts) + yield path, value + + +def _set_value_at_path( + module: tp.Any, path_parts: tp.Union[PathParts, tp.List[str]], value: tp.Any +): + if len(path_parts) == 1: + setattr(module, path_parts[0], value) + else: + _set_value_at_path(vars(module)[path_parts[0]], path_parts[1:], value) + + +def _get_value_path(module: tp.Any, path: tp.Sequence[str]) -> tp.Any: + if len(path) == 0: + return module + else: + return _get_value_path(vars(module)[path[0]], path[1:]) + + +def _build_module(moduledef: ModuleDef[M]) -> M: + index_module: tp.Dict[int, Module] = {} + module = _build_module_recursive(moduledef, index_module) + return module + + +def _build_module_recursive( + moduledef: tp.Union[ModuleDef[M], int], + index_module: tp.Dict[int, Module], +) -> M: + if isinstance(moduledef, int): + return index_module[moduledef] # type: ignore + + assert moduledef.index not in index_module + + # add a dummy module to the index to avoid infinite recursion + module = object.__new__(moduledef.type) + index_module[moduledef.index] = module + + submodules = { + name: _build_module_recursive(submodule, index_module) + for name, submodule in moduledef.submodules + } + + vars(module).update(moduledef.static_fields) + vars(module).update(moduledef.variables) + vars(module).update(submodules) + vars(module)['_module__state'] = ModuleState.from_tuple( + moduledef.module_state + ) + + return module + + +def _pop( + module: Module, + filters: tuple[filterlib.Filter, ...], +) -> tuple[State, ...]: + module_index: tp.Dict[ids.UUID, int] = {} + path_parts: PathParts = () + predicates = tuple(filterlib.to_predicate(filter) for filter in filters) + states = tuple({} for _ in predicates) + _pop_recursive(module, module_index, path_parts, states, predicates) + + return tuple(State(x) for x in states) + + +def _pop_recursive( + module: Module, + module_index: tp.Dict[ids.UUID, int], + path_parts: PathParts, + states: tuple[tp.Dict[Path, tp.Any]], + predicates: tuple[filterlib.Predicate, ...], +) -> None: + if module._module__state.id in module_index: + return + + for name, value in list(vars(module).items()): + if isinstance(value, Module): + _pop_recursive( + value, module_index, (*path_parts, name), states, predicates + ) + continue + elif not isinstance(value, Variable): + continue + elif value.is_empty: + continue + + path = '/'.join((*path_parts, name)) + for state, predicate in zip(states, predicates): + if predicate(path, value): + state[path] = value + # empty Variable attributes + setattr(module, name, value.as_empty()) + break + else: + # NOTE: should we raise an error here? + pass + + module_index[module._module__state.id] = len(module_index) + + +def _update_module_dynamic_state( + module: Module, + updates: tp.Union[State, tp.Sequence[State]], +) -> None: + if isinstance(updates, State): + new_states = (updates,) + else: + new_states = updates + + state: StateDict = {} + for new_state in new_states: + state.update(new_state.variables) + + for path, value in state.items(): + path_parts = path.split('/') + _set_value_at_path(module, path_parts, value) + + +# _StaticSubmoduleState = tp.Literal["new", "updated"] +class _StaticModuleStatus(enum.Enum): + NEW = enum.auto() + UPDATED = enum.auto() + + +def _update_module_static_state(module: M, updates: M) -> None: + cache: dict[Module, _StaticModuleStatus] = {} + _update_module_static_state_recursive( + module, updates, cache, _StaticModuleStatus.UPDATED, () + ) + + +def _update_module_static_state_recursive( + module: M, + updates: M, + cache: dict[Module, _StaticModuleStatus], + status: _StaticModuleStatus, + path: PathParts, +) -> None: + if type(module) != type(updates): + raise ValueError( + f'Expected an instance of {type(module).__name__}, got' + f' {type(updates).__name__}' + ) + + if updates in cache: + if cache[updates] != status: + str_path = '/'.join(path) + if status is _StaticModuleStatus.NEW: + raise ValueError( + f'Trying to add a new submodule at path {str_path!r} but a' + ' submodule with the same reference has been updated' + ) + else: + raise ValueError( + f'Trying to update a submodule at path {str_path!r} but a new' + ' submodule with the same reference has been added' + ) + return + + cache[updates] = status + + module_vars = vars(module) + for name, value in vars(updates).items(): + if isinstance(value, variableslib.Variable): + continue + elif isinstance(value, Module): + if name in module_vars: + _update_module_static_state_recursive( + module_vars[name], + value, + cache, + _StaticModuleStatus.UPDATED, + (*path, name), + ) + else: + if value in cache: + if cache[value] is not _StaticModuleStatus.NEW: + raise ValueError( + f'Trying to add a new submodule at path {name!r} but a' + ' submodule with the same reference has been updated' + ) + else: + cache[value] = _StaticModuleStatus.NEW + + setattr(module, name, value) + else: # static field + setattr(module, name, value) + + +def first_from(*args: tp.Optional[A]) -> A: + """Return the first non-None argument.""" + for arg in args: + if arg is not None: + return arg + raise ValueError('No non-None arguments found.') + + +def merge( + state_and_def: tuple[tpe.Unpack[tuple[State, ...]], ModuleDef[M]] +) -> M: + *states, moduledef = state_and_def + return moduledef.merge(*states) diff --git a/flax/experimental/nnx/nnx/nn/__init__.py b/flax/experimental/nnx/nnx/nn/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/nnx/nn/activations.py b/flax/experimental/nnx/nnx/nn/activations.py new file mode 100644 index 0000000000..55fee03d5a --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/activations.py @@ -0,0 +1,69 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.nn import ( + celu, + elu, + gelu, + glu, + hard_sigmoid, + hard_silu, + hard_swish, + hard_tanh, + leaky_relu, + log_sigmoid, + log_softmax, + logsumexp, + normalize, + one_hot, + relu, + relu6, + selu, + sigmoid, + silu, + soft_sign, + softmax, + softplus, + standardize, + swish, +) +from jax.numpy import tanh + +__all__ = [ + 'celu', + 'elu', + 'gelu', + 'glu', + 'hard_sigmoid', + 'hard_silu', + 'hard_swish', + 'hard_tanh', + 'leaky_relu', + 'log_sigmoid', + 'log_softmax', + 'logsumexp', + 'normalize', + 'one_hot', + 'relu', + 'relu6', + 'selu', + 'sigmoid', + 'silu', + 'soft_sign', + 'softmax', + 'softplus', + 'standardize', + 'swish', + 'tanh', +] diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/experimental/nnx/nnx/nn/dtypes.py new file mode 100644 index 0000000000..0d9beab834 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/dtypes.py @@ -0,0 +1,80 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional + +from jax import numpy as jnp + +Dtype = Any +Array = Any + + +def canonicalize_dtype( + *args, dtype: Optional[Dtype] = None, inexact: bool = True +) -> Dtype: + """Canonicalize an optional dtype to the definitive dtype. + + If the ``dtype`` is None this function will infer the dtype. If it is not + None it will be returned unmodified or an exceptions is raised if the dtype + is invalid. + from the input arguments using ``jnp.result_type``. + + Args: + *args: JAX array compatible values. None values + are ignored. + dtype: Optional dtype override. If specified the arguments are cast to + the specified dtype instead and dtype inference is disabled. + inexact: When True, the output dtype must be a subdtype + of `jnp.inexact`. Inexact dtypes are real or complex floating points. This + is useful when you want to apply operations that don't work directly on + integers like taking a mean for example. + Returns: + The dtype that *args should be cast to. + """ + if dtype is None: + args_filtered = [jnp.asarray(x) for x in args if x is not None] + dtype = jnp.result_type(*args_filtered) + if inexact and not jnp.issubdtype(dtype, jnp.inexact): + dtype = jnp.promote_types(jnp.float32, dtype) + if inexact and not jnp.issubdtype(dtype, jnp.inexact): + raise ValueError(f'Dtype must be inexact: {dtype}') + return dtype + + +def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]: + """ "Promotes input arguments to a specified or inferred dtype. + + All args are cast to the same dtype. See ``canonicalize_dtype`` for how + this dtype is determined. + + The behavior of promote_dtype is mostly a convinience wrapper around + ``jax.numpy.promote_types``. The differences being that it automatically casts + all input to the inferred dtypes, allows inference to be overridden by a + forced dtype, and has an optional check to garantuee the resulting dtype is + inexact. + + Args: + *args: JAX array compatible values. None values + are returned as is. + dtype: Optional dtype override. If specified the arguments are cast to + the specified dtype instead and dtype inference is disabled. + inexact: When True, the output dtype must be a subdtype + of `jnp.inexact`. Inexact dtypes are real or complex floating points. This + is useful when you want to apply operations that don't work directly on + integers like taking a mean for example. + Returns: + The arguments cast to arrays of the same dtype. + """ + dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) + return [jnp.asarray(x, dtype) if x is not None else None for x in args] diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/experimental/nnx/nnx/nn/initializers.py new file mode 100644 index 0000000000..0e989c80ca --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/initializers.py @@ -0,0 +1,73 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax.nn.initializers import constant as constant +from jax.nn.initializers import delta_orthogonal as delta_orthogonal +from jax.nn.initializers import glorot_normal as glorot_normal +from jax.nn.initializers import glorot_uniform as glorot_uniform +from jax.nn.initializers import he_normal as he_normal +from jax.nn.initializers import he_uniform as he_uniform +from jax.nn.initializers import kaiming_normal as kaiming_normal +from jax.nn.initializers import kaiming_uniform as kaiming_uniform +from jax.nn.initializers import lecun_normal as lecun_normal +from jax.nn.initializers import lecun_uniform as lecun_uniform +from jax.nn.initializers import normal as normal +from jax.nn.initializers import orthogonal as orthogonal +from jax.nn.initializers import uniform as uniform +from jax.nn.initializers import variance_scaling as variance_scaling +from jax.nn.initializers import xavier_normal as xavier_normal +from jax.nn.initializers import xavier_uniform as xavier_uniform + +Shape = tp.Sequence[int] +DTypeLikeInexact = tp.Any +Array = jax.Array + + +class Initializer(tp.Protocol): + @staticmethod + def __call__( + key: Array, shape: Shape, dtype: DTypeLikeInexact = jnp.float_ + ) -> Array: + ... + + +def zeros() -> Initializer: + """Builds an initializer that returns a constant array full of zeros. + + >>> import jax, jax.numpy as jnp + >>> from flax.linen.initializers import zeros_init + >>> zeros_initializer = zeros_init() + >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + """ + return jax.nn.initializers.zeros + + +def ones() -> Initializer: + """Builds an initializer that returns a constant array full of ones. + + >>> import jax, jax.numpy as jnp + >>> from flax.linen.initializers import ones_init + >>> ones_initializer = ones_init() + >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) + Array([[1., 1.], + [1., 1.], + [1., 1.]], dtype=float32) + """ + return jax.nn.initializers.ones diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py new file mode 100644 index 0000000000..1bedc53605 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -0,0 +1,445 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax + +from flax.experimental import nnx +from flax.experimental.nnx.nnx import rnglib +from flax.experimental.nnx.nnx.module import Module +from flax.experimental.nnx.nnx.nn import dtypes, initializers + +Array = jax.Array +KeyArray = jax.Array +Shape = tp.Tuple[int, ...] +Dtype = tp.Any # this could be a real type? +PrecisionLike = tp.Union[ + None, + str, + lax.Precision, + tp.Tuple[str, str], + tp.Tuple[lax.Precision, lax.Precision], +] +ConvGeneralDilatedT = tp.Callable[..., Array] +PaddingLike = tp.Union[str, int, tp.Sequence[tp.Union[int, tp.Tuple[int, int]]]] +LaxPadding = tp.Union[str, tp.Sequence[tp.Tuple[int, int]]] +DotGeneralT = tp.Callable[..., Array] + + +default_kernel_init = initializers.lecun_normal() + + +def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: + """ "Canonicalizes conv padding to a jax.lax supported format.""" + if isinstance(padding, str): + return padding + if isinstance(padding, int): + return [(padding, padding)] * rank + if isinstance(padding, tp.Sequence) and len(padding) == rank: + new_pad = [] + for p in padding: + if isinstance(p, int): + new_pad.append((p, p)) + elif isinstance(p, tuple) and len(p) == 2: + new_pad.append(p) + else: + break + if len(new_pad) == rank: + return new_pad + raise ValueError( + f'Invalid padding format: {padding}, should be str, int,' + f' or a sequence of len {rank} where each element is an' + ' int or pair of ints.' + ) + + +def _conv_dimension_numbers(input_shape): + """Computes the dimension numbers based on the input shape.""" + ndim = len(input_shape) + lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) + rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) + out_spec = lhs_spec + return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) + + +class Linear(Module): + """A linear transformation applied over the last dimension of the input. + + Attributes: + features: the number of output features. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + """ + + def __init__( + self, + in_features: int, + out_features: int, + *, + use_bias: bool = True, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + precision: PrecisionLike = None, + kernel_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = default_kernel_init, + bias_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.zeros(), + dot_general: DotGeneralT = lax.dot_general, + rngs: rnglib.Rngs, + ): + kernel_key = rngs.params() + self.kernel = nnx.Param( + kernel_init(kernel_key, (in_features, out_features), param_dtype) + ) + if use_bias: + bias_key = rngs.params() + self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype)) + else: + self.bias = nnx.Param(None) + + self.in_features = in_features + self.out_features = out_features + self.use_bias = use_bias + self.dtype = dtype + self.param_dtype = param_dtype + self.precision = precision + self.kernel_init = kernel_init + self.bias_init = bias_init + self.dot_general = dot_general + + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along the last dimension. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + kernel = self.kernel + bias = self.bias + + inputs, kernel, bias = dtypes.promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) + y = self.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + if bias is not None: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + return y + + +class Conv(Module): + """Convolution Module wrapping `lax.conv_general_dilated[_local]`. + + Attributes: + features: number of convolution filters. + kernel_size: shape of the convolutional kernel. For 1D convolution, + the kernel size can be passed as an integer. For all other cases, it must + be a sequence of integers. + strides: an integer or a sequence of `n` integers, representing the + inter-window strides (default: 1). + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. A single int is interpeted as applying the same padding + in all dims and passign a single int in a sequence causes the same padding + to be used on both sides. `'CAUSAL'` padding for a 1D convolution will + left-pad the convolution axis, resulting in same-sized output. + input_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + kernel_dilation: an integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + feature_group_count: integer, default 1. If specified divides the input + features into groups. + use_bias: whether to add a bias to the output (default: True). + mask: Optional mask for the weights during masked convolution. The mask must + be the same shape as the convolution weight matrix. + dtype: the dtype of the computation (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the convolutional kernel. + bias_init: initializer for the bias. + """ + + def __init__( + self, + in_features: int, + out_features: int, + kernel_size: tp.Sequence[int], + strides: tp.Union[None, int, tp.Sequence[int]] = 1, + *, + padding: PaddingLike = 'SAME', + input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, + kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, + feature_group_count: int = 1, + use_bias: bool = True, + mask_fn: tp.Optional[tp.Callable[[Array], Array]] = None, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + precision: PrecisionLike = None, + kernel_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = default_kernel_init, + bias_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.zeros(), + conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated, + rngs: rnglib.Rngs, + ): + if isinstance(kernel_size, int): + raise TypeError( + 'Expected Conv kernel_size to be a' + ' tuple/list of integers (eg.: [3, 3]) but got' + f' {kernel_size}.' + ) + else: + kernel_size = tuple(kernel_size) + + kernel_shape = kernel_size + ( + in_features // feature_group_count, + out_features, + ) + kernel_key = rngs.params() + self.kernel = nnx.Param(kernel_init(kernel_key, kernel_shape, param_dtype)) + + if use_bias: + bias_shape = (out_features,) + bias_key = rngs.params() + self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.in_features = in_features + self.out_features = out_features + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.input_dilation = input_dilation + self.kernel_dilation = kernel_dilation + self.feature_group_count = feature_group_count + self.use_bias = use_bias + self.mask_fn = mask_fn + self.dtype = dtype + self.param_dtype = param_dtype + self.precision = precision + self.kernel_init = kernel_init + self.bias_init = bias_init + self.conv_general_dilated = conv_general_dilated + + def __call__(self, inputs: Array) -> Array: + """Applies a (potentially unshared) convolution to the inputs. + + Args: + inputs: input data with dimensions (*batch_dims, spatial_dims..., + features). This is the channels-last convention, i.e. NHWC for a 2d + convolution and NDHWC for a 3D convolution. Note: this is different from + the input convention used by `lax.conv_general_dilated`, which puts the + spatial dimensions last. + Note: If the input has more than 1 batch dimension, all batch dimensions + are flattened into a single dimension for the convolution and restored + before returning. In some cases directly vmap'ing the layer may yield + better performance than this default flattening approach. If the input + lacks a batch dimension it will be added for the convolution and removed + n return, an allowance made to enable writing single-example code. + + Returns: + The convolved data. + """ + + assert isinstance(self.kernel_size, tuple) + kernel_size = self.kernel_size + + def maybe_broadcast( + x: tp.Optional[tp.Union[int, tp.Sequence[int]]] + ) -> tp.Tuple[int, ...]: + if x is None: + # backward compatibility with using None as sentinel for + # broadcast 1 + x = 1 + if isinstance(x, int): + return (x,) * len(kernel_size) + return tuple(x) + + # Combine all input batch dimensions into a single leading batch axis. + num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) + if num_batch_dimensions != 1: + input_batch_shape = inputs.shape[:num_batch_dimensions] + total_batch_size = int(np.prod(input_batch_shape)) + flat_input_shape = (total_batch_size,) + inputs.shape[ + num_batch_dimensions: + ] + inputs = jnp.reshape(inputs, flat_input_shape) + + # self.strides or (1,) * (inputs.ndim - 2) + strides = maybe_broadcast(self.strides) + input_dilation = maybe_broadcast(self.input_dilation) + kernel_dilation = maybe_broadcast(self.kernel_dilation) + + padding_lax = canonicalize_padding(self.padding, len(kernel_size)) + if padding_lax == 'CIRCULAR': + kernel_size_dilated = [ + (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) + ] + zero_pad: tp.List[tp.Tuple[int, int]] = [(0, 0)] + pads = ( + zero_pad + + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + + [(0, 0)] + ) + inputs = jnp.pad(inputs, pads, mode='wrap') + padding_lax = 'VALID' + elif padding_lax == 'CAUSAL': + if len(kernel_size) != 1: + raise ValueError( + 'Causal padding is only implemented for 1D convolutions.' + ) + left_pad = kernel_dilation[0] * (kernel_size[0] - 1) + pads = [(0, 0), (left_pad, 0), (0, 0)] + inputs = jnp.pad(inputs, pads) + padding_lax = 'VALID' + + dimension_numbers = _conv_dimension_numbers(inputs.shape) + + # One shared convolutional kernel for all pixels in the output. + assert self.in_features % self.feature_group_count == 0 + + kernel = self.kernel + + if self.mask_fn is not None: + kernel = self.mask_fn(kernel) + + bias = self.bias + + inputs, kernel, bias = dtypes.promote_dtype( + inputs, kernel, bias, dtype=self.dtype + ) + + y = self.conv_general_dilated( + inputs, + kernel, + strides, + padding_lax, + lhs_dilation=input_dilation, + rhs_dilation=kernel_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=self.feature_group_count, + precision=self.precision, + ) + + if self.use_bias: + bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) + y += bias + + if num_batch_dimensions != 1: + output_shape = input_batch_shape + y.shape[1:] + y = jnp.reshape(y, output_shape) + return y + + +default_embed_init = initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 +) + + +class Embed(Module): + """Embedding Module. + + A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: same as embedding). + param_dtype: the dtype passed to parameter initializers (default: float32). + embedding_init: embedding initializer. + """ + + def __init__( + self, + num_embeddings: int, + features: int, + *, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + embedding_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = default_embed_init, + rngs: rnglib.Rngs, + ): + self.embedding = nnx.Param( + embedding_init(rngs.params(), (num_embeddings, features), param_dtype) + ) + + self.num_embeddings = num_embeddings + self.features = features + self.dtype = dtype or self.embedding.dtype + self.param_dtype = param_dtype + self.embedding_init = embedding_init + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError('Input type must be an integer or unsigned integer.') + # Use take because fancy indexing numpy arrays with JAX indices does not + # work correctly. + (embedding,) = dtypes.promote_dtype( + self.embedding, dtype=self.dtype, inexact=False + ) + return jnp.take(embedding, inputs, axis=0) + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + query, embedding = dtypes.promote_dtype( + query, self.embedding, dtype=self.dtype + ) + return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py new file mode 100644 index 0000000000..1df89eeb9f --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -0,0 +1,401 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +from jax import lax + +from flax.experimental import nnx +from flax.experimental.nnx.nnx import flaglib, rnglib +from flax.experimental.nnx.nnx.module import Module, first_from +from flax.experimental.nnx.nnx.nn import dtypes, initializers + +KeyArray = jax.Array +Array = jax.Array +Shape = tp.Tuple[int, ...] +Dtype = tp.Any # this could be a real type? + +Axes = tp.Union[int, tp.Any] + + +def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]: + """Returns a tuple of deduplicated, sorted, and positive axes.""" + if not isinstance(axes, tp.Iterable): + axes = (axes,) + return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) + + +def _abs_sq(x): + """Computes the elementwise square of the absolute value |x|^2.""" + if jnp.iscomplexobj(x): + return lax.square(lax.real(x)) + lax.square(lax.imag(x)) + else: + return lax.square(x) + + +def _compute_stats( + x: Array, + axes: tp.Optional[Axes], + dtype: tp.Optional[Dtype], + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + use_mean: bool = True, +): + """Computes mean and variance statistics. + + This implementation takes care of a few important details: + - Computes in float32 precision for stability in half precision training. + - mean and variance are computable in a single XLA fusion, + by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). + - Clips negative variances to zero which can happen due to + roundoff errors. This avoids downstream NaNs. + - Supports averaging across a parallel axis and subgroups of a parallel axis + with a single `lax.pmean` call to avoid latency. + + Arguments: + x: Input array. + axes: The axes in ``x`` to compute mean and variance statistics for. + dtype: tp.Optional dtype specifying the minimal precision. Statistics + are always at least float32 for stability (default: dtype of x). + axis_name: tp.Optional name for the pmapped axis to compute mean over. + axis_index_groups: tp.Optional axis indices. + use_mean: If true, calculate the mean from the input and use it when + computing the variance. If false, set the mean to zero and compute + the variance without subtracting the mean. + + Returns: + A pair ``(mean, var)``. + """ + if dtype is None: + dtype = jnp.result_type(x) + # promote x to at least float32, this avoids half precision computation + # but preserves double or complex floating points + dtype = jnp.promote_types(dtype, jnp.float32) + x = jnp.asarray(x, dtype) + + mean2 = jnp.mean(_abs_sq(x), axes) + if use_mean: + mean = jnp.mean(x, axes) + else: + mean = jnp.zeros(mean2.shape, dtype=dtype) + + if axis_name is not None: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + ), + 2, + ) + # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due + # to floating point round-off errors. + var = jnp.maximum(0.0, mean2 - _abs_sq(mean)) + return mean, var + + +def _normalize( + x: Array, + mean: Array, + var: Array, + scale: tp.Optional[Array], + bias: tp.Optional[Array], + reduction_axes: Axes, + feature_axes: Axes, + dtype: Dtype, + epsilon: float, +): + """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + + Arguments: + x: The input. + mean: Mean to use for normalization. + var: Variance to use for normalization. + reduction_axes: The axes in ``x`` to reduce. + feature_axes: Axes containing features. A separate bias and scale is learned + for each specified feature. + dtype: The dtype of the result (default: infer from input and params). + epsilon: Normalization epsilon. + + Returns: + The normalized input. + """ + reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) + feature_axes = _canonicalize_axes(x.ndim, feature_axes) + stats_shape = list(x.shape) + for axis in reduction_axes: + stats_shape[axis] = 1 + mean = mean.reshape(stats_shape) + var = var.reshape(stats_shape) + feature_shape = [1] * x.ndim + reduced_feature_shape = [] + for ax in feature_axes: + feature_shape[ax] = x.shape[ax] + reduced_feature_shape.append(x.shape[ax]) + y = x - mean + mul = lax.rsqrt(var + epsilon) + args = [x] + if scale is not None: + scale = scale.reshape(feature_shape) + mul *= scale + args.append(scale) + y *= mul + if bias is not None: + bias = bias.reshape(feature_shape) + y += bias + args.append(bias) + dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) + return jnp.asarray(y, dtype) + + +class BatchNorm(Module): + """BatchNorm Module. + + Attributes: + use_running_average: if True, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of + the batch statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + def __init__( + self, + num_features: int, + *, + use_running_average: tp.Optional[bool] = None, + axis: int = -1, + momentum: float = 0.99, + epsilon: float = 1e-5, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.zeros(), + scale_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.ones(), + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + rngs: rnglib.Rngs, + ): + feature_shape = (num_features,) + self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32)) + self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32)) + + if use_scale: + key = rngs.params() + self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) + else: + self.scale = nnx.Param(None) + + if use_bias: + key = rngs.params() + self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.num_features = num_features + self.use_running_average = use_running_average + self.axis = axis + self.momentum = momentum + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + def __call__( + self, + x, + use_running_average: tp.Optional[bool] = None, + ): + """Normalizes the input using batch statistics. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + + use_running_average = first_from( + use_running_average, + self.use_running_average, + flaglib.flags.get('use_running_average'), + ) + feature_axes = _canonicalize_axes(x.ndim, self.axis) + reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) + + if use_running_average: + mean, var = self.mean, self.var + else: + mean, var = _compute_stats( + x, + reduction_axes, + dtype=self.dtype, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ) + + self.mean = self.momentum * self.mean + (1 - self.momentum) * mean + self.var = self.momentum * self.var + (1 - self.momentum) * var + + return _normalize( + x, + mean, + var, + self.scale, + self.bias, + reduction_axes, + feature_axes, + self.dtype, + self.epsilon, + ) + + +class LayerNorm(Module): + """Layer normalization (https://arxiv.org/abs/1607.06450). + + LayerNorm normalizes the activations of the layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within + each example close to 0 and the activation standard deviation close to 1. + + Attributes: + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + reduction_axes: Axes for computing normalization statistics. + feature_axes: Feature axes for learned bias and scaling. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + This is only needed if the model is subdivided across devices, i.e. the + array being normalized is sharded across devices within a pmap. + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + def __init__( + self, + num_features: int, + *, + epsilon: float = 1e-6, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.zeros(), + scale_init: tp.Callable[ + [KeyArray, Shape, Dtype], Array + ] = initializers.ones(), + reduction_axes: Axes = -1, + feature_axes: Axes = -1, + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + rngs: rnglib.Rngs, + ): + feature_shape = (num_features,) + + if use_scale: + key = rngs.params() + self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) + else: + self.scale = nnx.Param(None) + + if use_bias: + key = rngs.params() + self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) + else: + self.bias = nnx.Param(None) + + self.num_features = num_features + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.reduction_axes = reduction_axes + self.feature_axes = feature_axes + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + def __call__(self, x): + """Applies layer normalization on the input. + + Args: + x: the inputs + + Returns: + Normalized inputs (the same shape as inputs). + """ + mean, var = _compute_stats( + x, + self.reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + ) + + return _normalize( + x, + mean, + var, + self.scale, + self.bias, + self.reduction_axes, + self.feature_axes, + self.dtype, + self.epsilon, + ) diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py new file mode 100644 index 0000000000..50c02a8d50 --- /dev/null +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -0,0 +1,86 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence + +import jax.numpy as jnp +from jax import lax, random + +from flax.experimental.nnx.nnx import dataclasses as nnx_dataclasses +from flax.experimental.nnx.nnx import flaglib, rnglib +from flax.experimental.nnx.nnx.module import Module, first_from + + +@nnx_dataclasses.dataclass +class Dropout(Module): + """Create a dropout layer. + + Attributes: + rate: the dropout probability. (_not_ the keep rate!) + broadcast_dims: dimensions that will share the same dropout mask + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + rng_collection: the rng collection name to use when requesting an rng key. + """ + + rate: float + broadcast_dims: Sequence[int] = () + deterministic: Optional[bool] = None + rng_collection: str = 'dropout' + + def __call__( + self, + inputs, + *, + deterministic: Optional[bool] = None, + rngs: Optional[rnglib.Rngs] = None, + ): + """Applies a random dropout mask to the input. + + Args: + inputs: the inputs that should be randomly masked. + deterministic: if false the inputs are scaled by `1 / (1 - rate)` and + masked, whereas if true, no mask is applied and the inputs are returned + as is. + + Returns: + The masked inputs reweighted to preserve mean. + """ + deterministic = first_from( + deterministic, + self.deterministic, + flaglib.flags.get('deterministic'), + ) + + if (self.rate == 0.0) or deterministic: + return inputs + + # Prevent gradient NaNs in 1.0 edge-case. + if self.rate == 1.0: + return jnp.zeros_like(inputs) + + if rngs is None: + raise ValueError( + "Dropout needs to generate a random mask but no 'rngs' were provided." + ) + + keep_prob = 1.0 - self.rate + rng = rngs[self.rng_collection]() + broadcast_shape = list(inputs.shape) + for dim in self.broadcast_dims: + broadcast_shape[dim] = 1 + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) + mask = jnp.broadcast_to(mask, inputs.shape) + return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py new file mode 100644 index 0000000000..ee179a0bb0 --- /dev/null +++ b/flax/experimental/nnx/nnx/pytreelib.py @@ -0,0 +1,291 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import importlib.util +import inspect +import typing as tp +from abc import ABCMeta +from copy import copy +from functools import partial +from types import MappingProxyType + +import jax +import numpy as np + +from flax.experimental.nnx.nnx import module as modulelib +from flax.experimental.nnx.nnx import reprlib, variables +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar('A') +P = tp.TypeVar('P', bound='Pytree') + + +class TreeNode(variables.Variable[A]): + pass + + +@contextlib.contextmanager +def _mutable(obj: P) -> tp.Iterator[None]: + vars(obj)['_pytree__is_mutable'] = True + try: + yield + finally: + del vars(obj)['_pytree__is_mutable'] + + +@contextlib.contextmanager +def _initializing(obj: P) -> tp.Iterator[None]: + vars(obj)['_pytree__initializing'] = True + try: + yield + finally: + del vars(obj)['_pytree__initializing'] + + +class PytreeMeta(ABCMeta): + if not tp.TYPE_CHECKING: + + def __call__(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: + return cls.call(*args, **kwargs) + + def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: + obj: P = cls.__new__(cls, *args, **kwargs) + vars(obj)['_pytree__sorted_fields'] = ['_pytree__sorted_fields'] + + with _mutable(obj), _initializing(obj): + obj.__init__(*args, **kwargs) + + if dataclasses.is_dataclass(obj): + assert isinstance(obj, Pytree) + for field in dataclasses.fields(obj): + if 'nnx_variable_constructor' not in field.metadata: + continue + + container_fn = field.metadata['nnx_variable_constructor'] + value = vars(obj)[field.name] + value = container_fn(value) + vars(obj)[field.name] = value + + vars(obj)['_pytree__sorted_fields'] = sorted(vars(obj)) + + return obj + + +class Pytree(reprlib.Representable, metaclass=PytreeMeta): + _pytree__is_mutable: bool + _pytree__class_is_mutable: bool + _pytree__sorted_fields: tp.Tuple[str, ...] + + if not tp.TYPE_CHECKING: + + def __getattribute__(self, name: str) -> tp.Any: + value = object.__getattribute__(self, name) + if isinstance(value, variables.Variable): + return value.value + return value + + def __setattr__(self, name: str, value: tp.Any) -> None: + self._setattr(name, value) + + def _setattr(self: P, name: str, value: tp.Any): + vars_dict = vars(self) + if '_pytree__initializing' in vars_dict: + pass + elif name not in vars_dict: + raise AttributeError(r'Cannot add new fields to an initialized Pytree') + elif ( + '_pytree__is_mutable' not in vars_dict + and not self._pytree__class_is_mutable + ): + raise AttributeError( + f'{type(self)} is immutable, trying to update field {name}' + ) + + if name in vars_dict and isinstance(vars_dict[name], variables.Variable): + vars_dict[name] = vars_dict[name].replace(value=value) + else: + if isinstance(value, variables.Variable): + value = value.copy() + elif isinstance(value, (jax.Array, np.ndarray, State)): + raise ValueError( + f"Trying to assing a '{type(value).__name__}' to the Module" + f" attribute '{name}'. This is not supported. Non-hashable " + 'objects are not valid static state in JAX. Please wrap ' + 'the value in a Variable type instead.' + ) + vars_dict[name] = value + + def __init_subclass__(cls, mutable: bool = False): + super().__init_subclass__() + # init class variables + cls._pytree__is_mutable = False + cls._pytree__class_is_mutable = mutable + + # TODO: clean up this in the future once minimal supported version is 0.4.7 + if hasattr(jax.tree_util, 'register_pytree_with_keys'): + if ( + 'flatten_func' + in inspect.signature(jax.tree_util.register_pytree_with_keys).parameters + ): + jax.tree_util.register_pytree_with_keys( + cls, + partial( + cls._pytree__flatten, + with_key_paths=True, + ), + cls._pytree__unflatten, + flatten_func=partial( + cls._pytree__flatten, + with_key_paths=False, + ), + ) + else: + jax.tree_util.register_pytree_with_keys( + cls, + partial( + cls._pytree__flatten, + with_key_paths=True, + ), + cls._pytree__unflatten, + ) + else: + jax.tree_util.register_pytree_node( + cls, + partial( + cls._pytree__flatten, + with_key_paths=False, + ), + cls._pytree__unflatten, + ) + + # flax serialization support + if importlib.util.find_spec('flax') is not None: + from flax import serialization + + serialization.register_serialization_state( + cls, cls._to_flax_state_dict, cls._from_flax_state_dict + ) + + @classmethod + def _pytree__flatten( + cls, + pytree: 'Pytree', + *, + with_key_paths: bool, + ): + all_vars = vars(pytree) + static = {} + node_values = [] + node_names = [] + + for field in pytree._pytree__sorted_fields: + value = all_vars[field] + + if isinstance(value, (modulelib.Module, variables.Variable, Pytree)): + node_names.append(field) + if with_key_paths: + node_values.append((jax.tree_util.GetAttrKey(field), value)) + else: + node_values.append(value) + else: + static[field] = value + + return node_values, (tuple(node_names), MappingProxyType(static)) + + @classmethod + def _pytree__unflatten( + cls: tp.Type[P], + metadata: tp.Tuple[tp.Tuple[str, ...], tp.Mapping[str, tp.Any]], + node_values: tp.Tuple[tp.Any, ...], + ) -> P: + node_names, static_fields = metadata + pytree = object.__new__(cls) + pytree.__dict__.update(zip(node_names, node_values)) + pytree.__dict__.update(static_fields) + return pytree + + @classmethod + def _to_flax_state_dict(cls, pytree: 'Pytree') -> tp.Dict[str, tp.Any]: + from flax import serialization + + state_dict = { + name: serialization.to_state_dict(getattr(pytree, name)) + for name, value in vars(pytree).items() + if isinstance(value, (modulelib.Module, variables.Variable, Pytree)) + } + return state_dict + + @classmethod + def _from_flax_state_dict( + cls, + pytree: P, + state: tp.Dict[str, tp.Any], + ) -> P: + """Restore the state of a data class.""" + from flax import serialization + + state = state.copy() # copy the state so we can pop the restored fields. + updates = {} + for name, value in vars(pytree).items(): + if not isinstance(value, (modulelib.Module, variables.Variable, Pytree)): + continue + if name not in state: + raise ValueError( + f'Missing field {name} in state dict while restoring' + f' an instance of {type(pytree).__name__},' + f' at path {serialization.current_path()}' + ) + value_state = state.pop(name) + updates[name] = serialization.from_state_dict( + value, value_state, name=name + ) + if state: + names = ','.join(state.keys()) + raise ValueError( + f'Unknown field(s) "{names}" in state dict while' + f' restoring an instance of {type(pytree).__name__}' + f' at path {serialization.current_path()}' + ) + return pytree.replace(**updates) + + def replace(self: P, **kwargs: tp.Any) -> P: + """ + Replace the values of the fields of the object with the values of the + keyword arguments. If the object is a dataclass, `dataclasses.replace` + will be used. Otherwise, a new object will be created with the same + type as the original object. + """ + if dataclasses.is_dataclass(self): + return dataclasses.replace(self, **kwargs) + + unknown_keys = set(kwargs) - set(vars(self)) + if unknown_keys and not self._pytree__class_is_mutable: + raise ValueError( + f'Trying to replace unknown fields {unknown_keys} ' + f"for '{type(self).__name__}'" + ) + + pytree = copy(self) + with _mutable(pytree): + for key, value in kwargs.items(): + setattr(pytree, key, value) + + return pytree + + def __nnx_repr__(self): + yield reprlib.Object(type(self)) + for name, value in vars(self).items(): + yield reprlib.Attr(name, repr(value)) diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/experimental/nnx/nnx/reprlib.py new file mode 100644 index 0000000000..4fac09dd16 --- /dev/null +++ b/flax/experimental/nnx/nnx/reprlib.py @@ -0,0 +1,108 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import threading +import typing as tp +from abc import ABC, abstractmethod + + +@dataclasses.dataclass +class ReprContext(threading.local): + indent_stack: tp.List[str] = dataclasses.field(default_factory=lambda: ['']) + + +REPR_CONTEXT = ReprContext() + + +@dataclasses.dataclass +class Object: + type: tp.Union[str, type] + start: str = '(' + end: str = ')' + value_sep: str = '=' + elem_indent: str = ' ' + empty_repr: str = '' + + +@dataclasses.dataclass +class Attr: + key: str + value: tp.Union[str, tp.Any] + start: str = '' + end: str = '' + + +class Representable(ABC): + __slots__ = () + + @abstractmethod + def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: + raise NotImplementedError + + def __repr__(self) -> str: + return get_repr(self) + + +@contextlib.contextmanager +def add_indent(indent: str) -> tp.Iterator[None]: + REPR_CONTEXT.indent_stack.append(REPR_CONTEXT.indent_stack[-1] + indent) + + try: + yield + finally: + REPR_CONTEXT.indent_stack.pop() + + +def get_indent() -> str: + return REPR_CONTEXT.indent_stack[-1] + + +def get_repr(obj: Representable) -> str: + if not isinstance(obj, Representable): + raise TypeError(f'Object {obj!r} is not representable') + + iterator = obj.__nnx_repr__() + config = next(iterator) + if not isinstance(config, Object): + raise TypeError(f'First item must be Config, got {type(config).__name__}') + + def _repr_elem(elem: tp.Any) -> str: + if not isinstance(elem, Attr): + raise TypeError(f'Item must be Elem, got {type(elem).__name__}') + + value = elem.value if isinstance(elem.value, str) else repr(elem.value) + + if '\n' in value and not isinstance(elem.value, Representable): + value = value.replace('\n', '\n' + get_indent()) + + return ( + f'{get_indent()}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}' + ) + + with add_indent(config.elem_indent): + elems = list(map(_repr_elem, iterator)) + elems = ',\n'.join(elems) + + if elems: + elems = '\n' + elems + '\n' + get_indent() + else: + elems = config.empty_repr + + type_repr = ( + config.type if isinstance(config.type, str) else config.type.__name__ + ) + + return f'{type_repr}{config.start}{elems}{config.end}' diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py new file mode 100644 index 0000000000..e12ca1a60c --- /dev/null +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -0,0 +1,227 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import hashlib +import typing as tp + +import jax +import numpy as np + +from flax.experimental.nnx.nnx import errors, filterlib, tracers + +Counts = list[int] +AxesValue = tp.Union[int, None] +Pattern = tp.Union[AxesValue, tuple[AxesValue, ...]] + + +class Missing: + pass + + +MISSING = Missing() + + +def _stable_hash(data: tp.Sequence[tp.Hashable]) -> int: + hash_str = ' '.join(str(x) for x in data) + _hash = hashlib.blake2s(hash_str.encode()) + hash_bytes = _hash.digest() + # uint32 is represented as 4 bytes in big endian + return int.from_bytes(hash_bytes[:4], byteorder='big') + + +@dataclasses.dataclass +class RngStream: + key: jax.Array # dynamic + counts: list[int] # static + + def make_rng(self) -> jax.Array: + fold_data = _stable_hash(self.counts) + self.counts[-1] += 1 + return jax.random.fold_in(self.key, fold_data) + + def fork(self, pattern: Pattern) -> 'RngStream': + if pattern is None: + # broadcast key + key = self.key + count_path = [*self.counts, 0] + self.counts[-1] += 1 + else: + key = self.make_rng() + # split key + if isinstance(pattern, int): + key = jax.random.split(key, pattern) + else: + num_splits = int(np.prod([x for x in pattern if x is not None])) + axis_size = tuple(x if x is not None else 1 for x in pattern) + # reshape key + key = jax.random.split(key, num_splits).reshape(*axis_size) + count_path = [0] + return RngStream(key, count_path) + + def copy(self) -> 'RngStream': + return RngStream(self.key, self.counts.copy()) + + +jax.tree_util.register_pytree_node( + RngStream, + lambda rng: ((rng.key,), tuple(rng.counts)), + lambda counts, nodes: RngStream(nodes[0], list(counts)), +) + +RngValue = tp.Union[int, jax.Array, RngStream] +RngDict = tp.Union[ + dict[str, int], + dict[str, jax.Array], + dict[str, RngStream], + dict[str, RngValue], +] + + +class Rngs(tp.Mapping[str, tp.Callable[[], jax.Array]]): + __slots__ = ('_trace_state', '_rngs', '_counts') + + def __init__( + self, + default: RngValue | RngDict | None = None, + **rngs: RngValue, + ): + if default is not None: + if isinstance(default, dict): + rngs = {**default, **rngs} + else: + rngs['default'] = default + + self._rngs = { + name: ( + RngStream(jax.random.key(value), [0]) + if isinstance(value, int) + else RngStream(value, [0]) + if isinstance(value, jax.Array) + else value.copy() + ) + for name, value in rngs.items() + } + self._trace_state = tracers.TraceState() + + def _make_rng(self, name: str) -> jax.Array: + if not self.is_valid(): + raise errors.TraceContextError( + 'Cannot use Rngs from a different trace level' + ) + if name not in self._rngs: + if 'default' not in self._rngs: + raise ValueError(f"No RNG named {name!r} or 'default' found in Rngs.") + stream = self._rngs['default'] + else: + stream = self._rngs[name] + + return stream.make_rng() + + def __getitem__(self, name: str) -> tp.Callable[[], jax.Array]: + return lambda: self._make_rng(name) + + __getattr__ = __getitem__ + + def __call__(self): + return self.default() + + def __iter__(self) -> tp.Iterator[str]: + return iter(self._rngs) + + def __len__(self) -> int: + return len(self._rngs) + + def __contains__(self, name: tp.Any) -> bool: + return name in self._rngs + + def copy(self) -> 'Rngs': + return Rngs(**self._rngs) + + def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs': + rngs: dict[str, tp.Any] = self._rngs.copy() + rngs.update(kwargs) + return Rngs(**rngs) + + def is_valid(self) -> bool: + return self._trace_state.is_valid() + + @tp.overload + def fork(self) -> dict[str, RngStream]: + ... + + @tp.overload + def fork(self, __default: Pattern) -> dict[str, RngStream]: + ... + + @tp.overload + def fork( + self, + __default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING, + **patterns: Pattern, + ) -> tuple[dict[str, RngStream], dict[str, RngStream]]: + ... + + def fork( + self, + _default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING, + **patterns: Pattern, + ) -> dict[str, RngStream] | tuple[dict[str, RngStream], dict[str, RngStream]]: + if not self.is_valid(): + raise errors.TraceContextError( + 'Cannot use Rngs from a different trace level' + ) + + filter_patterns: list[tuple[filterlib.Filter, Pattern]] + if isinstance(_default, dict): + # merge default and patterns + filter_patterns = [ + *_default.items(), + *patterns.items(), + (..., None), # broadcast all remaining + ] + else: + default = None if isinstance(_default, Missing) else _default + filter_patterns = [ + *patterns.items(), + (..., default), # split all remaining with default + ] + + predicate_pattern = [ + (filterlib.to_predicate(filter_), pattern) + for filter_, pattern in filter_patterns + ] + + splits: dict[str, RngStream] = {} + broadcasts: dict[str, RngStream] = {} + + for name, stream in self._rngs.items(): + for predicate, pattern in predicate_pattern: + if predicate(name, stream): + fork = stream.fork(pattern) + if pattern is None: + broadcasts[name] = fork + else: + splits[name] = fork + break + else: + raise RuntimeError( + f'Strea {name!r} did not match any predicate, this is a bug.' + ) + + if isinstance(_default, dict) or patterns: + return splits, broadcasts + else: + return {**splits, **broadcasts} diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py new file mode 100644 index 0000000000..a885aaeec9 --- /dev/null +++ b/flax/experimental/nnx/nnx/spmd.py @@ -0,0 +1,223 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import typing as tp + +import jax +from jax.experimental import maps +from jax.sharding import Mesh, PartitionSpec + +from flax.experimental.nnx.nnx import variables +from flax.experimental.nnx.nnx.pytreelib import TreeNode +from flax.experimental.nnx.nnx.state import State + +# Real types and dummy aliases for documentation +Array = tp.Any # pylint: disable=invalid-name +ArrayPytree = tp.Any # pylint: disable=invalid-name +PartitionSpecPytree = tp.Any # pylint: disable=invalid-name + +A = tp.TypeVar('A') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +PARTITION_NAME = 'partition_name' +Sharding = tuple[tp.Optional[str], ...] + + +@tp.runtime_checkable +class HasSharding(tp.Protocol): + sharding: tp.Optional[Sharding] + + +def add_axis( + state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] +) -> State: + axis_name = _get_partition_name(params) + + def _add_axis(x: tp.Any): + if isinstance(x, variables.Variable): + if isinstance(x, HasSharding) and x.sharding is not None: + sharding = list(x.sharding) + while len(sharding) < index: + sharding.append(None) + sharding.insert(index, axis_name) + x = x.replace(sharding=tuple(sharding)) + + x = x.add_axis(axis_name, index) + return x + + return jax.tree_map( + _add_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + ) + + +def remove_axis( + state: State, index: int, params: tp.Mapping[tp.Any, tp.Any] +) -> State: + axis_name = _get_partition_name(params) + + def _remove_axis(x: tp.Any): + if isinstance(x, variables.Variable): + if isinstance(x, HasSharding) and x.sharding is not None: + sharding = list(x.sharding) + assert sharding.pop(index) == axis_name + x = x.replace(sharding=tuple(sharding)) + x = x.remove_axis(axis_name, index) + return x + + return jax.tree_map( + _remove_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + ) + + +def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str: + if PARTITION_NAME not in params: + raise ValueError( + 'Trying to transform a Partitioned variable but "partition_name" ' + f'is not specified in scan_metadata: {params}' + ) + return params[PARTITION_NAME] + + +def get_partition_spec(tree: A) -> A: + """Extracts a PartitionSpec tree from a PyTree containing ``Variable`` values.""" + + def _maybe_replicate(x): + if hasattr(x, 'shape'): + return PartitionSpec() + else: + return None + + def f(x): + if isinstance(x, variables.Variable): + if isinstance(x, HasSharding) and x.sharding: + return x.replace(value=PartitionSpec(*x.sharding)) + else: + return x.replace(value=_maybe_replicate(x.value)) + + return _maybe_replicate(x) + + return jax.tree_map( + f, + tree, + is_leaf=lambda x: isinstance(x, variables.Variable) + and not isinstance(x, TreeNode), + ) + + +# Dynamic Axis Mapping Rngs +# ------------------------------------------------------------------------------ + + +def _global_mesh_defined() -> bool: + """Checks if global xmap/pjit mesh resource environment is defined.""" + maps_env = maps.thread_resources.env + return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison + + +def _with_sharding_constraint( + x: Array, + axis_resources: tp.Optional[jax.sharding.PartitionSpec], + mesh: tp.Optional[jax.sharding.Mesh] = None, +): + # if jax.devices()[0].platform == "cpu" or ( + if not _global_mesh_defined() and mesh is None: + return x + else: + if mesh is not None and axis_resources is not None: + sharding = jax.sharding.NamedSharding(mesh, axis_resources) + return jax.lax.with_sharding_constraint(x, sharding) + return jax.lax.with_sharding_constraint(x, axis_resources) + + +def _is_spec(x): + return x is None or ( + isinstance(x, tuple) and all(isinstance(e, str) or e is None for e in x) + ) + + +def with_sharding_constraint( + x: ArrayPytree, + axis_resources: PartitionSpecPytree, + mesh: tp.Optional[jax.sharding.Mesh] = None, +): + # If no axis binding is set, this is a no-op. + if axis_resources is None: + return x + # Translate logical names to mesh assignments. + return jax.tree_util.tree_map( + functools.partial(_with_sharding_constraint, mesh=mesh), + axis_resources, + x, + is_leaf=_is_spec, + ) + + +# ------------------------------------- +# Partitioning Axis Metadata +# ------------------------------------- + + +@tp.runtime_checkable +class Partitioned(tp.Protocol): + get_value_hooks: tp.Callable[[variables.Variable[tp.Any]], tp.Any] + sharding: Sharding + mesh: tp.Optional[Mesh] + + +def sharding_hook( + node: variables.Variable[tp.Any], + value: tp.Any, + /, +) -> tp.Any: + if _global_mesh_defined() or ( + isinstance(node, Partitioned) and node.mesh is not None + ): + return with_sharding_constraint( + value, + get_partition_spec(node), + mesh=node.mesh, + ) + return value + + +def with_partitioning( + initializer: F, + sharding: Sharding, + mesh: tp.Optional[jax.sharding.Mesh] = None, + get_value_hooks: tp.Union[ + variables.GetValueHook[A], tp.Sequence[variables.GetValueHook[A]] + ] = (), + create_value_hooks: tp.Union[ + variables.CreateValueHook[A], tp.Sequence[variables.CreateValueHook[A]] + ] = (), + **metadata: tp.Any, +) -> F: + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks, sharding_hook) + else: + get_value_hooks = (*get_value_hooks, sharding_hook) + + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks, sharding_hook) + else: + create_value_hooks = (*create_value_hooks, sharding_hook) + + return variables.with_metadata( + initializer, + get_value_hooks=get_value_hooks, + create_value_hooks=create_value_hooks, + sharding=sharding, + mesh=mesh, + **metadata, + ) diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py new file mode 100644 index 0000000000..7c62927217 --- /dev/null +++ b/flax/experimental/nnx/nnx/state.py @@ -0,0 +1,226 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import typing as tp + +import jax +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import filterlib, reprlib +from flax.experimental.nnx.nnx.variables import Variable + +A = tp.TypeVar('A') + +Leaf = tp.Any +Path = str +StateDict = tp.Dict[Path, tp.Any] +StateMapping = tp.Mapping[Path, tp.Any] + + +class State(tp.MutableMapping[Path, Leaf], reprlib.Representable): + __slots__ = ('_mapping',) + + def __init__( + self, + __input: tp.Union[ + tp.Mapping[Path, Variable[Leaf]], + tp.Iterator[tp.Tuple[Path, Variable[Leaf]]], + ], + /, + ): + self._mapping = dict(__input) + + @property + def variables(self) -> dict[Path, Variable[Leaf]]: + return self._mapping + + def __getitem__(self, __key: Path) -> Leaf: + return self._mapping[__key].value + + def __setitem__(self, __key: Path, __value: Leaf) -> None: + self._mapping[__key] = self._mapping[__key].replace(value=__value) + + def __delitem__(self, __key: Path) -> None: + del self._mapping[__key] + + def __iter__(self) -> tp.Iterator[Path]: + return iter(self._mapping) + + def __len__(self) -> int: + return len(self._mapping) + + def __nnx_repr__(self): + yield reprlib.Object(type(self), value_sep=': ', start='({', end='})') + + for k, v in self.items(): + yield reprlib.Attr(repr(k), v) + + @tp.overload + def split(self, first: filterlib.Filter, /) -> 'State': + ... + + @tp.overload + def split( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Tuple['State', ...]: + ... + + def split( + self, first: filterlib.Filter, /, *filters: filterlib.Filter + ) -> tp.Union['State', tp.Tuple['State', ...]]: + filters = (first, *filters) + *states, rest = _split_state(self, *filters) + + if rest: + raise ValueError( + 'Non-exhaustive filters, got a non-empty remainder: ' + f'{list(rest.keys())}.\nUse `...` to match all remaining elements.' + ) + + if len(states) == 1: + states = State(states[0]) + else: + states = tuple(State(state) for state in states) + return states + + @tp.overload + def extract( + self, + first: filterlib.Filter, + /, + ) -> 'State': + ... + + @tp.overload + def extract( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Tuple['State', ...]: + ... + + def extract( + self, + first: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Union['State', tp.Tuple['State', ...]]: + *states, _rest = _split_state(self, first, *filters) + + assert len(states) == len(filters) + 1 + + if len(states) == 1: + states = State(states[0]) + else: + states = tuple(State(state) for state in states) + + return states + + @staticmethod + def merge(state: 'State', /, *states: 'State') -> 'State': + states = (state, *states) + + if len(states) == 1: + return states[0] + + new_state: StateDict = {} + + for state in states: + new_state.update(state.variables) + + return State(new_state) + + def __or__(self, other: 'State') -> 'State': + if not other: + return self + return State.merge(self, other) + + def __sub__(self, other: 'State') -> 'State': + if not other: + return self + + # create new State via __new__ to avoid __init__ sorting + _mapping = {k: v for k, v in self._mapping.items() if k not in other} + state = object.__new__(State) + state._mapping = _mapping + return state + + +def _state_flatten_with_keys( + x: State, +): + items = sorted(x._mapping.items(), key=lambda item: item[0]) + children = tuple((jtu.DictKey(key), value) for key, value in items) + return children, tuple(x._mapping.keys()) + + +def _state_unflatten( + static: tp.Tuple[Path, ...] | None, + leaves: tp.Tuple[Leaf, ...] | tuple[dict[str, Leaf]], +): + return State(zip(static, leaves)) if static else State(leaves[0]) + + +def _state_flatten(x: State): + return (x._mapping,), None + + +jax.tree_util.register_pytree_with_keys( + State, + _state_flatten_with_keys, + _state_unflatten, + flatten_func=_state_flatten, +) + + +def _split_state( + state: StateMapping, + *filters: filterlib.Filter, +) -> tp.Tuple[StateDict, ...]: + for i, filter_ in enumerate(filters): + if filter_ is ... and i != len(filters) - 1: + raise ValueError( + 'Ellipsis `...` can only be used as the last filter, ' + f'got it at index {i}.' + ) + predicates = tuple(map(filterlib.to_predicate, filters)) + + # we have n + 1 states, where n is the number of predicates + # the last state is for values that don't match any predicate + states: tp.Tuple[StateDict, ...] = tuple( + {} for _ in range(len(predicates) + 1) + ) + + if isinstance(state, State): + items = state._mapping.items() + else: + items = state.items() + + for path, value in items: + for i, predicate in enumerate(predicates): + if predicate(path, value): + states[i][path] = value + break + else: + # if we didn't break, set leaf to last state + states[-1][path] = value + + return states diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/experimental/nnx/nnx/tracers.py new file mode 100644 index 0000000000..158bee22e8 --- /dev/null +++ b/flax/experimental/nnx/nnx/tracers.py @@ -0,0 +1,113 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Taken from flax/core/tracer.py 🏴‍☠️ + +import contextlib +import dataclasses +import threading +import typing as tp + +import jax +import jax.core +from jax.core import MainTrace + +from flax.experimental.nnx.nnx import reprlib + + +@tp.runtime_checkable +class Tracer(tp.Protocol): + _trace: jax.core.Trace + + +def get_top_trace(pytree: tp.Union[tp.Any, Tracer]) -> MainTrace: + """Returns the main top trace of a sequence of tracers.""" + if isinstance(pytree, Tracer): + return pytree._trace.main + + return jax.core.find_top_trace(jax.tree_util.tree_leaves(pytree)).main + + +def current_jax_trace() -> MainTrace: + """Returns the innermost Jax tracer.""" + return get_top_trace(()) + + +def get_all_traces(pytree: tp.Union[tp.Any, Tracer]) -> tp.Set[MainTrace]: + """Returns True if all tracers have the same main trace.""" + if isinstance(pytree, Tracer): + return {pytree._trace.main} + else: + return { + trace._trace.main + for trace in jax.tree_util.tree_leaves(pytree) + if isinstance(trace, Tracer) + } + + +def trace_level(main): + """Returns the level of the trace of -infinity if it is None.""" + if main: + return main.level + return float('-inf') + + +@dataclasses.dataclass +class TraceContext(threading.local): + nnx_trace_stack: tp.List[MainTrace] = dataclasses.field( + default_factory=lambda: [current_jax_trace()] + ) + + +TRACE_CONTEXT = TraceContext() + + +@contextlib.contextmanager +def nnx_trace(trace: MainTrace): + TRACE_CONTEXT.nnx_trace_stack.append(trace) + try: + yield + finally: + TRACE_CONTEXT.nnx_trace_stack.pop() + + +def current_nnx_trace() -> MainTrace: + return TRACE_CONTEXT.nnx_trace_stack[-1] + + +class TraceState(reprlib.Representable): + __slots__ = ['_jax_trace', '_nnx_trace'] + + def __init__(self): + self._jax_trace = current_jax_trace() + self._nnx_trace = current_nnx_trace() + + @property + def jax_trace(self): + return self._jax_trace + + @property + def nnx_trace(self): + return self._nnx_trace + + def is_valid(self) -> bool: + return ( + self._jax_trace is current_jax_trace() + and self._nnx_trace is current_nnx_trace() + ) + + def __nnx_repr__(self): + yield reprlib.Object(f'{type(self).__name__}') + yield reprlib.Attr('jax_trace', self._jax_trace) + yield reprlib.Attr('nnx_trace', self._nnx_trace) diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py new file mode 100644 index 0000000000..db86214d4f --- /dev/null +++ b/flax/experimental/nnx/nnx/transforms.py @@ -0,0 +1,1639 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import functools +import typing as tp +from abc import abstractmethod +from types import MappingProxyType +from typing import Any + +import jax +import jax.numpy as jnp +import jax.stages + +from flax.experimental.nnx.nnx import ( + filterlib, + rnglib, + spmd, + tracers, + variables, +) +from flax.experimental.nnx.nnx.module import ( + CallableProxy, + DelayedAccessor, + Module, + ModuleDef, + ModuleMeta, +) +from flax.experimental.nnx.nnx.state import State + +A = tp.TypeVar('A') +C = tp.TypeVar('C') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) +M = tp.TypeVar('M', bound=Module) +N = tp.TypeVar('N', bound=Module) + +AxisName = tp.Hashable +Leaf = tp.Any +Leaves = tp.List[Leaf] + + +def _check_args(args: tuple[tp.Any, ...]): + """Check if Rngs is passed as a positional argument and raise an error.""" + for arg in args: + if isinstance(arg, rnglib.Rngs): + raise ValueError( + "Rngs must be passed as a keyword argument named 'rngs', not a" + ' positional argument' + ) + + +class LiftedModule(Module, tp.Generic[M]): + @abstractmethod + def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> tp.Any: + ... + + @property + @abstractmethod + def _submodule(self) -> M: + ... + + def __call__(self, *args, **kwargs) -> tp.Any: + return self.call(*args, **kwargs) # type: ignore + + @property + def call(self) -> tp.Any: + module = self + + def check_and_call(*args, **kwargs): + _check_args(args) + return self._call(*args, **kwargs) + + proxy = CallableProxy(check_and_call, DelayedAccessor()) + + while isinstance(module._submodule, LiftedModule): + module = module._submodule + proxy = proxy.call + + return proxy # type: ignore + + +# ------------------------------- +# jit +# ------------------------------- + +UNSPECIFIED = object() + + +@dataclasses.dataclass +class JITOptions: + in_shardings: tp.Any + out_shardings: tp.Any + static_argnums: tp.Union[int, tp.Sequence[int], None] + static_argnames: tp.Union[str, tp.Iterable[str], None] + donate_argnums: tp.Union[int, tp.Sequence[int]] + keep_unused: bool + device: tp.Optional[jax.Device] + backend: tp.Optional[str] + inline: bool + abstracted_axes: tp.Optional[tp.Any] + + def get_kwargs(self) -> dict[str, tp.Any]: + kwargs = vars(self).copy() + if kwargs['in_shardings'] is UNSPECIFIED: + kwargs.pop('in_shardings') + if kwargs['out_shardings'] is UNSPECIFIED: + kwargs.pop('out_shardings') + return kwargs + + +class JITMeta(ModuleMeta): + def __call__( + self, + module_constructor: tp.Callable[..., M], + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: tp.Union[int, tp.Sequence[int], None] = None, + static_argnames: tp.Union[str, tp.Iterable[str], None] = None, + donate_argnums: tp.Union[int, tp.Sequence[int]] = (), + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + ) -> tp.Callable[..., 'JIT[M]']: + super_call = super().__call__ + + def _create_jit(*args, **kwargs) -> JIT[M]: + _check_args(args) + return super_call( + module_constructor=module_constructor, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_jit + + +class JittedFn(tp.Protocol, tp.Generic[M]): + def __call__( + self, state_and_def: tuple[State | tuple[State, ...], ModuleDef[M]] + ) -> tuple[tuple[State | tuple[State, ...], ModuleDef[M]], tp.Any]: + ... + + +def get_jitted_fn(_module_type: type[M], f, options: JITOptions) -> JittedFn[M]: + jit_kwargs = options.get_kwargs() + + @functools.partial(jax.jit, **jit_kwargs) + def jitted_fn( + state_and_def: tuple[State | tuple[State, ...], ModuleDef[M]], + *args, + **kwargs, + ): + _check_args(args) + states, moduledef = state_and_def + + if isinstance(states, State): + states = (states,) + + nnx_trace = tracers.get_top_trace((args, kwargs)) + with tracers.nnx_trace(nnx_trace): + if 'rngs' in kwargs: + kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) + module = moduledef.merge(*states) + out = f(module, *args, **kwargs) + + updates = module.split() + out = (updates, out) + + return out + + return jitted_fn + + +def jit_init( + jitted_fn: JittedFn[M], + module: M, + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], +) -> None: + if not isinstance(module, Module): + raise TypeError(f'Expected Module, got {type(module).__name__}') + + module = tp.cast(M, module) + + if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], rnglib.Rngs): + kwargs['rngs'] = rngs.fork() + + state_and_def = module.split() + out = jitted_fn(state_and_def, *args, **kwargs) + updates, _ = out + module.update(updates) + + +def jit_apply( + jitted_fn: JittedFn[M], + module: M, + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], +) -> tp.Any: + if not isinstance(module, Module): + raise TypeError(f'Expected Module, got {type(module).__name__}') + + module = tp.cast(M, module) + + if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], rnglib.Rngs): + kwargs['rngs'] = rngs.fork() + + state_and_def = module.split() + updates, out = jitted_fn(state_and_def, *args, **kwargs) + module.update(updates) + return out + + +class JIT(LiftedModule[M], metaclass=JITMeta): + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: tp.Union[int, tp.Sequence[int], None] = None, + static_argnames: tp.Union[str, tp.Iterable[str], None] = None, + donate_argnums: tp.Union[int, tp.Sequence[int]] = (), + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.options = JITOptions( + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + self.accessor: tp.Optional[DelayedAccessor] = None + + def jit_call_module(module, *args, **kwargs): + assert self.accessor is not None + f = self.accessor(module) + return f(*args, **kwargs) + + self.jitted_fn: JittedFn[M] = get_jitted_fn( + M, jit_call_module, self.options + ) + self.module_constructor = module_constructor + self.jit_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.jit_module + + def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> Any: + self.accessor = accessesor + try: + out = jit_apply(self.jitted_fn, self.jit_module, args, kwargs) + finally: + self.accessor = None + return out + + +def jit( + f: F, + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: tp.Union[int, tp.Sequence[int], None] = None, + static_argnames: tp.Union[str, tp.Iterable[str], None] = None, + donate_argnums: tp.Union[int, tp.Sequence[int]] = (), + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + is_init: tp.Optional[bool] = None, +) -> F: + if is_init is None: + is_init = f.__name__ == '__init__' + + if static_argnames is None: + static_argnames = [] + elif isinstance(static_argnames, str): + static_argnames = [static_argnames] + else: + static_argnames = list(static_argnames) + + options = JITOptions( + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + jitted_fn = get_jitted_fn(Module, f, options) + + if is_init: + + @functools.wraps(f) + def jit_init_wrapper(module: Module, *args, **kwargs): + _check_args(args) + jit_init(jitted_fn, module, args, kwargs) + + wrapper = jit_init_wrapper + wrapper.inner = jitted_fn + else: + + @functools.wraps(f) + def jit_apply_wrapper(module: Module, *args, **kwargs): + _check_args(args) + return jit_apply(jitted_fn, module, args, kwargs) + + wrapper = jit_apply_wrapper + wrapper.inner = jitted_fn + + return wrapper # type: ignore + + +# ------------------------------- +# grad +# ------------------------------- + + +@dataclasses.dataclass +class GradOptions: + wrt: filterlib.Filter + has_aux: bool + holomorphic: bool + allow_int: bool + reduce_axes: tp.Sequence[AxisName] + return_value: bool + + +class GradMeta(ModuleMeta): + def __call__( + self, + module_constructor: tp.Callable[..., M], + *, + wrt: filterlib.Filter = variables.Param, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + return_value: bool = False, + ) -> tp.Callable[..., 'Grad[M]']: + super_call = super().__call__ + + def _create_grad(*args, **kwargs) -> Grad[M]: + _check_args(args) + return super_call( + module_constructor=module_constructor, + wrt=wrt, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + return_value=return_value, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_grad + + +class Grad(LiftedModule[M], metaclass=GradMeta): + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + wrt: filterlib.Filter = variables.Param, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), + return_value: bool = False, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.options = GradOptions( + wrt=wrt, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + return_value=return_value, + ) + self.module_constructor = module_constructor + self.grad_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.grad_module + + def _call(self, accessesor: DelayedAccessor, *args, **kwargs) -> Any: + def grad_call_apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return grad_apply( + self.options, grad_call_apply, self.grad_module, *args, **kwargs + ) + + +def grad_apply(options: GradOptions, f, module: Module, *args, **kwargs): + if not isinstance(module, Module): + raise TypeError(f'Expected a Module, got {type(module).__name__}') + + predicate = filterlib.to_predicate(options.wrt) + + diff, nondiff, moduledef = module.split(predicate, ...) + transform = jax.value_and_grad if options.return_value else jax.grad + + @functools.partial( + transform, + argnums=0, # we'll handle this ourselves + has_aux=True, + holomorphic=options.holomorphic, + allow_int=options.allow_int, + reduce_axes=options.reduce_axes, + ) + def grad_fn(diff: State): + nonlocal moduledef + + with tracers.nnx_trace(tracers.get_top_trace(diff)): + module = moduledef.merge(diff, nondiff) + out = f(module, *args, **kwargs) + + updates, moduledef = module.split() + if options.has_aux: + loss, aux = out + out = (loss, (updates, aux)) + else: + out = (out, updates) + + return out + + out = grad_fn(diff) + + updates: State + if options.return_value: + if options.has_aux: + (loss, (updates, aux)), grads = out + out = (loss, aux), grads + else: + (loss, updates), grads = out + out = loss, grads + else: + if options.has_aux: + grads, (updates, aux) = out + out = grads, aux + else: + out, updates = out + + module.update((updates, moduledef)) + return out + + +@tp.overload +def grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., State]: + ... + + +@tp.overload +def grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + has_aux: tp.Literal[True], + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tuple[State, tp.Any]]: + ... + + +def grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tp.Union[tuple[State, tp.Any], State]]: + if f.__name__ == '__init__': + raise ValueError('Cannot use `grad` with `__init__`') + + options = GradOptions( + wrt=wrt, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + return_value=False, + ) + + @functools.wraps(f) + def grad_wrapper(module: Module, *args, **kwargs): + _check_args(args) + return grad_apply(options, f, module, *args, **kwargs) + + return grad_wrapper # type: ignore + + +@tp.overload +def value_and_grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tuple[jax.Array, State]]: + ... + + +@tp.overload +def value_and_grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + has_aux: tp.Literal[True], + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[..., tuple[tuple[jax.Array, tp.Any], State]]: + ... + + +def value_and_grad( + f: tp.Callable[..., tp.Any], + wrt: filterlib.Filter = variables.Param, + *, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: tp.Sequence[AxisName] = (), +) -> tp.Callable[ + ..., + tp.Union[tuple[tuple[jax.Array, tp.Any], State], tuple[jax.Array, State]], +]: + if f.__name__ == '__init__': + raise ValueError('Cannot use `value_and_grad` with `__init__`') + + options = GradOptions( + wrt=wrt, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + return_value=True, + ) + + @functools.wraps(f) + def value_and_grad_wrapper(module: Module, *args, **kwargs): + _check_args(args) + return grad_apply(options, f, module, *args, **kwargs) + + return value_and_grad_wrapper # type: ignore + + +# ------------------------------- +# scan +# ------------------------------- + + +@dataclasses.dataclass +class ScanOptions: + variable_axes: tp.Mapping[filterlib.Filter, int] + broadcast_rngs: filterlib.Filter + in_args_axes: tp.Any + in_kwargs_axes: tp.Any + out_axes: tp.Any + length: tp.Optional[int] + reverse: bool + unroll: int + scan_metadata: tp.Mapping[str, tp.Any] + scan_output: bool + + +class ScanMeta(ModuleMeta): + def __call__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + scan_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + scan_output: bool = True, + ) -> tp.Callable[..., 'Scan[M]']: + super_call = super().__call__ + + def _create_scan(*args, **kwargs) -> Scan[M]: + _check_args(args) + return super_call( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + scan_metadata=scan_metadata, + scan_output=scan_output, + ) + + return _create_scan + + +class Scan(LiftedModule[M], metaclass=ScanMeta): + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + scan_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + scan_output: bool = True, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + self.options = ScanOptions( + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + scan_metadata=scan_metadata, + scan_output=scan_output, + ) + self.scan_module = scan_init( + self.options, module_constructor, module_init_args, module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.scan_module + + def _call( + self, accessesor: DelayedAccessor, *args, **kwargs + ) -> tuple[tp.Any, tp.Any]: + if len(args) < 1: + raise TypeError( + f'Expected at least 1 positional arguments, got {len(args)}' + ) + _check_args(args) + carry_arg, args = args[0], args[1:] + + def scan_call_apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return scan_apply( + self.options, + scan_call_apply, + self.scan_module, + carry_arg, + args, + kwargs, + ) + + +class ScanCall(tp.Protocol, tp.Generic[C, B]): + def __call__( + self, + module: Module, + carry_arg: C, + *args: tp.Any, + **kwargs: tp.Any, + ) -> tuple[C, B] | C: + ... + + +def scan_init( + options: ScanOptions, + module_constructor: tp.Callable[..., M], + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], +) -> M: + if options.variable_axes and options.length is None: + raise ValueError('Cannot use variable_axes without specifying a length') + + _check_args(module_init_args) + + rngs = module_init_kwargs.pop('rngs', None) + + if rngs is not None and not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + + split_keys = [] + + if rngs is not None: + if not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): options.length} + ) + + if split_keys and options.length is None: + raise ValueError('Cannot split RNGs without specifying a length') + + else: + split_keys = None + broadcast_keys = None + + moduledef: tp.Optional[ModuleDef[M]] = None + + def _init_state(split_keys, broadcast_keys): + nonlocal moduledef + + if split_keys is not None: + assert broadcast_keys is not None + module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) + + module = module_constructor(*module_init_args, **module_init_kwargs) + + # lift module + filters = (*options.variable_axes.keys(), ...) + + *states, moduledef = module.split(*filters) + + return tuple(states) + + if split_keys is not None or options.variable_axes: + init_out_axes = (*options.variable_axes.values(), None) + _init_state = jax.vmap( + _init_state, + in_axes=(0, None), + out_axes=init_out_axes, + axis_size=options.length, + ) + + *axes_states, carry_state = _init_state(split_keys, broadcast_keys) + moduledef = tp.cast(ModuleDef[M], moduledef) + + # add additional axis name to Variable.sharding + if spmd.PARTITION_NAME in options.scan_metadata: + axes_states = [ + spmd.add_axis(state, index, options.scan_metadata) + for state, index in zip(axes_states, options.variable_axes.values()) + ] + + module = moduledef.merge(*axes_states, carry_state) + + return module + + +def scan_apply( + options: ScanOptions, + f: ScanCall[C, B], + module: Module, + carry_arg: C, + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], +) -> tuple[C, B] | C: + rngs = kwargs.pop('rngs', None) + + # split module state + filters = (*options.variable_axes.keys(), ...) + *scan_states, carry_state, moduledef = module.split(*filters) + + # transpose axes state + scan_states = tuple( + jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state) + for axes_state, axis in zip(scan_states, options.variable_axes.values()) + ) + # transpose axes arg + scan_args = jax.tree_map( + lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node) + if axis is not None + else None, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, + ) + broadcast_args = jax.tree_map( + lambda axis, node: None if axis is not None else node, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, + ) + scan_kwargs = jax.tree_map( + lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, axis, 0), node) + if axis is not None + else None, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, + ) + broadcast_kwargs = jax.tree_map( + lambda axis, node: None if axis is not None else node, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, + ) + + # infer length + lengths: tp.Set[int] = set( + x.shape[0] + for x in jax.tree_util.tree_leaves((scan_states, scan_args, scan_kwargs)) + ) + + if len(lengths) > 1: + raise ValueError( + 'Inconsistent lengths between variable_axes states and ' + f'arguments: {lengths}' + ) + elif len(lengths) == 0: + if options.length is None: + raise ValueError( + 'Cannot infer length from variable_axes states or axes_arg, ' + 'please specify `length`' + ) + length = options.length + else: + length = lengths.pop() + if options.length is not None and options.length != length: + raise ValueError( + f'Specified length {options.length} is not the same as the inferred ' + f'length {length}' + ) + + # split rng state + if rngs is not None: + if not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): length} + ) + else: + split_keys = None + broadcast_keys = None + + moduledef_out: tp.Optional[ModuleDef[Module]] = None + + def scan_fn( + carry: tuple[State, tp.Any], + scan: tuple[ + dict[str, rnglib.RngStream] | None, + tuple[State, ...], + tuple[tp.Any, ...], + dict[str, tp.Any], + ], + ): + nonlocal moduledef_out + carry_state, carry_arg = carry + split_keys, scan_states, scan_args, scan_kwargs = scan + + # merge args and kwargs + args = jax.tree_map( + lambda axis, scan, broadcast: scan if axis is not None else broadcast, + options.in_args_axes, + scan_args, + broadcast_args, + is_leaf=lambda x: x is None, + ) + kwargs = jax.tree_map( + lambda axis, scan, broadcast: scan if axis is not None else broadcast, + options.in_kwargs_axes, + scan_kwargs, + broadcast_kwargs, + is_leaf=lambda x: x is None, + ) + + # merge rng state + if split_keys is not None: + assert broadcast_keys is not None + kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) + + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in options.scan_metadata: + scan_states = [ + spmd.remove_axis(state, index, options.scan_metadata) + for state, index in zip(scan_states, options.variable_axes.values()) + ] + + # merge module state + module = moduledef.merge(*scan_states, carry_state) + + output = f(module, carry_arg, *args, **kwargs) + + if options.scan_output: + if not isinstance(output, tuple) or len(output) != 2: + raise ValueError( + 'Expected a tuple of length 2 as the output of the scan function, ' + f'got {output}' + ) + output = tp.cast(tuple[C, B], output) + carry_out, scan_out = output + else: + output = tp.cast(C, output) + carry_out = output + scan_out = None + + # split module state + *scan_states_out, carry_state_out, moduledef_out = module.split(*filters) + carry_state_new = carry_state_out - carry_state + + # remove new carry state + carry_state_out = carry_state_out - carry_state_new + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in options.scan_metadata: + scan_states_out = [ + spmd.add_axis(state, index, options.scan_metadata) + for state, index in zip(scan_states_out, options.variable_axes.values()) + ] + + full_carry_out = (carry_state_out, carry_out) + full_scan_out = (scan_states_out, carry_state_new, scan_out) + + return full_carry_out, full_scan_out + + carry = (carry_state, carry_arg) + scan = (split_keys, scan_states, scan_args, scan_kwargs) + + full_carry_out, full_scan_out = jax.lax.scan( + scan_fn, + carry, + scan, + length=length, + reverse=options.reverse, + unroll=options.unroll, + ) + carry_state, carry_out = full_carry_out + scan_states, carry_state_new, scan_out = full_scan_out + assert moduledef_out is not None + + # transpose axes state + scan_states = tuple( + jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state) + for axes_state, axis in zip(scan_states, options.variable_axes.values()) + ) + # transpose axes arg + scan_out = jax.tree_map( + lambda axis, node: jax.tree_map(lambda x: jnp.moveaxis(x, 0, axis), node), + options.out_axes, + scan_out, + ) + # slice new carry state + carry_state_new = jax.tree_map(lambda x: x[0], carry_state_new) + + module.update(((*scan_states, carry_state, carry_state_new), moduledef_out)) + + if options.scan_output: + return carry_out, scan_out + else: + return carry_out + + +def scan( + f: F, + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + length: tp.Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + is_init: tp.Optional[bool] = None, + scan_metadata: tp.Mapping[str, tp.Any] = {}, + scan_output: bool = True, +) -> F: + if is_init is None: + is_init = f.__name__ == '__init__' + + options = ScanOptions( + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + scan_metadata=scan_metadata, + scan_output=scan_output, + ) + + if is_init: + + @functools.wraps(f) + def scan_init_wrapper(module: Module, *args, **kwargs): + def module_constructor(*args, **kwargs): + _check_args(args) + f(module, *args, **kwargs) + return module + + lifted_module = scan_init(options, module_constructor, args, kwargs) + module.update(lifted_module) + + wrapper = scan_init_wrapper + + else: + + @functools.wraps(f) + def scan_apply_wrapper( + module: Module, + *args, + **kwargs, + ) -> tuple[C, tp.Any]: + if len(args) < 2: + raise TypeError( + f'Expected at least 2 positional arguments, got {len(args)}' + ) + _check_args(args) + + carry_arg, args = args[0], args[1:] + return scan_apply(options, f, module, carry_arg, args, kwargs) + + wrapper = scan_apply_wrapper + + return wrapper # type: ignore + + +# ------------------------------- +# remat +# ------------------------------- + + +class RematMeta(ModuleMeta): + def __call__( + self, + module_constructor: tp.Callable[..., M], + # variables: lift.CollectionFilter = True, + # rngs: lift.PRNGSequenceFilter = True, + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + ) -> tp.Callable[..., 'Remat[M]']: + super_call = super().__call__ + + def create_remat(*args, **kwargs) -> Remat[M]: + _check_args(args) + return super_call( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + return create_remat + + +@dataclasses.dataclass +class RematOptions: + prevent_cse: bool + static_argnums: tp.Union[int, tuple[int, ...]] + policy: tp.Optional[tp.Callable[..., bool]] + + def __post_init__(self): + if isinstance(self.static_argnums, int): + self.static_argnums = (self.static_argnums,) + + # add 2 as an offset to account for state and keys + self.static_argnums = tuple( + x + 2 if x >= 0 else x for x in self.static_argnums + ) + + +class Remat(LiftedModule[M], metaclass=RematMeta): + def __init__( + self, + *, + module_constructor: tp.Callable[..., M], + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.options = RematOptions( + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + self.module_constructor = module_constructor + self.remat_module = self.module_constructor( + *module_init_args, **module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.remat_module + + def _call( + self, + accessesor: DelayedAccessor, + *args, + rngs: tp.Optional[rnglib.Rngs] = None, + ) -> tp.Any: + def remat_call_apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return remat_apply( + self.options, + remat_call_apply, + self.remat_module, + args, + rngs, + ) + + +class RematCall(tp.Protocol): + def __call__(self, *args, rngs: tp.Optional[rnglib.Rngs]) -> tp.Any: + ... + + +def remat_apply( + options: RematOptions, + f: RematCall, + module: Module, + args: tuple[tp.Any, ...], + rngs: tp.Optional[rnglib.Rngs], +): + _check_args(args) + + state, moduledef = module.split() + keys = rngs.fork() if rngs is not None else None + + def _remat_fn( + state: State, + keys: tp.Optional[dict[str, jax.Array]], + *args, + ) -> tuple[tuple[State, ModuleDef[Module]], tp.Any]: + kwargs = {} + if keys is not None: + kwargs['rngs'] = rnglib.Rngs(keys) + + module = moduledef.merge(state) + out = f(module, *args, **kwargs) + + state_and_def = module.split() + + return state_and_def, out + + state_and_def: tuple[State, ModuleDef[Module]] + state_and_def, out = jax.checkpoint( + _remat_fn, + prevent_cse=options.prevent_cse, + static_argnums=options.static_argnums, + policy=options.policy, + )(state, keys, *args) + + module.update(state_and_def) + + return out + + +def remat( + f: F, + *, + # variables: lift.CollectionFilter, + # rngs: lift.PRNGSequenceFilter, + prevent_cse: bool = True, + static_argnums: tp.Union[int, tuple[int, ...]] = (), + policy: tp.Optional[tp.Callable[..., bool]] = None, + is_init: tp.Optional[bool] = None, +) -> F: + if is_init is None: + is_init = f.__name__ == '__init__' + + options = RematOptions( + # variables=variables, + # rngs=rngs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + if is_init: + return f + else: + + @functools.wraps(f) + def remat_wrapper( + module: Module, *args, rngs: tp.Optional[rnglib.Rngs] = None + ): + return remat_apply(options, f, module, args, rngs) + + return remat_wrapper # type: ignore + + +# ------------------------------- +# vmap +# ------------------------------- + + +@dataclasses.dataclass +class VmapOptions: + variable_axes: tp.Mapping[filterlib.Filter, int] + broadcast_rngs: filterlib.Filter + in_args_axes: tp.Any + in_kwargs_axes: tp.Any + out_axes: tp.Any + axis_size: int | None + axis_name: str | None + spmd_axis_name: str | None + vmap_metadata: tp.Mapping[str, tp.Any] + + +class VmapMeta(ModuleMeta): + def __call__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + axis_size: int | None = None, + axis_name: str | None = None, + spmd_axis_name: str | None = None, + vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + ) -> tp.Callable[..., 'Vmap[M]']: + super_call = super().__call__ + + def _create_scan(*args, **kwargs) -> Scan[M]: + _check_args(args) + return super_call( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + vmap_metadata=vmap_metadata, + ) + + return _create_scan + + +class Vmap(LiftedModule[M], metaclass=VmapMeta): + def __init__( + self, + module_constructor: tp.Callable[..., M], + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + axis_size: int | None = None, + axis_name: str | None = None, + spmd_axis_name: str | None = None, + vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + # submodule args + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], + ): + self.module_constructor = module_constructor + self.options = VmapOptions( + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + vmap_metadata=vmap_metadata, + ) + self.vmap_module = vmap_init( + self.options, module_constructor, module_init_args, module_init_kwargs + ) + + @property + def _submodule(self) -> M: + return self.vmap_module + + def _call( + self, accessesor: DelayedAccessor, *args, **kwargs + ) -> tuple[tp.Any, tp.Any]: + _check_args(args) + + def vmap_call_apply(module, *args, **kwargs): + return accessesor(module)(*args, **kwargs) + + return vmap_apply( + self.options, + vmap_call_apply, + self.vmap_module, + args, + kwargs, + ) + + +class VmapCall(tp.Protocol): + def __call__( + self, + module: Module, + *args: tp.Any, + **kwargs: tp.Any, + ) -> tp.Any: + ... + + +def vmap_init( + options: VmapOptions, + module_constructor: tp.Callable[..., M], + module_init_args: tuple[tp.Any, ...], + module_init_kwargs: dict[str, tp.Any], +) -> M: + if options.variable_axes and options.axis_size is None: + raise ValueError('Cannot use variable_axes without specifying a length') + + _check_args(module_init_args) + + rngs = module_init_kwargs.pop('rngs', None) + + if rngs is not None and not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + + if rngs is not None: + if not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): options.axis_size} + ) + if split_keys and options.axis_size is None: + raise ValueError('Cannot split RNGs without specifying a length') + else: + split_keys = None + broadcast_keys = None + + moduledef: tp.Optional[ModuleDef[M]] = None + + def _init_state(split_keys, broadcast_keys): + nonlocal moduledef + + if split_keys is not None: + assert broadcast_keys is not None + module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) + + module = module_constructor(*module_init_args, **module_init_kwargs) + + # lift module + filters = (*options.variable_axes.keys(), ...) + + *states, moduledef = module.split(*filters) + + return tuple(states) + + if split_keys is not None or options.variable_axes: + init_out_axes = (*options.variable_axes.values(), None) + _init_state = jax.vmap( + _init_state, + in_axes=(0, None), + out_axes=init_out_axes, + axis_size=options.axis_size, + ) + + *axes_states, carry_state = _init_state(split_keys, broadcast_keys) + moduledef = tp.cast(ModuleDef[M], moduledef) + + # add additional axis name to Variable.sharding + if spmd.PARTITION_NAME in options.vmap_metadata: + axes_states = [ + spmd.add_axis(state, index, options.vmap_metadata) + for state, index in zip(axes_states, options.variable_axes.values()) + ] + + module = moduledef.merge(*axes_states, carry_state) + return module + + +def vmap_apply( + options: VmapOptions, + f: VmapCall, + module: Module, + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], +) -> tp.Any: + rngs = kwargs.pop('rngs', None) + + # split module state + filters = (*options.variable_axes.keys(), ...) + *vectorized_states, broadcast_state, moduledef = module.split(*filters) + + # infer length + axis_sizes: tp.Set[int] = set() + args_sizes = jax.tree_map( + lambda axis, node: jax.tree_map(lambda x: x.shape[axis], node) + if axis is not None + else None, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, + ) + kwargs_sizes = jax.tree_map( + lambda axis, node: jax.tree_map(lambda x: x.shape[axis], node) + if axis is not None + else None, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, + ) + axis_sizes.update(jax.tree_util.tree_leaves(args_sizes)) + axis_sizes.update(jax.tree_util.tree_leaves(kwargs_sizes)) + + if len(axis_sizes) > 1: + raise ValueError( + 'Inconsistent lengths between variable_axes states and ' + f'arguments: {axis_sizes}' + ) + elif len(axis_sizes) == 0: + if options.axis_size is None: + raise ValueError( + 'Cannot infer length from variable_axes states or axes_arg, ' + 'please specify `length`' + ) + axis_size = options.axis_size + else: + axis_size = axis_sizes.pop() + if options.axis_size is not None and options.axis_size != axis_size: + raise ValueError( + f'Specified axis_size {options.axis_size} is not the same as the' + f' inferred length {axis_size}' + ) + + # split rng state + if rngs is not None: + if not isinstance(rngs, rnglib.Rngs): + raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') + + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): axis_size} + ) + else: + split_keys = None + broadcast_keys = None + + moduledef_out: tp.Optional[ModuleDef[Module]] = None + + keys_axes = 0 + states_axes = list(options.variable_axes.values()) + args_axes = options.in_args_axes + kwargs_axes = options.in_kwargs_axes + out_axes = options.out_axes + + @functools.partial( + jax.vmap, + in_axes=(keys_axes, states_axes, args_axes, kwargs_axes), + out_axes=(None, states_axes, out_axes), + axis_name=options.axis_name, + axis_size=axis_size, + spmd_axis_name=options.spmd_axis_name, + ) + def vmap_fn( + split_keys: dict[str, rnglib.RngStream] | None, + vectorized_states: list[State], + args: tuple[tp.Any, ...], + kwargs: dict[str, tp.Any], + ): + nonlocal moduledef_out + + # merge rng state + if split_keys is not None: + assert broadcast_keys is not None + kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) + + # remove metadata axis name from Variable.sharding + if spmd.PARTITION_NAME in options.vmap_metadata: + vectorized_states = [ + spmd.remove_axis(state, index, options.vmap_metadata) + for state, index in zip( + vectorized_states, options.variable_axes.values() + ) + ] + + # merge module state + module = moduledef.merge(*vectorized_states, broadcast_state) + + output = f(module, *args, **kwargs) + + # split module state + *vectorized_states_out, broadcast_state_out, moduledef_out = module.split( + *filters + ) + + # add metadata axis name to Variable.sharding + if spmd.PARTITION_NAME in options.vmap_metadata: + vectorized_states_out = [ + spmd.add_axis(state, index, options.vmap_metadata) + for state, index in zip( + vectorized_states_out, options.variable_axes.values() + ) + ] + + return broadcast_state_out, vectorized_states_out, output + + broadcast_state, vectorized_states, output = vmap_fn( + split_keys, vectorized_states, args, kwargs + ) + assert moduledef_out is not None + + module.update(((*vectorized_states, broadcast_state), moduledef_out)) + + return output + + +def vmap( + f: F, + *, + variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), + broadcast_rngs: filterlib.Filter = None, + in_args_axes: tp.Any = 0, + in_kwargs_axes: tp.Any = 0, + out_axes: tp.Any = 0, + axis_size: int | None = None, + axis_name: str | None = None, + spmd_axis_name: str | None = None, + vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + is_init: tp.Optional[bool] = None, +) -> F: + if is_init is None: + is_init = f.__name__ == '__init__' + + options = VmapOptions( + variable_axes=variable_axes, + broadcast_rngs=broadcast_rngs, + in_args_axes=in_args_axes, + in_kwargs_axes=in_kwargs_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + vmap_metadata=vmap_metadata, + ) + + if is_init: + + @functools.wraps(f) + def vmap_init_wrapper(module: Module, *args, **kwargs): + def module_constructor(*args, **kwargs): + _check_args(args) + f(module, *args, **kwargs) + return module + + lifted_module = vmap_init(options, module_constructor, args, kwargs) + module.update(lifted_module) + + wrapper = vmap_init_wrapper + + else: + + @functools.wraps(f) + def vmap_apply_wrapper(module: Module, *args, **kwargs) -> tp.Any: + _check_args(args) + return vmap_apply(options, f, module, args, kwargs) + + wrapper = vmap_apply_wrapper + + return wrapper # type: ignore diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py new file mode 100644 index 0000000000..b45e01cca0 --- /dev/null +++ b/flax/experimental/nnx/nnx/variables.py @@ -0,0 +1,485 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools +import typing as tp +from abc import ABCMeta +from functools import partial +from typing import Any + +import jax +import jax.tree_util as jtu + +from flax.experimental.nnx.nnx import reprlib + +A = tp.TypeVar('A') +B = tp.TypeVar('B') +F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +V = tp.TypeVar('V', bound='Variable[Any]') +Sharding = tp.Tuple[tp.Optional[str], ...] +GetValueHook = tp.Callable[['Variable[A]', A], A] +SetValueHook = tp.Callable[['Variable[A]', A], A] +CreateValueHook = tp.Callable[['Variable[A]', A], A] +AxisName = str +AxisIndex = int +AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], V] +RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], V] + +VariableTypeCache: tp.Dict[str, tp.Type['Variable[tp.Any]']] = {} + + +class Empty: + def __repr__(self): + return 'Empty' + + def __eq__(self, other): + return isinstance(other, Empty) + + def __hash__(self): + return hash(Empty) + + +jtu.register_pytree_node( + Empty, + lambda empty: ((), None), + lambda _0, _1: EMPTY, +) + +EMPTY = Empty() + + +@dataclasses.dataclass +class VariableMetadata(tp.Generic[A]): + value: A + set_value_hooks: tuple[SetValueHook[A], ...] + get_value_hooks: tuple[GetValueHook[A], ...] + create_value_hooks: tuple[CreateValueHook[A], ...] + add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] + remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] + metadata: tp.Mapping[str, tp.Any] + + +class VariableMetaclass(ABCMeta): + def __call__(self, value: A, **metadata: tp.Any) -> A: + if isinstance(value, Variable): + container = value + value = container.value + else: + container = None + + obj = super().__call__(value, **metadata) + + if container is not None and not container.is_equivalent(obj): + raise ValueError( + f"input value of type '{type(container).__name__}' is not compatible " + f"with return type '{type(obj).__name__}'" + ) + + return obj + + +class Variable( + tp.Generic[A], reprlib.Representable, metaclass=VariableMetaclass +): + value: A + set_value_hooks: tuple[SetValueHook[A], ...] + get_value_hooks: tuple[GetValueHook[A], ...] + create_value_hooks: tuple[CreateValueHook[A], ...] + add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] + remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] + + def __init__( + self, + value: tp.Union[A, VariableMetadata[A]], + set_value_hooks: tp.Union[ + SetValueHook[A], tp.Sequence[SetValueHook[A]] + ] = (), + get_value_hooks: tp.Union[ + GetValueHook[A], tp.Sequence[GetValueHook[A]] + ] = (), + create_value_hooks: tp.Union[ + CreateValueHook[A], tp.Sequence[CreateValueHook[A]] + ] = (), + add_axis_hooks: tp.Union[ + AddAxisHook['Variable[A]'], tp.Sequence[AddAxisHook['Variable[A]']] + ] = (), + remove_axis_hooks: tp.Union[ + RemoveAxisHook['Variable[A]'], + tp.Sequence[RemoveAxisHook['Variable[A]']], + ] = (), + **metadata: tp.Any, + ): + if set_value_hooks: + if callable(set_value_hooks): + set_value_hooks = (set_value_hooks,) + else: + set_value_hooks = tuple(set_value_hooks) + else: + set_value_hooks = () + if get_value_hooks: + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks,) + else: + get_value_hooks = tuple(get_value_hooks) + else: + get_value_hooks = () + + if create_value_hooks: + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks,) + else: + create_value_hooks = tuple(create_value_hooks) + else: + create_value_hooks = () + + if add_axis_hooks: + if callable(add_axis_hooks): + add_axis_hooks = (add_axis_hooks,) + else: + add_axis_hooks = tuple(add_axis_hooks) + else: + add_axis_hooks = () + + if remove_axis_hooks: + if callable(remove_axis_hooks): + remove_axis_hooks = (remove_axis_hooks,) + else: + remove_axis_hooks = tuple(remove_axis_hooks) + else: + remove_axis_hooks = () + + if isinstance(value, VariableMetadata): + value_metadata = dict(value.metadata) + if set_value_hooks and value.set_value_hooks: + set_value_hooks = set_value_hooks + value.set_value_hooks + elif value.set_value_hooks: + set_value_hooks = value.set_value_hooks + if get_value_hooks and value.get_value_hooks: + get_value_hooks = get_value_hooks + value.get_value_hooks + elif value.get_value_hooks: + get_value_hooks = value.get_value_hooks + if create_value_hooks and value.create_value_hooks: + create_value_hooks = create_value_hooks + value.create_value_hooks + elif value.create_value_hooks: + create_value_hooks = value.create_value_hooks + if add_axis_hooks and value.add_axis_hooks: + add_axis_hooks = add_axis_hooks + value.add_axis_hooks + elif value.add_axis_hooks: + add_axis_hooks = value.add_axis_hooks + if remove_axis_hooks and value.remove_axis_hooks: + remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks + elif value.remove_axis_hooks: + remove_axis_hooks = value.remove_axis_hooks + + metadata.update(value_metadata) + value = tp.cast(A, value.value) + + if hasattr(self, 'on_get_value'): + on_get_value = getattr(type(self), 'on_get_value') + if on_get_value not in get_value_hooks: + get_value_hooks = (on_get_value, *get_value_hooks) + + if hasattr(self, 'on_set_value'): + on_set_value = getattr(type(self), 'on_set_value') + if on_set_value not in set_value_hooks: + set_value_hooks = (on_set_value, *set_value_hooks) + + if hasattr(self, 'on_create_value'): + on_create_value = getattr(type(self), 'on_create_value') + if on_create_value not in create_value_hooks: + create_value_hooks = (on_create_value, *create_value_hooks) + + if hasattr(self, 'on_add_axis'): + on_add_axis = getattr(type(self), 'on_add_axis') + if on_add_axis not in add_axis_hooks: + add_axis_hooks = (on_add_axis, *add_axis_hooks) + + if hasattr(self, 'on_remove_axis'): + on_remove_axis = getattr(type(self), 'on_remove_axis') + if on_remove_axis not in remove_axis_hooks: + remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) + + self.value = value + self.get_value_hooks = get_value_hooks + self.set_value_hooks = set_value_hooks + self.create_value_hooks = create_value_hooks + self.add_axis_hooks = add_axis_hooks + self.remove_axis_hooks = remove_axis_hooks + vars(self).update(metadata) + + # run create_value hooks + self.value = self.create_value(self.value) + + @property + def is_empty(self) -> bool: + return self.value is EMPTY + + if tp.TYPE_CHECKING: + + def __getattr__(self, name: str) -> tp.Any: + ... + + def get_value(self) -> A: + value = self.value + if self.get_value_hooks: + for hook in self.get_value_hooks: + value = hook(self, value) + return value + + def set_value(self: V, value: A) -> V: + if self.set_value_hooks: + for hook in self.set_value_hooks: + value = hook(self, value) + return self.replace(value=value) + + def create_value(self, value: A): + for hook in self.create_value_hooks: + value = hook(self, value) + return value + + def add_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: + box = self + for hook in self.add_axis_hooks: + box = hook(box, axis_name, axis_index) + return box # type: ignore + + def remove_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: + box = self + for hook in self.remove_axis_hooks: + box = hook(box, axis_name, axis_index) + return box # type: ignore + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Variable): + return False + return type(self) is type(other) and vars(other) == vars(self) + + @tp.overload + def replace(self, *, value: B, **kwargs) -> 'Variable[B]': + ... + + @tp.overload + def replace(self, **kwargs) -> 'Variable[A]': + ... + + def replace(self, **kwargs) -> 'Variable[tp.Any]': + # return `value` if it is a Variable + if 'value' in kwargs and isinstance(value := kwargs['value'], Variable): + # remove value from kwargs + kwargs.pop('value') + if not self.is_equivalent(value): + raise ValueError( + 'Cannot replace value from incompatible container, ' + f'expected {type(self).__name__}, got {type(value).__name__}' + ) + # if kwargs aren't empty, recursively call replace + # else return variable value + if kwargs: + return value.replace(**kwargs) + else: + return value + + # get and update attributes + attributes = vars(self).copy() + attributes.update(**kwargs) + # return new instance with updated attributes + obj = object.__new__(type(self)) + vars(obj).update(attributes) + return obj + + def as_empty(self: V) -> V: + return self.replace(value=EMPTY) + + def is_equivalent(self, other: tp.Any) -> bool: + return type(self) is type(other) + + def copy(self: 'Variable[A]') -> 'Variable[A]': + obj = object.__new__(type(self)) + vars(obj).update(vars(self)) + return obj + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + for name, value in vars(self).items(): + if name.endswith('_hooks'): + continue + yield reprlib.Attr(name, repr(value)) + + def __init_subclass__(cls): + super().__init_subclass__() + + jtu.register_pytree_with_keys( + cls, + partial(_variable_flatten, with_keys=True), # type: ignore + partial(_variable_unflatten, cls=cls), # type: ignore + flatten_func=partial(_variable_flatten, with_keys=False), # type: ignore + ) + + # hooks API + if tp.TYPE_CHECKING: + + def on_get_value(self, value: A) -> A: + raise NotImplementedError + + def on_set_value(self, value: A) -> A: + raise NotImplementedError + + def on_create_value(self, value: A) -> A: + raise NotImplementedError + + def on_add_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: + raise NotImplementedError + + def on_remove_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: + raise NotImplementedError + + +def _variable_flatten(x: Variable[tp.Any], *, with_keys: bool): + attributes = vars(x).copy() + value = attributes.pop('value') + if with_keys: + node = (jtu.GetAttrKey('value'), value) + else: + node = value + + return (node,), attributes + + +def _variable_unflatten( + metadata: tp.Mapping[str, tp.Any], + children: tp.Tuple[A], + *, + cls: type[Variable[A]], +) -> Variable[A]: + return cls(children[0], **metadata) # type: ignore + + +jtu.register_pytree_with_keys( + Variable, + partial(_variable_flatten, with_keys=True), # type: ignore + partial(_variable_unflatten, cls=Variable), # type: ignore + flatten_func=partial(_variable_flatten, with_keys=False), # type: ignore +) + + +class Param(Variable[A]): + pass + + +class BatchStat(Variable[A]): + pass + + +class Cache(Variable[A]): + pass + + +class Intermediate(Variable[A]): + pass + + +class Rng(Variable[jax.Array]): + tag: str + + def __init__(self, value: jax.Array, *, tag: str, **metadata: tp.Any): + super().__init__(value, tag=tag, **metadata) + + def on_get_value(self, value: jax.Array): + self.value, value = jax.random.split(value) + return value + + +def with_metadata( + initializer: F, + set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), + get_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), + create_value_hooks: tp.Union[ + CreateValueHook[A], tp.Sequence[CreateValueHook[A]] + ] = (), + add_axis_hooks: tp.Union[ + AddAxisHook['Variable[A]'], tp.Sequence[AddAxisHook['Variable[A]']] + ] = (), + remove_axis_hooks: tp.Union[ + RemoveAxisHook['Variable[A]'], + tp.Sequence[RemoveAxisHook['Variable[A]']], + ] = (), + **metadata: tp.Any, +) -> F: + if set_value_hooks: + if callable(set_value_hooks): + set_value_hooks = (set_value_hooks,) + else: + set_value_hooks = tuple(set_value_hooks) + else: + set_value_hooks = () + + if get_value_hooks: + if callable(get_value_hooks): + get_value_hooks = (get_value_hooks,) + else: + get_value_hooks = tuple(get_value_hooks) + else: + get_value_hooks = () + + if create_value_hooks: + if callable(create_value_hooks): + create_value_hooks = (create_value_hooks,) + else: + create_value_hooks = tuple(create_value_hooks) + else: + create_value_hooks = () + + if add_axis_hooks: + if callable(add_axis_hooks): + add_axis_hooks = (add_axis_hooks,) + else: + add_axis_hooks = tuple(add_axis_hooks) + else: + add_axis_hooks = () + + if remove_axis_hooks: + if callable(remove_axis_hooks): + remove_axis_hooks = (remove_axis_hooks,) + else: + remove_axis_hooks = tuple(remove_axis_hooks) + else: + remove_axis_hooks = () + + @functools.wraps(initializer) + def wrapper(*args): + return VariableMetadata( + initializer(*args), + set_value_hooks=set_value_hooks, + get_value_hooks=get_value_hooks, + create_value_hooks=create_value_hooks, + add_axis_hooks=add_axis_hooks, + remove_axis_hooks=remove_axis_hooks, + metadata=metadata, + ) + + return wrapper # type: ignore + + +def variable_type(name: str) -> tp.Type[Variable[tp.Any]]: + if name not in VariableTypeCache: + VariableTypeCache[name] = type(name, (Variable,), {}) + return VariableTypeCache[name] + + +# add known variable type names +VariableTypeCache['params'] = Param +VariableTypeCache['batch_stats'] = BatchStat +VariableTypeCache['cache'] = Cache +VariableTypeCache['intermediates'] = Intermediate diff --git a/flax/experimental/nnx/scripts/requirements.txt b/flax/experimental/nnx/scripts/requirements.txt new file mode 100644 index 0000000000..7a24b6e2b7 --- /dev/null +++ b/flax/experimental/nnx/scripts/requirements.txt @@ -0,0 +1 @@ +datasets>=2.12.0 diff --git a/flax/experimental/nnx/scripts/run-all-examples.bash b/flax/experimental/nnx/scripts/run-all-examples.bash new file mode 100644 index 0000000000..205e13acb1 --- /dev/null +++ b/flax/experimental/nnx/scripts/run-all-examples.bash @@ -0,0 +1,12 @@ +set -e + +cd ../../.. +source .venv/bin/activate +cd flax/experimental/nnx + +for f in $(find examples -name "*.py"); do + echo -e "\n---------------------------------" + echo "$f" + echo "---------------------------------" + python "$f" +done diff --git a/flax/experimental/nnx/tests/__init__.py b/flax/experimental/nnx/tests/__init__.py new file mode 100644 index 0000000000..e80ba0b35f --- /dev/null +++ b/flax/experimental/nnx/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/flax/experimental/nnx/tests/test_compatibility.py b/flax/experimental/nnx/tests/test_compatibility.py new file mode 100644 index 0000000000..50f12dc877 --- /dev/null +++ b/flax/experimental/nnx/tests/test_compatibility.py @@ -0,0 +1,22 @@ +import jax + +from flax import linen +from flax.experimental import nnx + + +class TestCompatibility: + def test_functional(self): + # Functional API for NNX Modules + functional = nnx.compatibility.functional(nnx.Linear)(32, 64) + state = functional.init(rngs=nnx.Rngs(0)) + x = jax.numpy.ones((1, 32)) + y, updates = functional.apply(state)(x) + + def test_linen_wrapper(self): + ## Wrapper API for Linen Modules + linen_module = linen.Dense(features=64) + x = jax.numpy.ones((1, 32)) + module = nnx.compatibility.LinenWrapper( + linen_module, x, rngs=nnx.Rngs(0) + ) # init + y = module(x) # apply diff --git a/flax/experimental/nnx/tests/test_containers.py b/flax/experimental/nnx/tests/test_containers.py new file mode 100644 index 0000000000..aff4dc062a --- /dev/null +++ b/flax/experimental/nnx/tests/test_containers.py @@ -0,0 +1,91 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from flax.experimental import nnx + + +class TestContainers: + def test_node_idenpotence(self): + x = nnx.Variable(1) + x = nnx.Variable(x) + + assert isinstance(x, nnx.Variable) + + def test_variable_idenpotence(self): + x = nnx.Variable(1) + x = nnx.Variable(x) + + assert isinstance(x, nnx.Variable) + assert x.value == 1 + + def test_variable_cannot_change_collection(self): + x = nnx.Param(1) + + with pytest.raises(ValueError, match='is not compatible with return type'): + x = nnx.BatchStat(x) + + def test_container_cannot_change_type(self): + x = nnx.Variable(1) + + with pytest.raises(ValueError, match='is not compatible with return type'): + x = nnx.Param(x) + + x = nnx.Param(2) + + with pytest.raises(ValueError, match='is not compatible with return type'): + x = nnx.Variable(x) + + def test_unbox(self): + x: nnx.Param[int] = nnx.Param( + 1, + get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + ) + + assert x.get_value() == 4 + + def test_box(self): + x: nnx.Param[int] = nnx.Param( + 1, + set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore + ) + x = x.set_value(5) + + assert x.value == 12 + + def test_module_unbox(self): + class Foo(nnx.Module): + def __init__(self) -> None: + self.x = nnx.Param( + 1, get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] + ) + + module = Foo() + + assert module.x == 4 + assert vars(module)['x'].value == 1 + + def test_module_box(self): + class Foo(nnx.Module): + def __init__(self) -> None: + self.x = nnx.Param( + 1, set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2] + ) + + module = Foo() + module.x = 5 + + assert module.x == 12 + assert vars(module)['x'].value == 12 diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/experimental/nnx/tests/test_helpers.py new file mode 100644 index 0000000000..732358c0a4 --- /dev/null +++ b/flax/experimental/nnx/tests/test_helpers.py @@ -0,0 +1,73 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import optax + +from flax.experimental import nnx + + +class TestHelpers: + def test_train_state(self): + m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) + + params, batch_stats, moduledef = m.split(nnx.Param, nnx.BatchStat) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=nnx.TreeNode(batch_stats), + other=nnx.Variable(100), + int=200, + ) + + leaves = jax.tree_util.tree_leaves(state) + + assert 1 in leaves + assert 2 in leaves + assert 100 in leaves + assert 200 not in leaves + + def test_train_state_methods(self): + class Foo(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + self.batch_norm = nnx.BatchNorm(4, rngs=rngs) + + def __call__(self, x: jax.Array, train: bool) -> jax.Array: + x = self.linear(x) + x = self.batch_norm(x, use_running_average=not train) + return x + + module = Foo(rngs=nnx.Rngs(0)) + params, batch_stats, moduledef = module.split(nnx.Param, nnx.BatchStat) + + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.sgd(1.0), + batch_stats=batch_stats, + ) + + x = jax.numpy.ones((1, 2)) + y, _updates = state.apply('params', 'batch_stats')(x, train=True) + + assert y.shape == (1, 4) + + # fake gradient + grads = jax.tree_map(jnp.ones_like, state.params) + # test apply_gradients + state = state.apply_gradients(grads) diff --git a/flax/experimental/nnx/tests/test_ids.py b/flax/experimental/nnx/tests/test_ids.py new file mode 100644 index 0000000000..28bf66cc4d --- /dev/null +++ b/flax/experimental/nnx/tests/test_ids.py @@ -0,0 +1,30 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from flax.experimental.nnx.nnx import ids + + +class TestIds: + def test_hashable(self): + id1 = ids.uuid() + id2 = ids.uuid() + assert id1 == id1 + assert id1 != id2 + assert hash(id1) != hash(id2) + id1c = copy.copy(id1) + id1dc = copy.deepcopy(id1) + assert hash(id1) != hash(id1c) + assert hash(id1) != hash(id1dc) diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py new file mode 100644 index 0000000000..87c48bae85 --- /dev/null +++ b/flax/experimental/nnx/tests/test_integration.py @@ -0,0 +1,250 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax +import jax.numpy as jnp +import numpy as np + +from flax.experimental import nnx + +A = tp.TypeVar('A') + + +class TestIntegration: + def test_shared_modules(self): + class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x): + x = self.linear(x) + x = self.bn(x) + return nnx.relu(x) + + class Model(nnx.Module): + def __init__(self, *, rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x + + @nnx.jit + def train_step(model: Model, x, y): + @nnx.grad + def loss_fn(model: Model): + with nnx.flags(use_running_average=False): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + ) + + model = Model(rngs=nnx.Rngs(0)) + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + train_step(model, x, y) + + assert model.block1.linear is model.block2.linear + assert model.block1.linear.bias is not None + assert model.block1.bn is not model.block2.bn + + def test_shared_modules_pure(self): + class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x): + x = self.linear(x) + x = self.bn(x) + return nnx.relu(x) + + class Model(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x + + @jax.jit + def train_step(state: nnx.State, moduledef: nnx.ModuleDef[Model], x, y): + model = moduledef.merge(state) + + @nnx.grad + def loss_fn(model: Model): + with nnx.flags(use_running_average=False): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = loss_fn(model) + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + ) + + return model.split() + + moduledef: nnx.ModuleDef[Model] + state, moduledef = Model(rngs=nnx.Rngs(0)).split() + + x = np.random.uniform(size=(4, 2)) + y = np.random.uniform(size=(4, 2)) + + for _i in range(3): + state, moduledef = train_step(state, moduledef, x, y) + + model = moduledef.merge(state) + + assert model.block1.linear.bias is not None + assert model.block2.linear.bias is not None + assert model.block1.linear.kernel is model.block2.linear.kernel + assert model.block1.linear.bias is model.block2.linear.bias + assert model.block1.bn is not model.block2.bn + + def test_stateful_example(self): + class State(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = State(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + @nnx.jit + def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + ) + + # execute the training step + train_step(model, x, y) + assert model.count == 2 + + def test_functional_example(self): + class Count(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + + def __call__(self, x): + self.count += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count == 1 + + params, counts, moduledef = model.split(nnx.Param, Count) + + @jax.jit + def train_step(params, counts, x, y): + def loss_fn(params): + y_pred, (updates, _) = moduledef.apply(params, counts)(x) + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, updates.extract(Count) + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + + # execute the training step + params, counts = train_step(params, counts, x, y) + model = moduledef.merge(params, counts) + assert model.count == 2 + + def test_intermediates_example(self): + class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, rngs=nnx.Rngs(0)) + + y = model(jnp.ones((8, 12))) + + intermediates = model.pop(nnx.Intermediate) + + assert 'y' in intermediates + + def test_intermediates_example_functional(self): + class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b[None] + self.y = nnx.Intermediate(y) + return y + + model = Linear(12, 2, rngs=nnx.Rngs(0)) + + state, moduledef = model.split() + + y, (state, _) = moduledef.apply(state)(jnp.ones((8, 12))) + + intermediates, state = state.split(nnx.Intermediate, ...) + + assert 'y' in intermediates diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py new file mode 100644 index 0000000000..2553ca9c5d --- /dev/null +++ b/flax/experimental/nnx/tests/test_module.py @@ -0,0 +1,621 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from copy import deepcopy +from typing import Any, TypeVar + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + +A = TypeVar('A') + + +class TestModule: + def test_has_module_state(self): + class Foo(nnx.Module): + ... + + foo = Foo() + + assert hasattr(foo, '_module__state') + + def test_trace_level(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match='Cannot mutate Module from different trace level', + ): + m.a = 2 + + f() + + def test_tree_map(self): + m = nnx.Dict(a=nnx.Param(1)) + + state, static = m.split() + + state = jax.tree_map(lambda x: x + 1, state) + + def test_split_2(self): + m = nnx.Dict(a=nnx.Param(1)) + + empty, some, static = m.split(None, ...) + + some = jax.tree_map(lambda x: x + 1, some) + + def test_split_merge(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[int]]): + m = moduledef.merge(state) + m.a = 2 + return m.split() + + state, moduledef = g(*m.split()) + m2 = moduledef.merge(state) + + assert m2.a == 2 + + def test_no_trace_level_error_on_grad(self): + # No trace level error occurs because jax doesn't update + # its top trace for grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @jax.grad + def f(_): + m.a = 2.0 + return 1.0 + + f(1.0) + + def test_trace_level_error_on_nnx_grad(self): + # error occurs because nnx updates its nnx_trace + # in nnx.grad. + m = nnx.Dict(a=nnx.Param(1.0)) + + @nnx.grad + def f(_): + with pytest.raises( + nnx.TraceContextError, + match='Cannot mutate Module from different trace level', + ): + m.a = 2.0 + return 1.0 + + f(m) + + def test_call(self): + class Foo(nnx.Module): + def __init__(self, c: float, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, ())) + self.c = c + + def __call__(self, x, *, rngs: nnx.Rngs): + key = rngs.e() + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, rngs=nnx.Rngs(0)) + + y = foo(x=2.0, rngs=nnx.Rngs(e=1)) + + assert isinstance(y, jax.Array) + + def test_shared_module(self): + m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) + m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) + + m3 = nnx.merge(m2.split()) + + assert m3['x'] is m3['y'] + assert m3['x']['a'] is m3['y']['a'] + assert m3['x']['b'] is m3['y']['b'] + + def test_module_graph(self): + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(1) + self.sub = self + + m = Foo() + + state, moduledef = m.split() + assert len(state) == 1 + + m2 = moduledef.merge(state) + assert m2 is m2.sub + + def test_deref_through_jit(self): + r1 = nnx.Variable(1) + r2 = nnx.Variable(2) + + m = m0 = nnx.Dict({'a': nnx.Sequence([r1, r2]), 'b': r1}) + + @jax.jit + def f(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[Any]]): + m = moduledef.merge(state) + + assert m['a'][0] is not m['b'] + assert m['a'][1] is not m['b'] + + return m.split() + + state, moduledef = f(*m.split()) + m = moduledef.merge(state) + + assert m['a'][0] is not m['b'] + assert m['a'][1] is not m['b'] + + # compare with pytree0 + assert m['a'][0] is not m0['a'][0] + assert m['a'][1] is not m0['a'][1] + assert m['b'] is not m0['b'] + + def test_cross_barrier(self): + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[int]]): + m = moduledef.merge(state) + m.a += 1 + return m.split() + + state, moduledef = g(*m.split()) + m2 = moduledef.merge(state) + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + def test_no_rejit(self): + n = 0 + m = nnx.Dict(a=nnx.Param(1)) + + @jax.jit + def g(state_and_def): + nonlocal n + n += 1 + m = nnx.merge(state_and_def) + m.a += 1 + return m.split() + + m2 = nnx.merge(g(m.split())) + + assert n == 1 + assert m2 is not m + assert m.a == 1 + assert m2.a == 2 + + g(m.split()) + assert n == 1 + + g(m2.split()) + assert n == 1 + + m2.b = nnx.Param(10) + g(m2.split()) + + assert n == 2 + + def test_deref_number_of_fields(self): + r1 = nnx.Variable(1) + r2 = nnx.Variable(2) + v1 = 3 + m = nnx.Dict( + { + 'a': nnx.Sequence([r1, r2, v1]), + 'b': nnx.Dict({'c': r1, 'd': r2}), + } + ) + + p, moduledef = m.split() + assert len(p) == 4 + assert len(jax.tree_util.tree_leaves(p)) == 4 + + def test_deref_array_attributes_not_allowed(self): + # test arrays are nodes + r1 = nnx.Variable(1) + r2 = nnx.Variable(2) + v1 = jax.numpy.array(3) + + with pytest.raises( + ValueError, + match=f"Trying to assing a '{type(v1).__name__}' to the Module", + ): + m = nnx.Dict( + { + 'a': nnx.Sequence([r1, r2, v1]), + 'b': nnx.Dict({'c': r1, 'd': r2}), + } + ) + + def test_clone(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), 3]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), + ) + + m2 = m.clone() + + assert m is not m2 + assert m2.a[0] == m2.b.c + assert m2.a[1] == m2.b.d + + assert m.a[0] == m2.a[0] + assert m.a[1] == m2.a[1] + assert m.b.c == m2.b.c + assert m.b.d == m2.b.d + + def test_sow_basic(self): + class Foo(nnx.Module): + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, 'y', y) + return y + + m = Foo() + y1 = m(2) + y2 = m(10) + + assert y1 == 3 + assert y2 == 11 + assert m.y == (3, 11) + + intermediates = m.pop(nnx.Intermediate) + + assert isinstance(intermediates.variables['y'], nnx.Intermediate) + assert intermediates['y'] == (3, 11) + + assert hasattr(m, 'y') + assert m.y is nnx.EMPTY + + def test_sow_existing_non_variable_field(self): + class Foo(nnx.Module): + def __init__(self) -> None: + self.y = 10 + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, 'y', y) + return y + + m = Foo() + + with pytest.raises(ValueError, match='to be a Variable, got'): + m(2) + + def test_sow_wrong_collection(self): + class Foo(nnx.Module): + def __init__(self) -> None: + self.y = nnx.Param(10) + + def __call__(self, x): + y = x + 1 + self.sow(nnx.Intermediate, 'y', y) + return y + + m = Foo() + + with pytest.raises(ValueError, match='to be of type'): + m(2) + + def test_update_static_state(self): + class Foo(nnx.Module): + def add_field(self): + self.a = 1 + + m1 = Foo() + m2 = Foo() + m2.add_field() + + m1.update(m2) + + assert m1.a == 1 + + def test_update_moduledef(self): + class Foo(nnx.Module): + def add_field(self): + self.a = 1 + + m1 = Foo() + m2 = Foo() + m2.add_field() + + m1.update(m2.get_moduledef()) + + assert m1.a == 1 + + def test_update_static_state_submodules(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + def add_field(self): + self.y = 2 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.a = Bar() + self.b = self.a + + m1 = Foo() + m2 = Foo() + m2.a.add_field() + + m1.update(m2) + + assert m1.a.x == 1 + assert m1.a.y == 2 + assert m1.b.x == 1 + assert m1.b.y == 2 + + def test_update_new_submodule(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.a = Bar() + + def add_module(self): + self.b = Bar() + + m1 = Foo() + m2 = Foo() + m2.add_module() + + m1.update(m2) + + assert m1.a.x == 1 + assert m1.b.x == 1 + + def test_update_update_submodule(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.a = Bar() + self.b = self.a + + m1 = Foo() + m2 = Foo() + m2.a.x = 2 + + m1.update(m2) + + assert m1.a.x == 2 + assert m1.b.x == 2 + + def test_update_add_shared_error(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.a = Bar() + self.b = self.a + + def add_submodule(self): + self.c = self.a + + m1 = Foo() + m2 = Foo() + m2.add_submodule() + + assert hasattr(m2, 'c') + + with pytest.raises( + ValueError, match='Trying to add a new submodule at path' + ): + m1.update(m2) + + def test_update_add_shared_error_new_first(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.b = Bar() + self.c = self.b + + def add_submodule(self): + self.a = self.b + + m1 = Foo() + m2 = Foo() + m2.add_submodule() + + assert hasattr(m2, 'a') + + m2 = m2.clone() # clone to sort the fields + + with pytest.raises( + ValueError, match='Trying to update a submodule at path' + ): + m1.update(m2) + + def test_create_abstract(self): + linear = nnx.Linear.create_abstract(2, 3, rngs=nnx.Rngs(0)) + + assert linear.kernel == jax.ShapeDtypeStruct((2, 3), jnp.float32) + assert linear.bias == jax.ShapeDtypeStruct((3,), jnp.float32) + + def test_deepcopy(self): + class Foo(nnx.Module): + def __init__(self) -> None: + self.a = nnx.Param(1) + self.b = [1, 2, 3] + self.c = nnx.Param(jnp.array([1.0])) + self.self = self + + m1 = Foo() + m2 = deepcopy(m1) + + assert m1.a == m2.a + assert vars(m1)['a'] is not vars(m2)['a'] + assert m1.b is not m2.b + assert m1.c is not m2.c + assert m1.self is m1 + + +class TestModulePytree: + def test_tree_map(self): + class Foo(nnx.Module, experimental_pytree=True): + def __init__(self): + self.node = nnx.Param(1) + self.static = 1 + + m = Foo() + + m = jax.tree_map(lambda x: x + 1, m) + + assert m.node == 2 + assert m.static == 1 + + +class TestModuleDataclass: + def test_basic(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int + b: int = nnx.treenode_field() + c: int = nnx.param_field() + d: int = nnx.variable_field(nnx.BatchStat) + e: int + f: int + + m = Foo( + a=1, # static + b=2, # node + c=3, # param + d=4, # var + e=5, # static int + f=nnx.Variable(6), # test that we can pass in a node + ) + + state, moduledef = m.split() + + assert len(state) == 4 + assert state.variables['b'] == nnx.TreeNode(2) + assert state.variables['c'] == nnx.Param(3) + assert state.variables['d'] == nnx.BatchStat(4) + assert state.variables['f'] == nnx.Variable(6) + + def test_no_override(self): + @nnx.dataclass + class Foo(nnx.Module): + a: int = nnx.treenode_field() + + with pytest.raises(ValueError, match='is not compatible with return type'): + _m = Foo(a=nnx.Param(1)) + + _m = Foo(a=nnx.TreeNode(1)) + + def test_context_none_after_init(self): + @dataclasses.dataclass + class DFoo(nnx.Module): + din: int + dout: int + rngs: nnx.Rngs + + def __post_init__(self): + self.bar = nnx.Linear(self.din, self.dout, rngs=self.rngs) + + def __call__(self, x): + return self.bar(x) + + m = DFoo(1, 1, rngs=nnx.Rngs(0)) + + assert hasattr(m, 'bar') + assert m.rngs is None + + def test_setup_is_called(self): + @dataclasses.dataclass + class DFoo(nnx.Module): + din: int + dout: int + rngs: nnx.Rngs + + def setup(self): + self.bar = nnx.Linear(self.din, self.dout, rngs=self.rngs) + + def __call__(self, x): + return self.bar(x) + + m = DFoo(1, 1, rngs=nnx.Rngs(0)) + + assert hasattr(m, 'bar') + assert m.rngs is None + + +class TestModuleDef: + def test_apply(self): + class Foo(nnx.Module): + def __init__(self, c: float, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), ())) + self.c = c + + def __call__(self, x, *, rngs: nnx.Rngs): + key = rngs.e() + return self.w * x + jax.random.normal(key, ()) + self.c + + rngs = nnx.Rngs(0) + foo = Foo(c=1.0, rngs=rngs) + + states, moduledef = foo.split() + + assert isinstance(states, nnx.State) + assert isinstance(states.variables['w'], nnx.Param) + # assert isinstance(states["c"], jax.Array) + + y, _updates = moduledef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) + + assert isinstance(y, jax.Array) + + def test_derefed_mod_apply(self): + class Foo(nnx.Module): + def __init__(self, c: float, *, rngs: nnx.Rngs): + self.w = nnx.Param( + jax.random.uniform(rngs.params(), ()), + ) + self.c = nnx.Variable(c) + + def __call__(self, x, *, rngs: nnx.Rngs): + key = rngs.e() + return self.w * x + jax.random.normal(key, ()) + self.c + + foo = Foo(c=1.0, rngs=nnx.Rngs(0)) + + state, moduledef = foo.split() + + assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(state, nnx.State) + assert isinstance(state.variables['w'], nnx.Param) + assert isinstance(state.variables['c'], nnx.Variable) + + y, (state, moduledef) = moduledef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) + + assert isinstance(y, jax.Array) diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py new file mode 100644 index 0000000000..fb1862c080 --- /dev/null +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -0,0 +1,159 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax +import pytest + +from flax.experimental import nnx + + +class TestPartitioning: + def test_partition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(2)]), + b=nnx.Param(2), + c=100, + ) + + params, rest, moduledef = m.split(nnx.Param, ...) + + assert len(params) == 2 + assert len(rest) == 1 + + # check params + assert params['a/0'] == m.a[0] + assert params['b'] == m.b + + # check rest + assert rest['a/1'] == m.a[1] + + m2 = moduledef.merge(params, rest) + + assert m2.a[0] == m.a[0] + assert m2.a[1] == m.a[1] + assert m2.b == m.b + assert m2.c == 100 + + def test_complete_partitioning(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error + m.split(nnx.Param, nnx.BatchStat, nnx.Variable) + + def test_complete_partitioning_plus_ellipsis(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + # no error if additional ... is passed at the end + m.split(nnx.Param, nnx.BatchStat, nnx.Variable, ...) + + def test_inclomplete_partition_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match='Non-exhaustive filters, got a non-empty remainder' + ): + m.split(nnx.Param) + + def test_ellipsis_not_last_error(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), + b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), + ) + + with pytest.raises( + ValueError, match='Ellipsis `...` can only be used as the last filter,' + ): + m.split(..., nnx.Param) + + def test_update_from(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=100, + ) + + state = m.split()[0] + state = jax.tree_map(lambda x: x * 2, state) + + m.update(state) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 100 + + def test_update_from_with_array_leaf(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1), nnx.BatchStat(3)]), + b=nnx.Param(2), + c=nnx.Variable(jax.numpy.array(100)), + ) + + state, moduledef = m.split() + state = jax.tree_map(lambda x: x * 2, state) + + m.update(state) + + assert m.a[0] == 2 + assert m.a[1] == 6 + assert m.b == 4 + assert m.c == 200 + + def test_grad_example(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(1.0), nnx.BatchStat(-10)]), + b=nnx.Param(2.0), + c=100, + ) + + params = m.extract(nnx.Param) + + def loss(params): + return sum(2 * p for p in jax.tree_util.tree_leaves(params)) + + grads = jax.grad(loss)(params) + m.update(grads) + + assert m.a[0] == 2.0 + assert m.a[1] == -10 + assert m.b == 2.0 + assert m.c == 100 + + def test_get_paritition(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.Param(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + # test Variables not shared + assert vars(m.a)['0'] is not vars(m)['b'] + + state = m.extract(nnx.Variable) + assert state['a/0'] == m.a[0] + assert state['a/1'] == m.a[1] + assert state['b'] == m.b + assert state.variables['b'] is not state.variables['a/0'] + assert len(state) == 3 diff --git a/flax/experimental/nnx/tests/test_pytree.py b/flax/experimental/nnx/tests/test_pytree.py new file mode 100644 index 0000000000..37b3c2559f --- /dev/null +++ b/flax/experimental/nnx/tests/test_pytree.py @@ -0,0 +1,264 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Generic, TypeVar + +import jax +import pytest + +from flax import serialization +from flax.experimental import nnx + + +class TestPytree: + def test_immutable_pytree(self): + class Foo(nnx.Pytree): + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Variable(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises( + AttributeError, match='is immutable, trying to update field' + ): + pytree.x = 4 + + def test_immutable_pytree_dataclass(self): + @nnx.dataclass(frozen=True) + class Foo(nnx.Pytree): + y: int = nnx.treenode_field() + x: int = nnx.field(default=2) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + with pytest.raises(AttributeError, match='cannot assign to field'): + pytree.x = 4 + + def test_jit(self): + @nnx.dataclass + class Foo(nnx.Pytree): + a: int = nnx.treenode_field() + b: int = nnx.field() + + module = Foo(a=1, b=2) + + @jax.jit + def f(m: Foo): + return m.a + m.b + + assert f(module) == 3 + + def test_flax_serialization(self): + class Bar(nnx.Pytree): + def __init__(self, a, b): + self.a = a + self.b = nnx.Variable(b) + + @nnx.dataclass + class Foo(nnx.Pytree): + bar: Bar + c: int = nnx.treenode_field() + d: int = nnx.field() + + foo: Foo = Foo(bar=Bar(a=1, b=2), c=3, d=4) + + state_dict = serialization.to_state_dict(foo) + + assert state_dict == { + 'bar': { + 'b': 2, + }, + 'c': 3, + } + + state_dict['bar']['b'] = 5 + + foo = serialization.from_state_dict(foo, state_dict) + + assert foo.bar.b == 5 + + del state_dict['bar']['b'] + + with pytest.raises(ValueError, match='Missing field'): + serialization.from_state_dict(foo, state_dict) + + state_dict['bar']['b'] = 5 + + # add unknown field + state_dict['x'] = 6 + + with pytest.raises(ValueError, match='Unknown field'): + serialization.from_state_dict(foo, state_dict) + + def test_generics(self): + T = TypeVar('T') + + class MyClass(nnx.Pytree, Generic[T]): + def __init__(self, x: T): + self.x = x + + MyClass[int] + + def test_key_paths(self): + @nnx.dataclass + class Bar(nnx.Pytree): + a: int = nnx.treenode_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class Foo(nnx.Pytree): + x: int = nnx.treenode_field(default=3) + y: int = nnx.field(default=4) + z: Bar = nnx.treenode_field(default_factory=Bar) + + foo = Foo() + + path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) + path_values = [(list(map(str, path)), value) for path, value in path_values] + + assert path_values[0] == (['.x', '.value'], 3) + assert path_values[1] == (['.z', '.value', '.a', '.value'], 1) + + def test_replace_unknown_fields_error(self): + class Foo(nnx.Pytree): + pass + + with pytest.raises(ValueError, match='Trying to replace unknown fields'): + Foo().replace(y=1) + + def test_dataclass_inheritance(self): + @nnx.dataclass + class A(nnx.Pytree): + a: int = nnx.treenode_field(default=1) + b: int = nnx.field(default=2) + + @nnx.dataclass + class B(A): + c: int = nnx.treenode_field(default=3) + + pytree = B() + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [1, 3] + + def test_pytree_with_new(self): + class A(nnx.Pytree): + def __init__(self, a): + self.a = a + + def __new__(cls, a): + return super().__new__(cls) + + pytree = A(a=1) + + pytree = jax.tree_map(lambda x: x * 2, pytree) + + def test_deterministic_order(self): + class A(nnx.Pytree): + def __init__(self, order: bool): + if order: + self.a = 1 + self.b = 2 + else: + self.b = 2 + self.a = 1 + + p1 = A(order=True) + p2 = A(order=False) + + leaves1 = jax.tree_util.tree_leaves(p1) + leaves2 = jax.tree_util.tree_leaves(p2) + + assert leaves1 == leaves2 + + +class TestMutablePytree: + def test_pytree(self): + class Foo(nnx.Pytree, mutable=True): + def __init__(self, y) -> None: + self.x = 2 + self.y = nnx.Variable(y) + + pytree = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 + + def test_no_new_fields_after_init(self): + class Foo(nnx.Pytree, mutable=True): + def __init__(self, x): + self.x = nnx.Variable(x) + + foo = Foo(x=1) + foo.x = 2 + + with pytest.raises(AttributeError, match=r'Cannot add new fields to'): + foo.y = 2 + + def test_pytree_dataclass(self): + @nnx.dataclass + class Foo(nnx.Pytree, mutable=True): + y: int = nnx.treenode_field() + x: int = nnx.field(default=2) + + pytree: Foo = Foo(y=3) + + leaves = jax.tree_util.tree_leaves(pytree) + assert leaves == [3] + + pytree = jax.tree_map(lambda x: x * 2, pytree) + assert pytree.x == 2 + assert pytree.y == 6 + + pytree = pytree.replace(x=3) + assert pytree.x == 3 + assert pytree.y == 6 + + # test mutation + pytree.x = 4 + assert pytree.x == 4 diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py new file mode 100644 index 0000000000..adede26d52 --- /dev/null +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -0,0 +1,180 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from flax.experimental import nnx +from flax.experimental.nnx.nnx.rnglib import _stable_hash + + +class TestRngs: + def test_hash(self): + _hash = _stable_hash('hi') + assert isinstance(_hash, int) + + def test_call(self): + rngs = nnx.Rngs(0) + key = rngs() + + def test_fallback(self): + rngs = nnx.Rngs(0) + key = rngs.dropout() + + def test_fallback_error_no_default(self): + rngs = nnx.Rngs(some_name=0) + with pytest.raises(ValueError, match='No RNG named'): + key = rngs.dropout() + + def test_rng_stream(self): + key0 = jax.random.key(0) + rngs = nnx.Rngs(params=key0) + assert rngs._rngs['params'].counts[-1] == 0 + + key1 = rngs.params() + assert rngs._rngs['params'].counts[-1] == 1 + assert rngs._rngs['params'].key is key0 + assert not np.equal(key0, key1).all() + + key2 = rngs.params() + assert rngs._rngs['params'].counts[-1] == 2 + assert rngs._rngs['params'].key is key0 + assert not np.equal(key1, key2).all() + + def test_rng_fork(self): + key0 = jax.random.key(0) + rngs1 = nnx.Rngs(params=key0) + rngs2 = nnx.Rngs(rngs1.fork()) + + assert rngs2._rngs['params'].counts == [0, 0] + + key1 = rngs1.params() + key2 = rngs2.params() + + assert not np.equal(key1, key2).all() + + def test_rng_trace_level_constraints(self): + rngs = nnx.Rngs(0) + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match='Cannot use Rngs from a different trace level', + ): + rngs.params() + + f() + + @jax.jit + def f(): + with pytest.raises( + nnx.TraceContextError, + match='Cannot use Rngs from a different trace level', + ): + rngs.fork() + + f() + + rngs1: Any = None + + @jax.jit + def g(): + nonlocal rngs1 + rngs1 = nnx.Rngs(1) + + g() + + assert isinstance(rngs1, nnx.Rngs) + with pytest.raises( + nnx.TraceContextError, + match='Cannot use Rngs from a different trace level', + ): + rngs1.params() + + def test_partition_merge(self): + rngs = nnx.Rngs(dropout=0) + + keys = rngs.fork() + + assert 'dropout' in keys + assert keys['dropout'].counts == [0, 0] + + rngs2 = nnx.Rngs(keys) + + key1 = rngs.dropout() + key2 = rngs2.dropout() + assert not np.equal(key1, key2).all() + + rngs3 = nnx.Rngs(keys) + key3 = rngs3.dropout() + assert np.equal(key2, key3).all() + + def test_fork_broadcast(self): + rngs = nnx.Rngs(params=0, dropout=1) + jax.random.key + + keys = rngs.fork() # all broadcast + + assert keys['params'].key.shape == () + assert keys['dropout'].key.shape == () + assert jnp.allclose(keys['params'].key, jax.random.key(0)) + assert jnp.allclose(keys['dropout'].key, jax.random.key(1)) + + def test_fork_split(self): + rngs = nnx.Rngs(params=0, dropout=1) + keys = rngs.fork(4) # split all + + assert keys['params'].key.shape == (4,) + assert keys['dropout'].key.shape == (4,) + + def test_fork_split_and_broadcast(self): + rngs = nnx.Rngs(params=0, dropout=1) + splits, broadcasts = rngs.fork(params=4, dropout=None) + + assert splits['params'].key.shape == (4,) + assert broadcasts['dropout'].key.shape == () + + def test_fork_filters(self): + rngs = nnx.Rngs(params=0, dropout=1) + splits, broadcasts = rngs.fork({'params': 4}) + + assert splits['params'].key.shape == (4,) + assert broadcasts['dropout'].key.shape == () + + def test_fork_multidimensional_split(self): + rngs = nnx.Rngs(params=0, dropout=1) + keys = rngs.fork((4, None, 3)) # split all + + assert keys['params'].key.shape == (4, 1, 3) + assert keys['dropout'].key.shape == (4, 1, 3) + + def test_fork_multidimensional_split_mixed(self): + rngs = nnx.Rngs(params=0, dropout=1) + splits, broadcasts = rngs.fork(params=(4, None, 3)) # split all + + assert splits['params'].key.shape == (4, 1, 3) + assert broadcasts['dropout'].key.shape == () + + def test_rng_stream_pytree(self): + rngs = nnx.Rngs(params=0, dropout=1) + stream = rngs.fork()['params'] + + stream2 = jax.tree_map(lambda x: x, stream) + + assert stream.key is stream2.key diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py new file mode 100644 index 0000000000..96d9cd065c --- /dev/null +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -0,0 +1,75 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +import optax +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec + +from flax.experimental import nnx + + +class TestSPMD: + @jtu.skip_on_devices('cpu', 'gpu') + def test_init(self): + class Foo(nnx.Module): + def __init__(self): + self.w = nnx.Param( + nnx.with_partitioning( + lambda: jnp.ones((8, 2)), + sharding=('model', 'data'), + )() + ) + + def __call__(self, x): + return x @ self.w + + @jax.jit + def create_module(): + return Foo().split() + + mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ('model', 'data')) + + with mesh: + m: Foo = nnx.merge(create_module()) + + assert m.w.shape == (8, 2) + assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) + + def test_get_partition_spec(self): + class Foo(nnx.Module): + def __init__(self): + self.w = nnx.Param( + nnx.with_partitioning( + lambda: jnp.ones((8, 2)), + sharding=('row', 'col'), + )() + ) + + def __call__(self, x): + return x @ self.w + + params, moduledef = Foo().split() + state = nnx.TrainState( + moduledef, + params=params, + tx=optax.adam(1e-3), + ) + state_spec = nnx.get_partition_spec(state) + + assert state_spec.params['w'] == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].mu['w'] == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].nu['w'] == PartitionSpec('row', 'col') diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py new file mode 100644 index 0000000000..ff23c82fac --- /dev/null +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -0,0 +1,656 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp +from functools import partial + +import jax +import jax.numpy as jnp +import pytest + +from flax.experimental import nnx + + +class TestJIT: + def test_jit(self): + m = nnx.Dict(a=nnx.Param(1)) + + @nnx.jit + def g(m: nnx.Dict): + m.a = 2 + return 1.0 + + out = g(m) + + assert m.a == 2 + assert out == 1.0 + + def test_jit_on_init(self): + n = 0 + + class Foo(nnx.Module): + @partial(nnx.jit, static_argnums=(1, 2)) + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + nonlocal n + n += 1 + + key = rngs.params() + self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) + self.din = din + self.dout = dout + + m = Foo(2, 3, rngs=nnx.Rngs(0)) + assert n == 1 + assert m.w.shape == (2, 3) + assert m.din == 2 + assert m.dout == 3 + assert isinstance(m.din, int) + assert isinstance(m.dout, int) + assert isinstance(m.w, jax.Array) + + m = Foo(2, 3, rngs=nnx.Rngs(0)) + assert n == 1 + + def test_jit_on_call(self): + n = 0 + + class Foo(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) + self.din = din + self.dout = dout + + @nnx.jit + def __call__(self, x: jax.Array) -> jax.Array: + nonlocal n + n += 1 + return jnp.dot(x, self.w) + + m = Foo(2, 3, rngs=nnx.Rngs(0)) + assert m.w.shape == (2, 3) + assert m.din == 2 + assert m.dout == 3 + assert isinstance(m.din, int) + assert isinstance(m.dout, int) + assert isinstance(m.w, jax.Array) + + y = m(jnp.ones((1, 2))) + assert y.shape == (1, 3) + assert n == 1 + y = m(jnp.ones((1, 2))) + assert n == 1 + + def test_jit_combinator(self): + n = 0 + + class Foo(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) + self.din = din + self.dout = dout + + @nnx.jit + def __call__(self, x: jax.Array) -> jax.Array: + nonlocal n + n += 1 + return jnp.dot(x, self.w) + + m = nnx.JIT(Foo)(2, 3, rngs=nnx.Rngs(0)) + + y = m(jnp.ones((1, 2))) + assert y.shape == (1, 3) + assert n == 1 + y = m(jnp.ones((1, 2))) + assert n == 1 + + +class TestGrad: + def test_grad(self): + p1 = nnx.Param(10.0) + p2 = nnx.Param(20.0) + + m = nnx.Dict( + a=nnx.Sequence([p1, p2]), + b=p1, + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m['a'][0] + m['a'][1] + m['b'] + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads['a/0'] == 1.0 + assert isinstance(grads.variables['a/0'], nnx.Variable) + assert grads['a/1'] == 1.0 + assert isinstance(grads.variables['a/1'], nnx.Variable) + assert grads['b'] == 1.0 + assert isinstance(grads.variables['b'], nnx.Variable) + assert len(grads) == 3 + + m.update(grads) + + assert m['a'][0] == 1.0 + assert m['a'][1] == 1.0 + assert m['b'] == 1.0 + assert m['c'] == 7 + assert m['d'] == 5.0 + + def test_grad_with_multiple_ref_types(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @nnx.grad + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads['a/0'] == 1.0 + assert isinstance(grads.variables['a/0'], nnx.Param) + assert len(grads) == 2 + + m.update(grads) + + assert m.a[0] == 1.0 + assert m.a[1] == 20.0 + assert m.b == 1.0 + assert m.c == 7 + assert m.d == 5.0 + + def test_grad_with_type_predicate(self): + m = nnx.Dict( + a=nnx.Sequence([nnx.Param(10.0), nnx.BatchStat(20.0)]), + b=nnx.Param(10.0), + c=7, + d=5.0, + ) + + @partial(nnx.grad, wrt=nnx.BatchStat) + def f(m: nnx.Dict): + # sum all params + return m.a[0] + m.a[1] + m.b + + grads = f(m) + + assert isinstance(grads, nnx.State) + assert grads['a/1'] == 1.0 + assert isinstance(grads.variables['a/1'], nnx.BatchStat) + assert len(grads) == 1 + + m.update(grads) + + assert m.a[0] == 10.0 + assert m.a[1] == 1.0 + assert m.b == 10.0 + assert m.c == 7 + assert m.d == 5.0 + + +class TestScan: + def test_basic(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, out = module(x) + + assert y.shape == (1, 3) + assert out is None + + def test_no_scan_output(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array): + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + scan_output=False, + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y = module(x) + + assert y.shape == (1, 3) + + def test_out_axes(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array): + x = self.linear(x) + x = nnx.gelu(x) + return x, (x, x) + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + out_axes=(1, 2), + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + c, (y1, y2) = module(x) + + assert c.shape == (1, 3) + assert y1.shape == (1, 5, 3) + assert y2.shape == (1, 3, 5) + + def test_in_axes(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__( + self, x: jax.Array, a: jax.Array + ) -> tp.Tuple[jax.Array, None]: + assert x.shape == a.shape + x = x + a + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + a = jnp.ones((5, 1, 3)) + y, out = module(x, a) + + assert y.shape == (1, 3) + assert out is None + + def test_in_axes_broadcast(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__( + self, x: jax.Array, a: jax.Array, b: jax.Array + ) -> tp.Tuple[jax.Array, None]: + assert x.shape == a.shape + assert x.shape == b.shape + x = x + a + b + x = self.linear(x) + x = nnx.gelu(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + in_args_axes=(0, None), + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + a = jnp.ones((5, 1, 3)) + b = jnp.ones((1, 3)) + y, out = module(x, a, b) + + assert y.shape == (1, 3) + assert out is None + + def test_complex(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x, rngs=rngs) + x = nnx.gelu(x) + return x + + MLP = nnx.Scan( + Block, variable_axes={nnx.Param: 0}, length=5, scan_output=False + ) + + module = MLP(rngs=nnx.Rngs(0)) + + assert module.scan_module.linear.kernel.shape == (5, 3, 3) + assert module.scan_module.linear.bias.shape == (5, 3) + assert module.scan_module.node.shape == (2,) + + x = jnp.ones((1, 3)) + with nnx.flags(deterministic=False, use_running_average=False): + y = module(x, rngs=nnx.Rngs(1)) + + assert y.shape == (1, 3) + + def test_complex_decorator(self): + scan_over_layers = partial( + nnx.scan, + variable_axes={nnx.Param: 0}, + length=5, + ) + + class Block(nnx.Module): + @scan_over_layers + def __init__(self, *, rngs: nnx.Rngs): + self.d = 3 + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5) + self.node = nnx.Variable(jnp.ones((2,))) + + @scan_over_layers + def __call__( + self, x: jax.Array, _, *, rngs: nnx.Rngs + ) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x, rngs=rngs) + x = nnx.gelu(x) + return x, None + + module = Block(rngs=nnx.Rngs(0)) + + assert module.d == 3 + assert module.linear.kernel.shape == (5, 3, 3) + assert module.linear.bias.shape == (5, 3) + assert module.node.shape == (2,) + + x = jnp.ones((1, 3)) + with nnx.flags(deterministic=False, use_running_average=False): + y, out = module(x, None, rngs=nnx.Rngs(dropout=1)) + + assert y.shape == (1, 3) + assert out is None + + def test_scan_with_sharding(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), + sharding=('din', 'dout'), + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros(), + sharding=('dout',), + ), + rngs=rngs, + ) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + + # test sharding layer axes is not present inside scan + variables = self.linear.get_state().variables + assert variables['kernel'].value.shape == (3, 3) + assert variables['kernel'].sharding == ('din', 'dout') + assert variables['bias'].value.shape == (3,) + assert variables['bias'].sharding == ('dout',) + + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + scan_metadata={nnx.PARTITION_NAME: 'layers'}, + ) + + m = MLP(rngs=nnx.Rngs(0)) + + # test sharding layers axes is set + variables = m.get_state().variables + assert variables['scan_module/linear/kernel'].value.shape == (5, 3, 3) + assert variables['scan_module/linear/kernel'].sharding == ( + 'layers', + 'din', + 'dout', + ) + assert variables['scan_module/linear/bias'].value.shape == (5, 3) + assert variables['scan_module/linear/bias'].sharding == ('layers', 'dout') + + x = jnp.ones((1, 3)) + y, out = m(x, None) + + # test sharding axes is preserved + variables = m.get_state().variables + assert variables['scan_module/linear/kernel'].value.shape == (5, 3, 3) + assert variables['scan_module/linear/kernel'].sharding == ( + 'layers', + 'din', + 'dout', + ) + assert variables['scan_module/linear/bias'].value.shape == (5, 3) + assert variables['scan_module/linear/bias'].sharding == ('layers', 'dout') + + def test_type_error_less_than_one_args(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self): + return None, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + ) + + mlp = MLP(rngs=nnx.Rngs(0)) + + with pytest.raises( + TypeError, match='Expected at least 1 positional argument' + ): + mlp() + + def test_value_error_positional_argument_type_context(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + MLP = nnx.Scan( + Block, + variable_axes={nnx.Param: 0}, + length=5, + ) + + with pytest.raises( + ValueError, match='Rngs must be passed as a keyword argument named' + ): + MLP(nnx.Rngs(0)) + + +class TestRemat: + def test_basic_remat(self): + RematLinear = nnx.Remat(nnx.Linear) + + module = RematLinear(2, 3, rngs=nnx.Rngs(0)) + + y = module(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_decorator(self): + class RematLinear(nnx.Module): + @nnx.remat + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + + @nnx.remat + def __call__(self, x: jax.Array) -> jax.Array: + return self.linear(x) + + module = RematLinear(2, 3, rngs=nnx.Rngs(0)) + + y = module(jnp.ones((1, 2))) + + assert y.shape == (1, 3) + + def test_remat_with_scan(self): + class LinearBlock(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + RematLinear = nnx.Remat(LinearBlock) + + ScanRematLinear = nnx.Scan( + RematLinear, + variable_axes={nnx.Param: 0}, + length=5, + ) + + m = ScanRematLinear(rngs=nnx.Rngs(0)) + + assert m.scan_module.remat_module.linear.kernel.shape == (5, 3, 3) + assert m.scan_module.remat_module.linear.bias.shape == (5, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + def test_remat_with_scan_decorator(self): + scan = partial( + nnx.scan, + variable_axes={nnx.Param: 0}, + length=5, + ) + + class ScanLinear(nnx.Module): + @scan + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + @scan + @nnx.remat + def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + return x, None + + m = ScanLinear(rngs=nnx.Rngs(0)) + + assert m.linear.kernel.shape == (5, 3, 3) + assert m.linear.bias.shape == (5, 3) + + y, _ = m(jnp.ones((1, 3)), None) + assert y.shape == (1, 3) + + +class TestVmap: + def test_basic(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = nnx.Vmap(Block, variable_axes={nnx.Param: 0}, axis_size=5) + + module = MLP(rngs=nnx.Rngs(0)) + + assert not jnp.allclose( + module.vmap_module.linear.kernel[0], + module.vmap_module.linear.kernel[1], + ) + assert module.vmap_module.linear.kernel.shape == (5, 3, 3) + assert module.vmap_module.linear.bias.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + y = module(x) + + assert y.shape == (5, 1, 3) diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/experimental/nnx/tests/test_variable.py new file mode 100644 index 0000000000..9d41e97386 --- /dev/null +++ b/flax/experimental/nnx/tests/test_variable.py @@ -0,0 +1,33 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import jax + +from flax.experimental import nnx + +A = tp.TypeVar('A') + + +class TestVariable: + def test_value(self): + r1 = nnx.Variable(1) + assert r1.value == 1 + + r2 = jax.tree_map(lambda x: x + 1, r1) + + assert r1.value == 1 + assert r2.value == 2 + assert r1 is not r2 diff --git a/flax/linen/kw_only_dataclasses.py b/flax/linen/kw_only_dataclasses.py index c2c7f32cf6..d2d617d0c3 100644 --- a/flax/linen/kw_only_dataclasses.py +++ b/flax/linen/kw_only_dataclasses.py @@ -56,7 +56,7 @@ class that defines a field with a default, and a subclass that defines a field from types import MappingProxyType from typing import Any, TypeVar -from typing_extensions import dataclass_transform +import typing_extensions as tpe import flax @@ -101,7 +101,7 @@ def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): return dataclasses.field(metadata=metadata, **kwargs) -@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] +@tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] def dataclass(cls=None, extra_fields=None, **kwargs): """Wrapper for dataclasses.dataclass that adds support for kw_only fields. diff --git a/flax/linen/module.py b/flax/linen/module.py index 37175e7df8..1a0c1edebf 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -43,7 +43,7 @@ import jax import jax.numpy as jnp -from typing_extensions import Protocol, dataclass_transform +import typing_extensions as tpe import flax import flax.linen as nn @@ -785,7 +785,7 @@ def __set__(self, obj, value): object.__setattr__(obj, '_parent_ref', maybe_weak) -class Descriptor(Protocol): +class Descriptor(tpe.Protocol): __isabstractmethod__: bool def __get__(self, obj, objtype=None) -> Any: @@ -873,7 +873,7 @@ def module_field(*, kw_only: bool = False, default: Optional[Any] = ...) -> Any: # * Other attributes are annotated for completeness. Because we are using # the `if typing.TYPE_CHECKING` pattern, these annotations are not present # at runtime so they don't affect the dataclass behavior. -@dataclass_transform(field_specifiers=(module_field,)) # type: ignore[literal-required] +@tpe.dataclass_transform(field_specifiers=(module_field,)) # type: ignore[literal-required] class ModuleBase: if typing.TYPE_CHECKING: scope: Optional[Scope] @@ -997,9 +997,9 @@ def _customized_dataclass_transform(cls, kw_only: bool): if tuple(sys.version_info)[:3] >= (3, 10, 0): for ( name, - annotation, + annotation, # pytype: disable=invalid-annotation default, - ) in extra_fields: # pytype: disable=invalid-annotation + ) in extra_fields: setattr(cls, name, default) cls.__annotations__[name] = annotation dataclasses.dataclass( # type: ignore[call-overload] diff --git a/pyproject.toml b/pyproject.toml index 3ac7851091..34d4263e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,10 @@ module = [ "yaml", ] ignore_missing_imports = true +# exclude nnx +[[tool.mypy.overrides]] +module = "flax.experimental.nnx.*" +ignore_errors = true [tool.pytest.ini_options] filterwarnings = [ @@ -134,6 +138,10 @@ filterwarnings = [ "ignore:.*jax.config.define_bool_state is deprecated.:DeprecationWarning", # pytest-cov uses a deprecated feature of pytest-xdist. (2023-11-06) "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", + # DeprecationWarning: jax.random.KeyArray is deprecated. + "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", + # DeprecationWarning: jax.core.Shape is deprecated. + "ignore:.*jax.core.Shape is deprecated.*:DeprecationWarning", ] [tool.coverage.report] diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 2f015ba3e1..a2b016138d 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -82,7 +82,10 @@ if $RUN_DOCTEST; then # test build html sphinx-build -M html docs docs/_build -T # test docstrings - pytest -n auto flax --doctest-modules --suppress-no-test-exit-code + pytest -n auto flax \ + --doctest-modules \ + --suppress-no-test-exit-code \ + --ignore=flax/experimental/nnx fi # check that flax is running on editable mode @@ -117,19 +120,22 @@ if $RUN_PYTEST; then for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do pytest $egd done + + # Run nnx tests + pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE fi if $RUN_PYTYPE; then echo "=== RUNNING PYTYPE ===" # Validate types in library code. - pytype --jobs auto --config pyproject.toml flax/ + pytype --jobs auto --config pyproject.toml flax/ --exclude flax/experimental/nnx # Validate types in examples. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do # use cd to make sure pytype cache lives in example dir and doesn't name clash # use *.py to avoid importing configs as a top-level import which leads to import errors # because config files use relative imports (e.g. from config import ...). - (cd $egd ; pytype --jobs auto --config ../../pyproject.toml "*.py") + (cd $egd ; pytype --jobs auto --exclude flax/experimental/nnx --config ../../pyproject.toml "*.py") done fi