From 186ca31f9c8ad37b69210cf7d246f51f7e92e158 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 14 Jan 2025 14:51:24 +0000 Subject: [PATCH] [nnx] add cache_args --- benchmarks/nnx_graph_overhead.py | 63 +- benchmarks/nnx_mlpmixer_training.py | 235 +++ benchmarks/nnx_simple_training.py | 81 +- docs_nnx/nnx_basics.ipynb | 116 +- .../nnx_toy_examples/02_lifted_transforms.py | 6 +- flax/configurations.py | 11 + flax/nnx/__init__.py | 1 + flax/nnx/bridge/variables.py | 8 +- flax/nnx/extract.py | 111 +- flax/nnx/graph.py | 1460 +++++++++++++---- flax/nnx/helpers.py | 4 + flax/nnx/nn/stochastic.py | 3 + flax/nnx/reprlib.py | 1 + flax/nnx/rnglib.py | 15 +- flax/nnx/statelib.py | 126 +- flax/nnx/tracers.py | 13 +- flax/nnx/transforms/autodiff.py | 122 +- flax/nnx/transforms/compilation.py | 68 +- flax/nnx/transforms/general.py | 6 +- flax/nnx/transforms/iteration.py | 147 +- flax/nnx/transforms/transforms.py | 3 +- flax/nnx/variablelib.py | 60 +- flax/typing.py | 6 +- flaxlib_src/CMakeLists.txt | 54 + flaxlib_src/meson.build | 14 - flaxlib_src/pyproject.toml | 17 +- .../{flaxlib.pyi => src/flaxlib/__init__.py} | 3 +- flaxlib_src/src/flaxlib/flaxlib_cpp.pyi | 25 + flaxlib_src/src/lib.cc | 300 +++- flaxlib_src/src/lib.rs | 28 - pyproject.toml | 6 + tests/nnx/bridge/wrappers_test.py | 4 +- tests/nnx/graph_utils_test.py | 195 ++- tests/nnx/module_test.py | 20 +- tests/nnx/transforms_test.py | 70 +- uv.lock | 53 +- 36 files changed, 2531 insertions(+), 924 deletions(-) create mode 100644 benchmarks/nnx_mlpmixer_training.py create mode 100644 flaxlib_src/CMakeLists.txt delete mode 100644 flaxlib_src/meson.build rename flaxlib_src/{flaxlib.pyi => src/flaxlib/__init__.py} (84%) create mode 100644 flaxlib_src/src/flaxlib/flaxlib_cpp.pyi delete mode 100644 flaxlib_src/src/lib.rs diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index 88809f7775..6d10f79e07 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -24,31 +24,52 @@ from absl import app FLAGS = flags.FLAGS -flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_enum( + 'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in' +) flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') - class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): - self.list = [ - nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), - nnx.Param(jnp.zeros((dout,))), - ] - self.dict = { - 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), - 'b': nnx.Param(jnp.zeros((dout,))), - } + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Block(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) + +class Count(nnx.Variable): + pass class MLP(nnx.Module): - def __init__(self, depth, *, rngs: nnx.Rngs): + def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ - Linear(10, 10, rngs=rngs) for _ in range(depth) + Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] + self.linear_out = Block(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count.value += 1 + x = nnx.relu(self.linear_in(x)) + for layer in self.intermediates: + x = nnx.relu(layer(x)) + x = self.linear_out(x) + return x def main(argv): @@ -63,21 +84,24 @@ def main(argv): X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - model = MLP(depth=depth, rngs=nnx.Rngs(0)) - tx = optax.sgd(1e-3) - optimizer = nnx.Optimizer(model, tx) - #------------------------------------------------------------ # NNX #------------------------------------------------------------ if mode in ['all', 'nnx']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + @nnx.jit def step_nnx(model: MLP, optimizer: nnx.Optimizer): pass + cached_step_nnx = nnx.cache_args(step_nnx, model, optimizer) + t0 = time() for _ in range(total_steps): - step_nnx(model, optimizer) + cached_step_nnx() total_time = time() - t0 time_per_step = total_time / total_steps @@ -93,6 +117,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): #------------------------------------------------------------ if mode in ['all', 'jax']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + @jax.jit def step_jax(graphdef, state): return graphdef, state diff --git a/benchmarks/nnx_mlpmixer_training.py b/benchmarks/nnx_mlpmixer_training.py new file mode 100644 index 0000000000..68d5e79734 --- /dev/null +++ b/benchmarks/nnx_mlpmixer_training.py @@ -0,0 +1,235 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +from functools import partial +import jax +import jax.numpy as jnp +from flax import nnx +import optax +import numpy as np +from einop import einop +from time import time +from tqdm import tqdm + +from flax import nnx + +from absl import flags +from absl import app + +FLAGS = flags.FLAGS +flags.DEFINE_enum( + 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' +) +flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') +flags.DEFINE_integer('batch_size', 32, 'Batch size') +flags.DEFINE_integer('width', 32, 'Hidden layer size') +flags.DEFINE_integer('depth', 4, 'Depth of the model') + + +class MlpBlock(nnx.Module): + def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs): + self.din, self.mlp_dim = din, mlp_dim + self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs) + self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs) + + def __call__(self, x): + return self.linear_out(nnx.gelu(self.linear_in(x))) + + +class MixerBlock(nnx.Module): + def __init__( + self, + tokens_mlp_dim: int, + channels_mlp_dim: int, + hidden_dim: int, + rngs: nnx.Rngs, + ): + self.tokens_mlp_dim = tokens_mlp_dim + self.channels_mlp_dim = channels_mlp_dim + self.hidden_dim = hidden_dim + self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs) + self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs) + self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) + self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) + + def __call__(self, x): + y = self.ln1(x) + y = y.swapaxes(1, 2) + y = self.token_mixing(y) + y = y.swapaxes(1, 2) + x = x + y + y = self.ln2(x) + return x + self.channel_mixing(y) + + +class MlpMixer(nnx.Module): + def __init__( + self, + din: int, + kernel_size: tuple[int, int], + strides: tuple[int, int], + num_blocks: int, + hidden_dim: int, + tokens_mlp_dim: int, + channels_mlp_dim: int, + rngs: nnx.Rngs, + ): + self.din = din + self.kernel_size = kernel_size + self.num_blocks = num_blocks + self.hidden_dim = hidden_dim + self.tokens_mlp_dim = tokens_mlp_dim + self.channels_mlp_dim = channels_mlp_dim + self.stem = nnx.Conv( + din + 1, + channels_mlp_dim, + kernel_size=kernel_size, + strides=strides, + rngs=rngs, + ) + self.blocks = [ + MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs) + for _ in range(num_blocks) + ] + self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) + self.conv_t = nnx.ConvTranspose( + channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs + ) + + def __call__(self, *, x, t): + # add time feature to input + t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1) + x = jnp.concatenate([x, t], axis=-1) + # create patches + x = self.stem(x) + h, w = x.shape[1], x.shape[2] + x = einop(x, 'n h w c -> n (h w) c') + # apply blocks + for block in self.blocks: + x = block(x) + x = self.pre_head_layer_norm(x) + # recreate image + x = einop(x, 'n (h w) c -> n h w c', h=h, w=w) + x = self.conv_t(x) + return x + + +def main(argv): + print(argv) + mode: str = FLAGS.mode + total_steps: int = FLAGS.total_steps + batch_size: int = FLAGS.batch_size + width: int = FLAGS.width + depth: int = FLAGS.depth + + print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') + + X = np.random.uniform(size=(batch_size, 28, 28, 1)) + + if mode == 'nnx' or mode == 'all': + rngs = nnx.Rngs(0) + flow = MlpMixer( + din=1, + kernel_size=(2, 2), + strides=(2, 2), + num_blocks=4, + hidden_dim=512, + tokens_mlp_dim=196, + channels_mlp_dim=512, + rngs=rngs, + ) + optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4)) + t0 = time() + + mse = lambda a, b: jnp.mean((a - b) ** 2) + + @nnx.jit(donate_argnums=(0, 1, 2)) + def train_step_nnx(flow, optimizer, rngs, x_1): + print('JITTING NNX') + x_0 = jax.random.normal(rngs(), x_1.shape) + t = jax.random.uniform(rngs(), (len(x_1),)) + + x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) + dx_t = x_1 - x_0 + + loss, grads = nnx.value_and_grad( + lambda flow: mse(flow(x=x_t, t=t), dx_t) + )(flow) + optimizer.update(grads) + return loss + + losses = [] + t0 = time() + for step in tqdm(range(total_steps), desc='NNX'): + loss = train_step_nnx(flow, optimizer, rngs, X) + losses.append(loss) + + total_time = time() - t0 + print('### NNX ###') + print(f'final loss: {losses[-1]}') + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + + if mode == 'jax' or mode == 'all': + rngs = nnx.Rngs(0) + flow = MlpMixer( + din=1, + kernel_size=(2, 2), + strides=(2, 2), + num_blocks=depth, + hidden_dim=width, + tokens_mlp_dim=196, + channels_mlp_dim=width, + rngs=rngs, + ) + optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4)) + graphdef, state = nnx.split((flow, optimizer, rngs)) + t0 = time() + + mse = lambda a, b: jnp.mean((a - b) ** 2) + + @partial(nnx.jit, donate_argnums=0) + def train_step_jax(state, x_1): + print('JITTING JAX') + flow, optimizer, rngs = nnx.merge(graphdef, state) + x_0 = jax.random.normal(rngs(), x_1.shape) + t = jax.random.uniform(rngs(), (len(x_1),)) + + x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) + dx_t = x_1 - x_0 + + loss, grads = nnx.value_and_grad( + lambda flow: mse(flow(x=x_t, t=t), dx_t) + )(flow) + optimizer.update(grads) + state = nnx.state((flow, optimizer, rngs)) + return loss, state + + losses = [] + t0 = time() + for step in tqdm(range(total_steps), desc='JAX'): + loss, state = train_step_jax(state, X) + losses.append(loss) + + nnx.update((flow, optimizer, rngs), state) + total_time = time() - t0 + print('### JAX ###') + print(f'final loss: {losses[-1]}') + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + + +if __name__ == '__main__': + app.run(main) diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py index 0cb08066fe..88195b3ffd 100644 --- a/benchmarks/nnx_simple_training.py +++ b/benchmarks/nnx_simple_training.py @@ -13,6 +13,7 @@ # limitations under the License. # %% +from functools import partial import jax import jax.numpy as jnp import numpy as np @@ -25,7 +26,9 @@ from absl import app FLAGS = flags.FLAGS -flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_enum( + 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' +) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') @@ -46,6 +49,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): def __call__(self, x): return x @ self.w + self.b +class Block(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) class Count(nnx.Variable): pass @@ -54,11 +64,11 @@ class Count(nnx.Variable): class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) - self.linear_in = Linear(din, dhidden, rngs=rngs) + self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ - Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) + Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] - self.linear_out = Linear(dhidden, dout, rngs=rngs) + self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 @@ -79,20 +89,16 @@ def main(argv): print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') - if mode not in ['nnx', 'jax']: - raise ValueError(f'Invalid mode: {mode}') - X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) - tx = optax.sgd(1e-3) - optimizer = nnx.Optimizer(model, tx) - t0 = time() - - if mode == 'nnx': + if mode == 'nnx' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() - @nnx.jit + @nnx.jit(donate_argnums=(0, 1)) def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch @@ -103,26 +109,40 @@ def loss_fn(model: MLP): grads: nnx.State = nnx.grad(loss_fn)(model) optimizer.update(grads) - @nnx.jit + @nnx.jit(donate_argnums=0) def test_step_nnx(model: MLP, batch): x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} + cached_train_step_nnx = nnx.cache_args(train_step_nnx, model, optimizer) + cached_test_step_nnx = nnx.cache_args(test_step_nnx, model) + for step, batch in enumerate(dataset(X, Y, batch_size)): - train_step_nnx(model, optimizer, batch) + cached_train_step_nnx(batch) if step % 1000 == 0: - logs = test_step_nnx(model, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") + logs = cached_test_step_nnx((X, Y)) if step >= total_steps - 1: break - else: - @jax.jit - def train_step_jax(graphdef, state, batch): + print('### NNX ###') + print(f"final loss: {logs['loss']}") + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) + + if mode == 'jax' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + + @partial(jax.jit, donate_argnums=0) + def train_step_jax(state, batch): model, optimizer = nnx.merge(graphdef, state) x, y = batch @@ -135,8 +155,8 @@ def loss_fn(model: MLP): return nnx.state((model, optimizer)) - @jax.jit - def test_step_jax(graphdef, state, batch): + @partial(jax.jit, donate_argnums=0) + def test_step_jax(state, batch): model, optimizer = nnx.merge(graphdef, state) x, y = batch y_pred = model(x) @@ -147,21 +167,22 @@ def test_step_jax(graphdef, state, batch): graphdef, state = nnx.split((model, optimizer)) for step, batch in enumerate(dataset(X, Y, batch_size)): - state = train_step_jax(graphdef, state, batch) + state = train_step_jax(state, batch) if step % 1000 == 0: - state, logs = test_step_jax(graphdef, state, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") + state, logs = test_step_jax(state, (X, Y)) if step >= total_steps - 1: break model, optimizer = nnx.merge(graphdef, state) - total_time = time() - t0 - print('total time:', total_time) - print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('### JAX ###') + print(f"final loss: {logs['loss']}") + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) if __name__ == '__main__': diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 03d0624911..bf040b98d0 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -92,7 +92,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -104,7 +104,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -190,13 +190,13 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -208,7 +208,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -263,7 +263,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -275,7 +275,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -399,84 +399,26 @@ { "data": { "text/html": [ - "
                                              MLP Summary                                               \n",
-       "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃ path                  type       BatchStat            Param                 RngState             ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│ bn                   │ BatchNorm │ mean: float32[5,32] │ bias: float32[5,32]  │                      │\n",
-       "│                      │           │ var: float32[5,32]  │ scale: float32[5,32] │                      │\n",
-       "│                      │           │                     │                      │                      │\n",
-       "│                      │           │ 320 (1.3 KB)320 (1.3 KB)         │                      │\n",
-       "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
-       "│ dropout/rngs/default │ RngStream │                     │                      │ count:               │\n",
-       "│                      │           │                     │                      │   tag: default       │\n",
-       "│                      │           │                     │                      │   value: uint32[5]   │\n",
-       "│                      │           │                     │                      │ key:                 │\n",
-       "│                      │           │                     │                      │   tag: default       │\n",
-       "│                      │           │                     │                      │   value: key<fry>[5] │\n",
-       "│                      │           │                     │                      │                      │\n",
-       "│                      │           │                     │                      │ 10 (60 B)            │\n",
-       "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
-       "│ linear1              │ Linear    │                     │ b: float32[5,32]     │                      │\n",
-       "│                      │           │                     │ w: float32[5,10,32]  │                      │\n",
-       "│                      │           │                     │                      │                      │\n",
-       "│                      │           │                     │ 1,760 (7.0 KB)       │                      │\n",
-       "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
-       "│ linear2              │ Linear    │                     │ b: float32[5,10]     │                      │\n",
-       "│                      │           │                     │ w: float32[5,32,10]  │                      │\n",
-       "│                      │           │                     │                      │                      │\n",
-       "│                      │           │                     │ 1,650 (6.6 KB)       │                      │\n",
-       "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
-       "│                           Total  320 (1.3 KB)         3,730 (14.9 KB)       10 (60 B)            │\n",
-       "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
-       "                                                                                                        \n",
-       "                                   Total Parameters: 4,060 (16.3 KB)                                    \n",
-       "
\n" + "
" ], "text/plain": [ - "\u001b[3m MLP Summary \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtype \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mBatchStat \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mRngState \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ bn │ BatchNorm │ mean: \u001b[2mfloat32\u001b[0m[5,32] │ bias: \u001b[2mfloat32\u001b[0m[5,32] │ │\n", - "│ │ │ var: \u001b[2mfloat32\u001b[0m[5,32] │ scale: \u001b[2mfloat32\u001b[0m[5,32] │ │\n", - "│ │ │ │ │ │\n", - "│ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ │\n", - "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n", - "│ dropout/rngs/default │ RngStream │ │ │ count: │\n", - "│ │ │ │ │ tag: default │\n", - "│ │ │ │ │ value: \u001b[2muint32\u001b[0m[5] │\n", - "│ │ │ │ │ key: │\n", - "│ │ │ │ │ tag: default │\n", - "│ │ │ │ │ value: \u001b[2mkey\u001b[0m[5] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m │\n", - "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n", - "│ linear1 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,32] │ │\n", - "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,10,32] │ │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ \u001b[1m1,760 \u001b[0m\u001b[1;2m(7.0 KB)\u001b[0m │ │\n", - "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n", - "│ linear2 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,10] │ │\n", - "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,32,10] │ │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ \u001b[1m1,650 \u001b[0m\u001b[1;2m(6.6 KB)\u001b[0m │ │\n", - "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n", - "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3,730 \u001b[0m\u001b[1;2m(14.9 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", - "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n", - "\u001b[1m \u001b[0m\n", - "\u001b[1m Total Parameters: 4,060 \u001b[0m\u001b[1;2m(16.3 KB)\u001b[0m\u001b[1m \u001b[0m\n" + "" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -528,7 +470,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -540,7 +482,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -589,7 +531,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -601,7 +543,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -613,7 +555,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -714,7 +656,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -726,7 +668,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -738,7 +680,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -750,7 +692,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -803,7 +745,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/examples/nnx_toy_examples/02_lifted_transforms.py b/examples/nnx_toy_examples/02_lifted_transforms.py index 9fef3adf26..f6d7455601 100644 --- a/examples/nnx_toy_examples/02_lifted_transforms.py +++ b/examples/nnx_toy_examples/02_lifted_transforms.py @@ -82,13 +82,15 @@ def test_step(model: MLP, batch): loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} +cached_train_step = nnx.cache_args(train_step, model, optimizer) +cached_test_step = nnx.cache_args(test_step, model) total_steps = 10_000 for step, batch in enumerate(dataset(32)): - train_step(model, optimizer, batch) + cached_train_step(batch) if step % 1000 == 0: - logs = test_step(model, (X, Y)) + logs = cached_test_step((X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: diff --git a/flax/configurations.py b/flax/configurations.py index ba19a572fc..5e1a492fcf 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -22,6 +22,7 @@ class Config: + flax_use_flaxlib: bool # See https://google.github.io/pytype/faq.html. _HAS_DYNAMIC_ATTRIBUTES = True @@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /): raise LookupError(f'Unrecognized config option: {name}') self._values[name] = value + def __repr__(self): + values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items()) + return f'Config({values_repr}\n)' + config = Config() @@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool): ' PRNG keys.' ), ) + +flax_use_flaxlib = bool_flag( + name='flax_use_flaxlib', + default=False, + help='Whether to use flaxlib for C++ acceleration.', +) \ No newline at end of file diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index fcb15f0608..1c0c19a46f 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -56,6 +56,7 @@ from .graph import MergeContext as MergeContext from .graph import merge_context as merge_context from .graph import variables as variables +from .graph import cache_args as cache_args from .nn import initializers as initializers from .nn.activations import celu as celu from .nn.activations import elu as elu diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 121bb98eb8..da83cd545e 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -21,7 +21,6 @@ from flax.nnx import spmd from flax.nnx import traversals from flax.nnx import variablelib as variableslib -from flax.nnx.module import GraphDef import typing as tp @@ -193,12 +192,7 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: """Convert a dict of NNX variables (or variable states) to Linen-style variables.""" linen_structured = {} - for kp, v in traversals.flatten_mapping( - nnx_attrs, - is_leaf=lambda _, x: isinstance( - x, variableslib.Variable | variableslib.VariableState | GraphDef - ), - ).items(): + for kp, v in traversals.flatten_mapping(nnx_attrs).items(): if isinstance(v, variableslib.Variable): col_name = variable_type_name(type(v)) v = to_linen_var(v.to_state()) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 191a0c195a..364177b5f5 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -13,9 +13,6 @@ # limitations under the License. import abc -import contextlib -import dataclasses -import threading import typing as tp import jax @@ -67,7 +64,7 @@ def extract_graph_nodes( | tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]] ): """Extracts all graph nodes from a pytree.""" - nodes = graph.RefMap[tp.Any, Index]() + nodes: dict[tp.Any, Index] = {} node_prefixes = [] leaves = [] @@ -134,11 +131,10 @@ def check_consistent_aliasing( prefix: tuple[tp.Any, ...], /, *, - node_prefixes: graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]] - | None = None, + node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] | None = None, ): if node_prefixes is None: - node_prefixes = graph.RefMap() + node_prefixes = {} # collect all paths and prefixes for each node for path, value in graph.iter_graph(node): @@ -181,50 +177,6 @@ def check_consistent_aliasing( + '\n'.join(node_msgs) ) - -# ----------------------------- -# broadcast -# ----------------------------- - - -@dataclasses.dataclass -class BroadcastContext(threading.local): - broadcast_state_stacks: dict[str, list[tp.Any]] = dataclasses.field( - default_factory=dict - ) - - -BROADCAST_CONTEXT = BroadcastContext() - - -@contextlib.contextmanager -def broadcast_state(tag: str, state: tp.Any): - if tag in BROADCAST_CONTEXT.broadcast_state_stacks: - stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] - else: - stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] = [] - stack.append(state) - try: - yield - finally: - stack.pop() - if not stack: - del BROADCAST_CONTEXT.broadcast_state_stacks[tag] - - -def get_broadcast_state(tag: str) -> tp.Any: - if tag not in BROADCAST_CONTEXT.broadcast_state_stacks: - raise ValueError(f'No broadcast state found for {tag!r}') - - stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] - - if not stack: - raise RuntimeError( - f'Empty broadcast state stack for {tag!r}, this is a bug' - ) - - return stack[-1] - # ----------------------------- # to_tree/from_tree # ----------------------------- @@ -251,10 +203,13 @@ class GraphDefState(struct.PyTreeNode): graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False) state: graph.GraphState = struct.field(pytree_node=True) +S = tp.TypeVar( + 'S', bound=graph.GraphState | graph.GraphFlatState | list[tp.Any] +) -class NodeStates(struct.PyTreeNode): +class NodeStates(struct.PyTreeNode, tp.Generic[S]): _graphdef: graph.GraphDef[tp.Any] | None - states: tuple[graph.GraphState, ...] + states: tuple[S, ...] metadata: tp.Any = struct.field(pytree_node=False) @property @@ -264,7 +219,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]: return self._graphdef @property - def state(self) -> graph.GraphState: + def state(self) -> S: if len(self.states) != 1: raise ValueError( f'Expected exactly one GraphDefState, got {len(self.states)}' @@ -275,15 +230,19 @@ def state(self) -> graph.GraphState: def from_split( cls, graphdef: graph.GraphDef[tp.Any], - state: graph.GraphState, + state: S, /, - *states: graph.GraphState, + *states: S, metadata: tp.Any = None, ): return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata) @classmethod - def from_states(cls, state: graph.GraphState, *states: graph.GraphState): + def from_states( + cls, + state: S, + *states: S, + ): return cls(_graphdef=None, states=(state, *states), metadata=None) @classmethod @@ -312,9 +271,18 @@ def to_tree( [graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any ] = default_split_fn, map_non_graph_nodes: bool = False, - ctxtag: str | None = None, + ctxtag: tp.Hashable | None = None, check_aliasing: bool = True, ) -> tp.Any: + if prefix is Missing or prefix is None: + # fast path, no need for prefix broadcasting or consistent aliasing checks + with graph.split_context(ctxtag) as split_ctx: + return jax.tree.map( + lambda x: split_fn(split_ctx, (), prefix, x) + if map_non_graph_nodes or graph.is_graph_node(x) + else x, + tree, + ) leaf_prefixes = broadcast_prefix( prefix, tree, @@ -324,7 +292,7 @@ def to_tree( assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] - node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() + node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] = {} with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): @@ -367,8 +335,19 @@ def from_tree( is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node, is_leaf: tp.Callable[[Leaf], bool] = is_tree_node, map_non_graph_nodes: bool = False, - ctxtag: str | None = None, + is_inner: bool | None = None, + ctxtag: tp.Hashable | None = None, ) -> tp.Any: + if prefix is Missing or prefix is None: + # fast path, no need for prefix broadcasting or consistent aliasing checks + with graph.merge_context(is_inner, ctxtag) as merge_ctx: + return jax.tree.map( + lambda x: merge_fn(merge_ctx, (), prefix, x) + if map_non_graph_nodes or is_node_leaf(x) + else x, + tree, + is_leaf=is_leaf, + ) leaf_prefixes = broadcast_prefix( prefix, tree, @@ -381,15 +360,11 @@ def from_tree( assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] - with graph.merge_context(ctxtag) as merge_ctx: + with graph.merge_context(is_inner, ctxtag) as merge_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): - if is_node_leaf(leaf): - leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) - leaves_out.append(leaf_out) - else: - if map_non_graph_nodes: - leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) - leaves_out.append(leaf) + if map_non_graph_nodes or is_node_leaf(leaf): + leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(leaf) pytree_out = jax.tree.unflatten(treedef, leaves_out) return pytree_out diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 8cc272f8eb..b1137d86a8 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -14,23 +14,26 @@ from __future__ import annotations +from collections import deque import contextlib import dataclasses import functools import threading import typing as tp +from weakref import WeakKeyDictionary +from flax import config import jax import numpy as np import typing_extensions as tpe -from flax.nnx import filterlib, reprlib, visualization +from flax.nnx import filterlib, reprlib from flax.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.nnx.statelib import State +from flax.nnx.statelib import FlatState, State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key, PathParts, is_key_like @@ -53,6 +56,7 @@ StateLeaf = VariableState[tp.Any] NodeLeaf = Variable[tp.Any] GraphState = State[Key, StateLeaf] +GraphFlatState = FlatState[StateLeaf] def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: @@ -62,37 +66,12 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: return isinstance(x, Variable) +RefMap = dict -class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin): - """A mapping that uses object id as the hash for the keys.""" - - def __init__( - self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / - ): - self._mapping: dict[int, tuple[A, B]] = {} - self.update(mapping) - - def __getitem__(self, key: A) -> B: - return self._mapping[id(key)][1] - - def __contains__(self, key: object) -> bool: - return id(key) in self._mapping - - def __setitem__(self, key: A, value: B): - self._mapping[id(key)] = (key, value) - - def __delitem__(self, key: A): - del self._mapping[id(key)] - - def __iter__(self) -> tp.Iterator[A]: - return (key for key, _ in self._mapping.values()) - - def __len__(self) -> int: - return len(self._mapping) - - def __str__(self) -> str: - return repr(self) +if not tp.TYPE_CHECKING and config.flax_use_flaxlib: + import flaxlib + RefMap = flaxlib.RefMap @dataclasses.dataclass(frozen=True, slots=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): @@ -175,9 +154,9 @@ def is_node_type(x: type[tp.Any]) -> bool: return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree -def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: +def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None: if isinstance(x, Variable): - raise ValueError(f'Variable is not a node: {x}') + return None node_type = type(x) @@ -185,19 +164,23 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: return GRAPH_REGISTRY[node_type] elif node_type in PYTREE_REGISTRY: return PYTREE_REGISTRY[node_type] - elif is_pytree_node(x): + elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple): return PYTREE_NODE_IMPL # type: ignore else: - raise ValueError(f'Unknown node type: {x}') + return None -def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: +def get_node_impl_for_type( + x: type[Node], +) -> NodeImpl[Node, tp.Any, tp.Any] | None: if x is GenericPytree: return PYTREE_NODE_IMPL # type: ignore elif x in PYTREE_REGISTRY: return PYTREE_REGISTRY[x] - else: + elif x in GRAPH_REGISTRY: return GRAPH_REGISTRY[x] + else: + return None class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): @@ -228,17 +211,8 @@ def __repr__(self) -> str: return repr(self._mapping) -class GraphDef(tp.Generic[Node]): - """A class that represents all the static, stateless, and Pythonic parts of a Flax - :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or - :func:`graphdef` on the :class:`Module`.""" - - type: type[Node] - index: int - - @dataclasses.dataclass(frozen=True, repr=False) -class NodeRef(GraphDef[Node], reprlib.Representable): +class NodeRef(tp.Generic[Node], reprlib.Representable): type: type[Node] index: int @@ -248,7 +222,8 @@ def __nnx_repr__(self): yield reprlib.Attr('index', self.index) def __treescope_repr__(self, path, subtree_renderer): - return visualization.render_object_constructor( + import treescope # type: ignore[import-not-found,import-untyped] + return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={'type': self.type, 'index': self.index}, path=path, @@ -262,16 +237,33 @@ def __treescope_repr__(self, path, subtree_renderer): class VariableDef(reprlib.Representable): type: type[Variable] index: int + outer_index: int | None metadata: HashableMapping[str, tp.Any] + def with_no_outer_index(self) -> VariableDef: + return VariableDef( + type=self.type, index=self.index, outer_index=None, metadata=self.metadata + ) + + def with_same_outer_index(self) -> VariableDef: + return VariableDef( + type=self.type, + index=self.index, + outer_index=self.index, + metadata=self.metadata, + ) + def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) + yield reprlib.Attr('outer_index', self.outer_index) yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata)) def __treescope_repr__(self, path, subtree_renderer): - return visualization.render_object_constructor( + import treescope # type: ignore[import-not-found,import-untyped] + + return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, @@ -286,71 +278,74 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(VariableDef) -@dataclasses.dataclass(frozen=True, slots=True) -class SubGraphAttribute: - key: Key - value: NodeDef[tp.Any] | NodeRef[tp.Any] - - -@dataclasses.dataclass(frozen=True, slots=True) -class StaticAttribute: - key: Key - value: tp.Any - - -@dataclasses.dataclass(frozen=True, slots=True) -class LeafAttribute: - key: Key - value: VariableDef | NodeRef[tp.Any] - - @dataclasses.dataclass(frozen=True, repr=False, slots=True) -class NodeDef(GraphDef[Node], reprlib.Representable): +class NodeDef(tp.Generic[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" type: tp.Type[Node] index: int - attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] + outer_index: int | None + attributes: tuple[ + tuple[ + Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any] + ], + ..., + ] metadata: tp.Any - index_mapping: HashableMapping[Index, Index] | None - @classmethod - def create( - cls, - type: tp.Type[Node], - index: int, - attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], - metadata: tp.Any, - index_mapping: tp.Mapping[Index, Index] | None, - ): - return cls( - type=type, - index=index, + def with_no_outer_index(self) -> NodeDef[Node]: + attributes = tuple( + ( + key, + value.with_no_outer_index() + if isinstance(value, NodeDef | VariableDef) + else value, + ) + for key, value in self.attributes + ) + return NodeDef( + type=self.type, + index=self.index, + outer_index=None, attributes=attributes, - metadata=metadata, - index_mapping=HashableMapping(index_mapping) - if index_mapping is not None - else None, + metadata=self.metadata, ) + def with_same_outer_index(self) -> NodeDef[Node]: + attributes = tuple( + ( + key, + value.with_same_outer_index() + if isinstance(value, NodeDef | VariableDef) + else value, + ) + for key, value in self.attributes + ) + return NodeDef( + type=self.type, + index=self.index, + outer_index=self.index if self.index >= 0 else None, + attributes=attributes, + metadata=self.metadata, + ) + + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) + def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) - yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes)) + yield reprlib.Attr('outer_index', self.outer_index) + yield reprlib.Attr('attributes', self.attributes) yield reprlib.Attr('metadata', self.metadata) - yield reprlib.Attr( - 'index_mapping', - reprlib.PrettyMapping(self.index_mapping) - if self.index_mapping is not None - else None, - ) def __treescope_repr__(self, path, subtree_renderer): - return visualization.render_object_constructor( + import treescope # type: ignore[import-not-found,import-untyped] + return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, @@ -373,19 +368,89 @@ def _apply( module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) - return out, flatten(module) + graphdef, flat_state = flatten(module) + state_ = State.from_flat_path(flat_state) + return out, (graphdef, state_) return CallableProxy(_apply, accessor) # type: ignore jax.tree_util.register_static(NodeDef) -PureState = tuple[GraphDef[A], GraphState] +GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]] +PureState = tuple[GraphDef[Node], GraphState] +@tp.overload +def flatten( + node: Node, + /, + *, + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ... +@tp.overload +def flatten( + node: Node, + /, + *, + with_paths: tp.Literal[True], + return_variables: tp.Literal[True], + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[ + GraphDef[Node], + FlatState[Variable[tp.Any]], +]: ... +@tp.overload def flatten( - node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None -) -> tuple[GraphDef[Node], GraphState]: + node: Node, + /, + *, + with_paths: tp.Literal[False], + return_variables: tp.Literal[True], + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[ + GraphDef[Node], + list[Variable[tp.Any]], +]: ... +@tp.overload +def flatten( + node: Node, + /, + *, + return_variables: tp.Literal[True], + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[ + GraphDef[Node], + FlatState[Variable[tp.Any]], +]: ... +@tp.overload +def flatten( + node: Node, + /, + *, + with_paths: bool, + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[ + GraphDef[Node], + FlatState[VariableState[tp.Any]] | list[tp.Any], +]: ... +def flatten( + node: Node, + /, + *, + with_paths: bool = True, + return_variables: bool = False, + ref_index: RefMap | None = None, + ref_outer_index: RefMap | None = None, +) -> tuple[ + GraphDef[Node], + FlatState[VariableState[tp.Any]] | FlatState[Variable[tp.Any]] | list[tp.Any], +]: """Flattens a graph node into a (graphdef, state) pair. Args: @@ -393,81 +458,355 @@ def flatten( ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. + with_paths: A boolean that indicates whether to return a FlatState object that includes + the paths to VariableState objects, or just a list of the Variable's inner values. """ if ref_index is None: ref_index = RefMap() - flat_state: list[tuple[PathParts, StateLeaf]] = [] - graphdef = _graph_flatten((), ref_index, flat_state, node) - return graphdef, GraphState.from_flat_path(flat_state) + + leaves: list[StateLeaf | Variable[tp.Any]] = [] + path: list[Key] | None = [] if with_paths else None + paths: list[PathParts] | None = [] if with_paths else None + node_impl = get_node_impl(node) + if node_impl is None: + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + graphdef = _graph_flatten( + node, + node_impl, + path, + ref_index, + ref_outer_index, + leaves, + paths, + return_variables, + ) + + if paths is not None: + return graphdef, FlatState.from_sorted_keys_values(tuple(paths), leaves) + else: + return graphdef, leaves def _graph_flatten( - path: PathParts, - ref_index: RefMap[tp.Any, Index], - flat_state: list[tuple[PathParts, StateLeaf]], node: Node, + node_impl: NodeImpl[Node, Leaf, AuxData], + path: list[Key] | None, + ref_index: RefMap, + ref_outer_index: RefMap | None, + leaves: list[StateLeaf | Variable[tp.Any]], + paths: list[PathParts] | None, + return_variables: bool, ) -> NodeDef[Node] | NodeRef: - if not is_node(node): - raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl) + is_graph_node_ = isinstance(node_impl, GraphNodeImpl) - if node in ref_index: + if not is_pytree_node_ and node in ref_index: return NodeRef(type(node), ref_index[node]) - node_impl = get_node_impl(node) - # only cache graph nodes - if isinstance(node_impl, GraphNodeImpl): + if is_graph_node_: index = len(ref_index) ref_index[node] = index else: index = -1 - attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] + attributes: list[ + tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] + ] = [] values, metadata = node_impl.flatten(node) for key, value in values: - if is_node(value): - nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) - # subgraphs.append((key, nodedef)) - attributes.append(SubGraphAttribute(key, nodedef)) + value_node_impl = get_node_impl(value) + if path is not None: + path.append(key) + if value_node_impl is not None: + nodedef = _graph_flatten( + value, + value_node_impl, + path, + ref_index, + ref_outer_index, + leaves, + paths, + return_variables, + ) + attributes.append((key, nodedef)) elif isinstance(value, Variable): if value in ref_index: - attributes.append( - LeafAttribute(key, NodeRef(type(value), ref_index[value])) - ) + attributes.append((key, NodeRef(type(value), ref_index[value]))) else: - flat_state.append(((*path, key), value.to_state())) + if return_variables: + leaf = value + elif path is None: + leaf = value.raw_value + else: + leaf = value.to_state() + leaves.append(leaf) + if path is not None: + assert paths is not None + paths.append(tuple(path)) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( - type(value), variable_index, HashableMapping(value._var_metadata) + type=type(value), + index=variable_index, + outer_index=ref_outer_index.get(value, None) + if ref_outer_index + else None, + metadata=HashableMapping(value._var_metadata), ) - attributes.append(LeafAttribute(key, variabledef)) + attributes.append((key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): - path_str = '/'.join(map(str, (*path, key))) - raise ValueError( + if path is not None: + path_str = '/'.join(map(str, path)) + raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' - ) + ) + else: + raise ValueError(f'Arrays leaves are not supported, found {value}') # static_fields.append((key, value)) - attributes.append(StaticAttribute(key, value)) + attributes.append((key, Static(value))) + + if path is not None: + path.pop() - nodedef = NodeDef.create( + nodedef = NodeDef( type=node_impl.type, index=index, + outer_index=ref_outer_index[node] + if is_graph_node_ and ref_outer_index and node in ref_outer_index + else None, attributes=tuple(attributes), metadata=metadata, - index_mapping=None, ) return nodedef +@dataclasses.dataclass(slots=True) +class FingerprintContext: + next_index: int + +def fingerprint( + node, + /, + *, + ref_index: RefMap | None = None, + new_ref_index: RefMap | None = None, +) -> list[tp.Hashable]: + """ """ + if ref_index is None: + ref_index = RefMap() + + if new_ref_index is None: + new_ref_index = RefMap() + node_impl = get_node_impl(node) + if node_impl is None: + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + ctx = FingerprintContext(len(ref_index) + len(new_ref_index)) + fp: list[tp.Hashable] = [] + _graph_fingerprint(ctx, fp.append, node, node_impl, ref_index, new_ref_index) + return fp + + +def _graph_fingerprint( + ctx: FingerprintContext, + append_fn: tp.Callable[[tp.Hashable], None], + node, + node_impl: NodeImpl[Node, Leaf, AuxData], + ref_index: RefMap, + new_ref_index: RefMap, +): + is_pytree_node_ = type(node_impl) is PytreeNodeImpl + is_graph_node_ = type(node_impl) is GraphNodeImpl + + append_fn(type(node)) + + if is_graph_node_: + append_fn(id(node)) + if node in ref_index: + append_fn(ref_index[node]) + return + elif node in new_ref_index: + append_fn(new_ref_index[node]) + return + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 + else: + index = -1 + + values, metadata = node_impl.flatten(node) + + append_fn(index) + append_fn(metadata) + + for key, value in values: + value_node_impl = get_node_impl(value) + append_fn(key) + if value_node_impl is not None: + _graph_fingerprint( + ctx, + append_fn, + value, + value_node_impl, + ref_index, + new_ref_index, + ) + elif isinstance(value, Variable): + append_fn(id(value)) + append_fn(type(value)) + if value in ref_index: + append_fn(ref_index[value]) + elif value in new_ref_index: + append_fn(new_ref_index[value]) + else: + variable_index = new_ref_index[value] = ctx.next_index + ctx.next_index += 1 + append_fn(variable_index) + for key_value in value._var_metadata.items(): + append_fn(key_value) + else: + if isinstance(value, (jax.Array, np.ndarray)): + raise ValueError(f'Arrays leaves are not supported: {value}') + append_fn(value) + +def check_fingerprint( + node, + fp: list[tp.Hashable], + /, + *, + ref_index: RefMap | None = None, + new_ref_index: RefMap | None = None, +) -> bool: + """ """ + if ref_index is None: + ref_index = RefMap() + + if new_ref_index is None: + new_ref_index = RefMap() + node_impl = get_node_impl(node) + if node_impl is None: + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + ctx = FingerprintContext(len(ref_index) + len(new_ref_index)) + fp_matches = _check_graph_fingerprint( + ctx, iter(fp), node, node_impl, ref_index, new_ref_index + ) + return fp_matches + + +def _check_graph_fingerprint( + ctx: FingerprintContext, + fp_iterator: tp.Iterator[tp.Hashable], + node, + node_impl: NodeImpl[Node, Leaf, AuxData], + ref_index: RefMap, + new_ref_index: RefMap, +) -> bool: + is_pytree_node_ = type(node_impl) is PytreeNodeImpl + is_graph_node_ = type(node_impl) is GraphNodeImpl + + if type(node) != next(fp_iterator): + return False + + if is_graph_node_: + # append_fn(id(node)) + if id(node) != next(fp_iterator): + return False + if node in ref_index: + # append_fn(ref_index[node]) + return ref_index[node] == next(fp_iterator) + elif node in new_ref_index: + # append_fn(new_ref_index[node]) + return new_ref_index[node] == next(fp_iterator) + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 + else: + index = -1 + + values, metadata = node_impl.flatten(node) + + # append_fn(index) + if index != next(fp_iterator): + return False + # append_fn(metadata) + if metadata != next(fp_iterator): + return False + + for key, value in values: + value_node_impl = get_node_impl(value) + # append_fn(key) + if key != next(fp_iterator): + return False + if value_node_impl is not None: + if not _check_graph_fingerprint( + ctx, + fp_iterator, + value, + value_node_impl, + ref_index, + new_ref_index, + ): + return False + elif isinstance(value, Variable): + # append_fn(id(value)) + if id(value) != next(fp_iterator): + return False + # append_fn(type(value)) + if type(value) != next(fp_iterator): + return False + if value in ref_index: + # append_fn(ref_index[value]) + if ref_index[value] != next(fp_iterator): + return False + elif value in new_ref_index: + # append_fn(new_ref_index[value]) + if new_ref_index[value] != next(fp_iterator): + return False + else: + variable_index = new_ref_index[value] = ctx.next_index + ctx.next_index += 1 + # append_fn(variable_index) + if variable_index != next(fp_iterator): + return False + for key_value in value._var_metadata.items(): + # append_fn(key_value) + if key_value != next(fp_iterator): + return False + else: + if isinstance(value, (jax.Array, np.ndarray)): + raise ValueError(f'Arrays leaves are not supported: {value}') + # append_fn(value) + if value != next(fp_iterator): + return False + + return True + + +def _get_sorted_leaves( + xs: tp.Mapping[tp.Any, tp.Any], +) -> deque[tp.Any]: + if not isinstance(xs, tp.Mapping): # type: ignore + raise TypeError(f'expected Mapping; got {type(xs).__qualname__}') + leaves = deque() + + def _flatten(xs): + if not isinstance(xs, tp.Mapping): + leaves.append(xs) + else: + for _, value in sorted(xs.items()): + _flatten(value) + + _flatten(xs) + return leaves + def unflatten( graphdef: GraphDef[Node], - state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], + state: State[KeyT, tp.Any | dict[KeyT, tp.Any]] + | FlatState[tp.Any] + | list[tp.Any], /, *, index_ref: dict[Index, tp.Any] | None = None, - index_ref_cache: dict[Index, tp.Any] | None = None, + outer_index_outer_ref: dict[Index, tp.Any] | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -484,19 +823,41 @@ def unflatten( existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ - if isinstance(state, State): - state = state.raw_mapping # type: ignore + if isinstance(state, (State, dict)): + leaves = _get_sorted_leaves(state) + elif isinstance(state, FlatState): + leaves = deque(state.leaves) + elif isinstance(state, list): # type: ignore + leaves = deque(state) + else: + raise ValueError(f'Unsupported state type: {type(state)}') if index_ref is None: index_ref = {} - assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) + + if isinstance(graphdef, NodeRef): + node = index_ref[graphdef.index] + else: + assert isinstance(graphdef, NodeDef) + node_impl = get_node_impl_for_type(graphdef.type) + if node_impl is None: + raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') + node = _graph_unflatten( + graphdef, node_impl, leaves, index_ref, outer_index_outer_ref + ) + if leaves: + raise ValueError( + f'Incorrect number of leaves: got an extra {len(leaves)} leaves in the state' + ) + return node + def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], - state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], + node_impl: NodeImpl[Node, Leaf, AuxData], + leaves: deque[tp.Any], index_ref: dict[Index, tp.Any], - index_ref_cache: dict[Index, tp.Any] | None, + outer_index_outer_ref: dict[Index, tp.Any] | None, ) -> Node: """Recursive helper for graph_unflatten. @@ -511,134 +872,82 @@ def _graph_unflatten( existing graph nodes are mutated to have the new content/topology specified by the nodedef. """ - if isinstance(nodedef, NodeRef): + if type(nodedef) is NodeRef: return index_ref[nodedef.index] - if not is_node_type(nodedef.type): - raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') - if nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') - node_impl = get_node_impl_for_type(nodedef.type) - def _get_children(): children: list[tuple[Key, NodeLeaf | Node]] = [] - state_keys: set = set(state.keys()) - - # for every key in attributes there are 6 possible cases: - # - (2) the key can either be present in the state or not - # - (3) the key can be a subgraph, a leaf, or a static attribute - for attribute in nodedef.attributes: - key = attribute.key - if key not in state: - # if key is not present create an empty types - if type(attribute) is StaticAttribute: - children.append((key, attribute.value)) - elif type(attribute) is SubGraphAttribute: - # if the key is a subgraph we create an empty node - subgraphdef = attribute.value - assert not isinstance(subgraphdef, VariableDef) - if isinstance(subgraphdef, NodeRef): - # subgraph exists, take it from the cache - children.append((key, index_ref[subgraphdef.index])) - else: - # create a node from an empty state, reasoning: - # * its a node with no state - # * its a node with state but only through references of already - # created nodes - substate = {} - subnode = _graph_unflatten( - subgraphdef, substate, index_ref, index_ref_cache - ) - children.append((key, subnode)) - elif type(attribute) is LeafAttribute: - variabledef = attribute.value - if variabledef.index in index_ref: - # variable exists, take it from the cache - children.append((key, index_ref[variabledef.index])) - else: - # key for a variable is missing, raise an error + + assert type(nodedef) is NodeDef + for key, value in nodedef.attributes: + if type(value) is Static: + children.append((key, value.value)) + elif type(value) is NodeRef: + children.append((key, index_ref[value.index])) + elif type(value) is NodeDef: + # if the key is a subgraph we create an empty node + subgraphdef = value + value_node_impl = get_node_impl_for_type(subgraphdef.type) + assert value_node_impl is not None + subnode = _graph_unflatten( + subgraphdef, value_node_impl, leaves, index_ref, outer_index_outer_ref + ) + children.append((key, subnode)) + elif type(value) is VariableDef: + variabledef = value + if not leaves: + raise ValueError('Not enough leaves to unflatten the graph') + # its a unseen variable, create a new one + value = leaves.popleft() + # when idxmap is present, check if the Varable exists there + # and update existing variables if it does + if ( + outer_index_outer_ref is not None + and variabledef.outer_index in outer_index_outer_ref + ): + # if variable exists, update it + variable = outer_index_outer_ref[variabledef.outer_index] + if not isinstance(variable, Variable): raise ValueError( - f'Expected key {key!r} in state while building node of type ' - f'{nodedef.type.__name__}.' + f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - else: - raise RuntimeError(f'Unknown static field: {key!r}') - else: - state_keys.remove(key) - value = state[key] - # if key in nodedef.static_fields: - if type(attribute) is StaticAttribute: - raise ValueError( - f'Got state for static field {key!r}, this is not supported.' - ) - elif type(attribute) is SubGraphAttribute: - if is_state_leaf(value): + elif isinstance(value, Variable): raise ValueError( - f'Expected value of type {attribute.value} for ' - f'{key!r}, but got {value!r}' + f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' + f'Got {value!r} for {key!r}.' ) - assert isinstance(value, dict) - subgraphdef = attribute.value - - if isinstance(subgraphdef, NodeRef): - children.append((key, index_ref[subgraphdef.index])) + elif isinstance(value, VariableState): + variable.update_from_state(value) else: - subnode = _graph_unflatten( - subgraphdef, value, index_ref, index_ref_cache - ) - children.append((key, subnode)) - - elif type(attribute) is LeafAttribute: - variabledef = attribute.value - - if variabledef.index in index_ref: - # add an existing variable - assert isinstance(variabledef, NodeRef) - children.append((key, index_ref[variabledef.index])) + variable.raw_value = value + else: # variabledef.index not in index_ref_cache + # variable reference does not exist outside, create a new one + if isinstance(value, Variable): + variable = value + elif isinstance(value, VariableState): + variable = value.to_variable() else: - # its a unseen variable, create a new one - assert isinstance(variabledef, VariableDef) - # when idxmap is present, check if the Varable exists there - # and update existing variables if it does - if ( - index_ref_cache is not None - and variabledef.index in index_ref_cache - ): - # if variable exists, update it - variable = index_ref_cache[variabledef.index] - if not isinstance(variable, Variable): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(variable)}.' - ) - if isinstance(value, VariableState): - variable.update_from_state(value) - else: - variable.raw_value = value - else: # if it doesn't, create a new variable - if isinstance(value, VariableState): - variable = value.to_variable() - else: - variable = variabledef.type.from_metadata( - value, variabledef.metadata - ) - children.append((key, variable)) - index_ref[variabledef.index] = variable - else: - raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') - - # NOTE: we could allw adding new StateLeafs here - if state_keys: - raise ValueError(f'Unknown keys: {state_keys}') + variable = variabledef.type.from_metadata( + value, variabledef.metadata + ) + children.append((key, variable)) + index_ref[variabledef.index] = variable + else: + raise RuntimeError(f'Unknown static field: {key!r}') return children if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle - if index_ref_cache is not None and nodedef.index in index_ref_cache: - node = index_ref_cache[nodedef.index] + if ( + outer_index_outer_ref is not None + and nodedef.outer_index in outer_index_outer_ref + ): + node = outer_index_outer_ref[nodedef.outer_index] if type(node) != nodedef.type: raise ValueError( f'Expected a node of type {nodedef.type} for index ' @@ -765,26 +1074,154 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): # updated from raw value current_value.raw_value = value + # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- + +class DynamicCache(tp.NamedTuple): + fingerprint: list[tp.Hashable] + graphdef: GraphDef[tp.Any] + final_graphdef: GraphDef[tp.Any] + paths: tuple[PathParts, ...] + variables: list[Variable[tp.Any]] + new_index_ref: dict[Index, tp.Any] + + @staticmethod + def create( + fingerprint: list[tp.Hashable], + graphdef: GraphDef[tp.Any], + paths: tuple[PathParts, ...], + variables: list[Variable[tp.Any]], + new_ref_index: RefMap, + ): + new_index_ref = {index: obj for obj, index in new_ref_index.items()} + if type(graphdef) is NodeDef: + final_graphdef = graphdef.with_same_outer_index() + else: + final_graphdef = graphdef + return DynamicCache( + fingerprint=fingerprint, + graphdef=graphdef, + final_graphdef=final_graphdef, + paths=paths, + variables=variables, + new_index_ref=new_index_ref, + ) + +class StaticCache(tp.NamedTuple): + graphdef: GraphDef[tp.Any] + final_graphdef: GraphDef[tp.Any] + paths: tuple[PathParts, ...] + variables: list[Variable[tp.Any]] + new_ref_index: RefMap + new_index_ref: dict[Index, tp.Any] + + @staticmethod + def create( + graphdef: GraphDef[tp.Any], + paths: tuple[PathParts, ...], + variables: list[Variable[tp.Any]], + new_ref_index: RefMap, + ): + new_index_ref = {index: obj for obj, index in new_ref_index.items()} + if type(graphdef) is NodeDef: + final_graphdef = graphdef.with_same_outer_index() + else: + final_graphdef = graphdef + return StaticCache( + graphdef=graphdef, + final_graphdef=final_graphdef, + paths=paths, + variables=variables, + new_ref_index=new_ref_index, + new_index_ref=new_index_ref, + ) + @dataclasses.dataclass class GraphContext(threading.local): - update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field( - default_factory=dict + update_context_stacks: dict[tp.Hashable, list[UpdateContext]] = ( + dataclasses.field(default_factory=dict) ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) + dynamic_cache_context: WeakKeyDictionary[ + tp.Hashable, WeakKeyDictionary[tp.Any, DynamicCache] + ] = dataclasses.field(default_factory=WeakKeyDictionary) + tmp_static_cache: WeakKeyDictionary[tp.Any, StaticCache] | None = None + caching: bool = False GRAPH_CONTEXT = GraphContext() +@contextlib.contextmanager +def static_cache(static_cache: WeakKeyDictionary[tp.Any, StaticCache]): + if GRAPH_CONTEXT.caching: + yield + return + + GRAPH_CONTEXT.tmp_static_cache = static_cache + + try: + yield + finally: + if GRAPH_CONTEXT.tmp_static_cache is not None: + raise ValueError( + 'GRAPH_CONTEXT.tmp_static_cache should be None, no context consumed it.' + ) + + +def _cache_args(f: tp.Callable[..., tp.Any], *cached_args): + cache: WeakKeyDictionary[tp.Any, StaticCache] = WeakKeyDictionary() + original_ref_index = RefMap() + index_ref: dict[Index, tp.Any] = {} + cached_ref_index = RefMap() + + def create_static_cache(x): + if is_graph_node(x): + graphdef, flat_state = flatten( + x, with_paths=True, return_variables=True, ref_index=original_ref_index + ) + paths = flat_state.paths + variables = flat_state.leaves + # clone but keep the same variable references + node_cache = unflatten(graphdef, flat_state, index_ref=index_ref) + cached_new_ref_index = RefMap() + _fp = fingerprint( + node_cache, + ref_index=cached_ref_index, + new_ref_index=cached_new_ref_index, + ) + cached_ref_index.update(cached_new_ref_index) + cache[node_cache] = StaticCache.create( + graphdef, paths, variables, cached_new_ref_index + ) + return node_cache + return x + + cached_args = jax.tree.map(create_static_cache, cached_args) + + @functools.wraps(f) + def cache_args_wrapper(*args, **kwargs): + with static_cache(cache): + return f(*cached_args, *args, **kwargs) + + return cache_args_wrapper + + +if tp.TYPE_CHECKING: + cache_args = functools.partial +else: + cache_args = _cache_args + + @dataclasses.dataclass class SplitContext: - ctxtag: str | None - ref_index: RefMap[tp.Any, Index] + ctxtag: tp.Hashable | None + ref_index: RefMap + is_inner: bool | None @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @@ -807,84 +1244,373 @@ def split( ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, state = flatten(node, self.ref_index) - states = _split_state(state, filters) - if ctx is not None: - if ctx.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(ctx.index_ref, self.ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) - ) + inner_ref_outer_index = ctx and ctx.inner_ref_outer_index + graphdef, flat_state = flatten( + node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index + ) + flat_states = _split_state(flat_state, filters) + states = tuple( + State.from_flat_path(flat_state) for flat_state in flat_states + ) return graphdef, *states + @tp.overload + def flatten( + self, + graph_node: A, + /, + *, + with_paths: tp.Literal[False], + ) -> tuple[GraphDef[A], list[tp.Any]]: ... + @tp.overload + def flatten( + self, + graph_node: A, + /, + ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + @tp.overload + def flatten( + self, + graph_node: A, + first: filterlib.Filter, + /, + ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ... + @tp.overload + def flatten( + self, + graph_node: A, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[ + GraphDef[A], + FlatState[VariableState[tp.Any]], + tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], + ]: ... + def flatten( + self, + node: A, + *filters: filterlib.Filter, + with_paths: bool = True, + ) -> tuple[ + GraphDef[A], + FlatState[VariableState[tp.Any]] | list[tp.Any], + tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], + ]: + if not with_paths and filters: + raise ValueError('Cannot use filters with with_paths=False') + + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + dynamic_cache = ( + ctx.dynamic_cache if ctx is not None and self.is_inner is False else None + ) + static_cache = ( + ctx.static_cache if ctx is not None and self.is_inner is False else None + ) + ref_outer_index = ctx and ctx.inner_ref_outer_index + + if node in self.ref_index: + # node is already in the ref_index, call flatten which will return a NodeRef + graphdef, flat_state = flatten( + node, + ref_index=self.ref_index, + ref_outer_index=ref_outer_index, + with_paths=with_paths, + ) + if with_paths: + assert isinstance(flat_state, FlatState) + paths = flat_state.paths + leaves = flat_state.leaves + else: + assert isinstance(flat_state, list) + paths = None + leaves = flat_state + elif static_cache is not None and node in static_cache: + node_cache = static_cache[node] + graphdef = node_cache.graphdef + # add the new references to the ref_index + self.ref_index.update(node_cache.new_ref_index) + + if with_paths: + paths = node_cache.paths + leaves = [variable.to_state() for variable in node_cache.variables] + else: + paths = None + leaves = [variable.raw_value for variable in node_cache.variables] + + elif dynamic_cache is not None and node in dynamic_cache: + node_cache = dynamic_cache[node] + cache_fp = node_cache.fingerprint + new_ref_index = RefMap() + fp_matches = check_fingerprint( + node, cache_fp, ref_index=self.ref_index, new_ref_index=new_ref_index + ) + if fp_matches: + graphdef = node_cache.graphdef + self.ref_index.update(new_ref_index) + + if with_paths: + paths = node_cache.paths + leaves = [variable.to_state() for variable in node_cache.variables] + else: + paths = None + leaves = [variable.raw_value for variable in node_cache.variables] + else: + del cache_fp + del node_cache + new_ref_index = RefMap() + node_fp = fingerprint( + node, ref_index=self.ref_index, new_ref_index=new_ref_index + ) + graphdef, flat_states = flatten( + node, + ref_index=self.ref_index, + ref_outer_index=ref_outer_index, + with_paths=True, + return_variables=True, + ) + paths = flat_states.paths + variables = flat_states.leaves + assert paths is not None + if with_paths: + leaves = [variable.to_state() for variable in variables] + else: + leaves = [variable.raw_value for variable in variables] + dynamic_cache[node] = DynamicCache.create( + node_fp, graphdef, paths, variables, new_ref_index + ) + elif dynamic_cache is not None: # node not in cache_context + new_ref_index = RefMap() + node_fp = fingerprint( + node, ref_index=self.ref_index, new_ref_index=new_ref_index + ) + graphdef, flat_state = flatten( + node, + ref_index=self.ref_index, + ref_outer_index=ref_outer_index, + with_paths=True, + return_variables=True, + ) + paths = flat_state.paths + variables = flat_state.leaves + if with_paths: + leaves = [variable.to_state() for variable in variables] + else: + leaves = [variable.raw_value for variable in variables] + dynamic_cache[node] = DynamicCache.create( + node_fp, graphdef, paths, variables, new_ref_index + ) + else: + graphdef, flat_state = flatten( + node, + ref_index=self.ref_index, + ref_outer_index=ref_outer_index, + with_paths=with_paths, + ) + if with_paths: + assert isinstance(flat_state, FlatState) + paths = flat_state.paths + leaves = flat_state.leaves + else: + assert isinstance(flat_state, list) + paths = None + leaves = flat_state + + if with_paths: + assert paths is not None + flat_state = FlatState.from_sorted_keys_values(paths, leaves) + flat_states = _split_state(flat_state, filters) + return graphdef, *flat_states + else: + return graphdef, leaves + @contextlib.contextmanager -def split_context(ctxtag: str | None = None): - index_ref: RefMap[tp.Any, Index] = RefMap() - flatten_ctx = SplitContext(ctxtag, index_ref) - GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx) +def split_context(ctxtag: tp.Hashable | None = None): + ctx = current_update_context(ctxtag) if ctxtag is not None else None + is_inner = ctx.outer_ref_outer_index is not None if ctx is not None else None + GRAPH_CONTEXT.ref_index_stack.append(SplitContext(ctxtag, RefMap(), is_inner)) try: - yield flatten_ctx + yield GRAPH_CONTEXT.ref_index_stack[-1] finally: - GRAPH_CONTEXT.ref_index_stack.pop() + flatten_ctx = GRAPH_CONTEXT.ref_index_stack.pop() if ctxtag is not None: ctx = current_update_context(ctxtag) - ctx.flatten_end(index_ref) + ctx.flatten_end(flatten_ctx.ref_index) del flatten_ctx.ref_index del flatten_ctx.ctxtag @dataclasses.dataclass class MergeContext: - ctxtag: str | None + ctxtag: tp.Hashable | None index_ref: dict[Index, tp.Any] + is_inner: bool | None def merge( - self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState + self, + graphdef: GraphDef[A], + state: GraphState, + /, + *states: GraphState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - if ( - ctx is not None - and isinstance(graphdef, NodeDef) - and graphdef.index_mapping is not None - ): - # outer merge (4), create index_ref_cache - assert ctx.ref_index is not None - index_ref_cache = compose_mapping_reversed( - ctx.ref_index, graphdef.index_mapping - ) - else: - # inner merge (2) - index_ref_cache = None state = State.merge(state, *states) node = unflatten( graphdef, state, index_ref=self.index_ref, - index_ref_cache=index_ref_cache, + outer_index_outer_ref=ctx and ctx.outer_index_outer_ref, ) return node + def unflatten( + self, + graphdef: GraphDef[A], + flat_state: GraphFlatState | list[tp.Any], + /, + *flat_states: GraphFlatState, + ) -> A: + ctx = ( + current_update_context(self.ctxtag) if self.ctxtag is not None else None + ) + dynamic_cache = ( + ctx.dynamic_cache if ctx is not None and self.is_inner is False else None + ) + static_cache = ( + ctx.static_cache if ctx is not None and self.is_inner is False else None + ) -@contextlib.contextmanager -def merge_context(ctxtag: str | None = None): - index_ref: dict[Index, tp.Any] = {} + if type(flat_state) is list: + if flat_states: + raise ValueError( + 'Cannot use multiple flat_states when flat_state is a list, ' + f'got flat_state: {flat_state!r}, flat_states: {flat_states!r}' + ) + state = flat_state + else: + state = FlatState.merge(flat_state, *flat_states) + + if type(graphdef) is NodeRef: + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + ) + + elif dynamic_cache is not None or static_cache is not None: + assert isinstance(graphdef, NodeDef) + assert ctx is not None + if (outer_index := graphdef.outer_index) is not None: + outer_index_outer_ref = ctx.outer_index_outer_ref + assert outer_index_outer_ref is not None + node = outer_index_outer_ref[outer_index] + + if static_cache and node in static_cache: + cache = static_cache[node] + if cache.final_graphdef != graphdef: + raise ValueError( + 'The graph structure of a node added to cache_args was mutated inside the transformation, ' + f'this is not allowed.\nNode: {node}\nOuput graphdef: {graphdef}\nExpected graphdef: {cache.final_graphdef}' + ) + if type(state) is list: + leaves = state + elif type(state) is FlatState: + leaves = state.leaves + else: + raise ValueError(f'Unsupported state type: {type(state)}') + + if len(leaves) != len(cache.variables): + raise ValueError( + f'Incorrect number of leaves: expected {len(cache.variables)} ' + f'leaves in the state, got {len(leaves)}' + ) + for variable, leaf in zip(cache.variables, leaves): + if type(leaf) is VariableState: + variable.update_from_state(leaf) + else: + variable.raw_value = leaf + self.index_ref.update(cache.new_index_ref) + elif dynamic_cache and node in dynamic_cache: + # node is in cache_context, retrieve its cache + cache = dynamic_cache[node] + # check if the graphdef is the same + if cache.final_graphdef == graphdef: + if type(state) is list: + leaves = state + elif type(state) is FlatState: # type: ignore + leaves = state.leaves + else: + raise ValueError(f'Unsupported state type: {type(state)}') + + # graphdefs match, update variables from state + if len(leaves) != len(cache.variables): + raise ValueError( + f'Incorrect number of leaves: expected {len(cache.variables)} ' + f'leaves in the state, got {len(leaves)}' + ) + for variable, leaf in zip(cache.variables, leaves): + if type(leaf) is VariableState: + variable.update_from_state(leaf) + else: + variable.raw_value = leaf + self.index_ref.update(cache.new_index_ref) + else: # cache.graphdef != graphdef_fp + # graph changed, re-create the node + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + outer_index_outer_ref=outer_index_outer_ref, + ) + else: + # all nodes in index_ref_cache must be in cache_context + raise RuntimeError(f'Node not found in cache_context, node: {node}') + else: # graphdef.outer_index is None + # its a new node, create it + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + ) + else: + node = unflatten( + graphdef, + state, + index_ref=self.index_ref, + outer_index_outer_ref=ctx and ctx.outer_index_outer_ref, + ) + return node - unflatten_ctx = MergeContext(ctxtag, index_ref) - GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx) + +@tp.overload +@contextlib.contextmanager +def merge_context(): ... +@tp.overload +@contextlib.contextmanager +def merge_context(inner: bool | None, ctxtag: tp.Hashable | None): ... +@contextlib.contextmanager +def merge_context(inner: bool | None = None, ctxtag: tp.Hashable | None = None): + GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner)) try: - yield unflatten_ctx + yield GRAPH_CONTEXT.index_ref_stack[-1] finally: - GRAPH_CONTEXT.index_ref_stack.pop() + unflatten_ctx = GRAPH_CONTEXT.index_ref_stack.pop() + index_ref = unflatten_ctx.index_ref if ctxtag is not None: + if inner is None: + raise ValueError('inner_merge must be specified when using ctxtag') ctx = current_update_context(ctxtag) - ctx.unflatten_end(index_ref) + ctx.unflatten_end(index_ref, inner) del unflatten_ctx.index_ref del unflatten_ctx.ctxtag @@ -893,9 +1619,14 @@ def merge_context(ctxtag: str | None = None): class UpdateContext: """A context manager for handling complex state updates.""" - tag: str - ref_index: RefMap[tp.Any, Index] | None - index_ref: dict[Index, tp.Any] | None + tag: tp.Hashable + outer_ref_outer_index: RefMap | None + outer_index_inner_ref: dict[Index, tp.Any] | None + # reverse caches + outer_index_outer_ref: dict[Index, tp.Any] | None + inner_ref_outer_index: RefMap | None + dynamic_cache: WeakKeyDictionary[tp.Any, DynamicCache] | None + static_cache: WeakKeyDictionary[tp.Any, StaticCache] | None # define hash and eq to make this an opaque object def __hash__(self): @@ -904,16 +1635,25 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, UpdateContext) - def flatten_end(self, ref_index: RefMap[tp.Any, Index]): - if self.ref_index is None: + def flatten_end(self, ref_index: RefMap): + if self.outer_ref_outer_index is None: # outer split (1), store the references - self.ref_index = ref_index + self.outer_ref_outer_index = ref_index + self.outer_index_outer_ref = { + index: obj for obj, index in self.outer_ref_outer_index.items() + } else: # inner split (3), clear index_ref - self.index_ref = None + self.outer_index_inner_ref = None + self.inner_ref_outer_index = None - def unflatten_end(self, index_ref: dict[Index, tp.Any]): - self.index_ref = index_ref + def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool): + if inner_merge: + # inner merge (2) + self.outer_index_inner_ref = index_ref + self.inner_ref_outer_index = RefMap( + {obj: index for index, obj in index_ref.items()} + ) @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @@ -996,15 +1736,14 @@ def split( :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no filters are passed, a single :class:`State` is returned. """ - ref_index: RefMap[tp.Any, Index] = RefMap() - graphdef, state = flatten(node, ref_index) - states = _split_state(state, filters) - - if self.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(self.index_ref, ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) - ) + ref_index: RefMap = RefMap() + graphdef, flat_state = flatten( + node, ref_index=ref_index, ref_outer_index=self.inner_ref_outer_index + ) + states = tuple( + State.from_flat_path(flat_state) + for flat_state in _split_state(flat_state, filters) + ) self.flatten_end(ref_index) @@ -1021,15 +1760,13 @@ def merge( raise ValueError( f'Expected a NodeDef instance, but got {type(graphdef)}.' ) - if self.ref_index is None: + if self.outer_ref_outer_index is None: raise ValueError('Cannot merge without ref_index.') - if graphdef.index_mapping is not None: + if self.outer_ref_outer_index is not None: # outer merge (4), create index_ref_cache - assert self.ref_index is not None - index_ref_cache = compose_mapping_reversed( - self.ref_index, graphdef.index_mapping - ) + index_ref_cache = self.outer_index_outer_ref + assert index_ref_cache is not None else: # inner merge (2) index_ref_cache = None @@ -1037,10 +1774,13 @@ def merge( state = State.merge(state, *states) index_ref: dict[Index, tp.Any] = {} node = unflatten( - graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache + graphdef, + state, + index_ref=index_ref, + outer_index_outer_ref=index_ref_cache, ) - self.unflatten_end(index_ref) + self.unflatten_end(index_ref, True) return node @@ -1050,10 +1790,31 @@ def merge( @dataclasses.dataclass class UpdateContextManager: - tag: str + tag: tp.Hashable + use_dynamic_cache: bool def __enter__(self): - ctx = UpdateContext(self.tag, None, None) + dynamic_cache: WeakKeyDictionary[tp.Any, DynamicCache] | None + if self.use_dynamic_cache: + dynamic_cache = WeakKeyDictionary() + else: + dynamic_cache = None + + if GRAPH_CONTEXT.tmp_static_cache is not None: + # take current static cache + static_cache = GRAPH_CONTEXT.tmp_static_cache + GRAPH_CONTEXT.tmp_static_cache = None + else: + static_cache = None + ctx = UpdateContext( + tag=self.tag, + outer_ref_outer_index=None, + outer_index_inner_ref=None, + outer_index_outer_ref=None, + inner_ref_outer_index=None, + dynamic_cache=dynamic_cache, + static_cache=static_cache, + ) if self.tag not in GRAPH_CONTEXT.update_context_stacks: GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx] else: @@ -1069,8 +1830,10 @@ def __exit__(self, *args): ctx = stack.pop() # clear references - del ctx.ref_index - del ctx.index_ref + del ctx.outer_ref_outer_index + del ctx.outer_index_inner_ref + del ctx.outer_index_outer_ref + del ctx.inner_ref_outer_index if not stack: del GRAPH_CONTEXT.update_context_stacks[self.tag] @@ -1084,7 +1847,7 @@ def update_context_manager_wrapper(*args, **kwargs): return update_context_manager_wrapper # type: ignore -def update_context(tag: str): +def update_context(tag: tp.Hashable, *, use_dynamic_cache: bool = False): """Creates an :class:`UpdateContext` context manager which can be used to handle more complex state updates beyond what ``nnx.update`` can handle, including updates to static properties and graph structure. @@ -1176,10 +1939,10 @@ def update_context(tag: str): Args: tag: A string tag to identify the context. """ - return UpdateContextManager(tag) + return UpdateContextManager(tag=tag, use_dynamic_cache=use_dynamic_cache) -def current_update_context(tag: str) -> UpdateContext: +def current_update_context(tag: tp.Hashable) -> UpdateContext: """Returns the current active :class:`UpdateContext` for the given tag.""" if tag not in GRAPH_CONTEXT.update_context_stacks: raise ValueError(f'No update context found for tag {tag!r}.') @@ -1191,13 +1954,13 @@ def current_update_context(tag: str) -> UpdateContext: # -------------------------------------------------------- def _split_state( - state: GraphState, + state: FlatState[tp.Any], filters: tuple[filterlib.Filter, ...], -) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]: +) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]: if not filters: return (state,) states = state.split(*filters) - if isinstance(states, State): + if not isinstance(states, tuple): return (states,) assert len(states) > 0 return states # type: ignore[return-value] @@ -1288,9 +2051,11 @@ def split( ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ - graphdef, state = flatten(node) - states = _split_state(state, filters) - return graphdef, *states + graphdef, flat_state = flatten(node) + flat_states = _split_state(flat_state, filters) + states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return graphdef, *states # type: ignore[return-value] + def merge( graphdef: GraphDef[A], @@ -1482,6 +2247,7 @@ def state( One or more :class:`State` mappings. """ _, state = flatten(node) + state = state.to_nested_state() states: GraphState | tuple[GraphState, ...] if len(filters) == 0: @@ -1755,16 +2521,6 @@ def _iter_graph( yield path_parts, node -def compose_mapping( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[A, C]: - return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc} - - -def compose_mapping_reversed( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[C, A]: - return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc} @dataclasses.dataclass(frozen=True) @@ -1783,21 +2539,15 @@ class Static(tp.Generic[A]): # --------------------------------------------------------- class GenericPytree: ... +from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY def is_pytree_node(x: tp.Any) -> bool: - t = type(x) - if t in PYTREE_REGISTRY: + if type(x) in JAX_PYTREE_REGISTRY: return True - elif t in GRAPH_REGISTRY: - return False - # known non-pytree types - elif isinstance(x, Variable): - return False - # known pytree types - elif type(x) is VariableState or type(x) is State: + elif isinstance(x, tuple): return True else: - return not jax.tree_util.all_leaves((x,)) + return False def _key_path_to_key(key: tp.Any) -> Key: @@ -1816,20 +2566,28 @@ def _key_path_to_key(key: tp.Any) -> Key: else: return str(key) +class IndexesPytreeDef(tp.NamedTuple): + key_index: HashableMapping[Key, int] + treedef: jax.tree_util.PyTreeDef def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) - nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves) - - return nodes, treedef + nodes = [(_key_path_to_key(path[0]), value) for path, value in leaves] + key_index = HashableMapping( + {key: i for i, (key, _) in enumerate(nodes)}, copy=False + ) + nodes.sort() # sort by key + return nodes, IndexesPytreeDef(key_index, treedef) def _unflatten_pytree( - nodes: tuple[tuple[Key, tp.Any], ...], treedef: jax.tree_util.PyTreeDef + nodes: tuple[tuple[Key, tp.Any], ...], metadata: IndexesPytreeDef ): - pytree = treedef.unflatten(value for _, value in nodes) + # sort to original order + sorted_nodes = sorted(nodes, key=lambda x: metadata.key_index[x[0]]) + pytree = metadata.treedef.unflatten(value for _, value in sorted_nodes) return pytree diff --git a/flax/nnx/helpers.py b/flax/nnx/helpers.py index 96622f0e40..077817c4a1 100644 --- a/flax/nnx/helpers.py +++ b/flax/nnx/helpers.py @@ -62,6 +62,10 @@ def __iter__(self) -> tp.Iterator[str]: def __len__(self) -> int: return len(vars(self)) + def __hash__(self) -> int: + return id(self) + + class Sequential(Module): def __init__(self, *fns: tp.Callable[..., tp.Any]): self.layers = list(fns) diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index add545634a..a3313bf6e7 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -125,3 +125,6 @@ def __call__( mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) + + def __hash__(self): + return id(self) diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py index 155c2e7e90..58a004145e 100644 --- a/flax/nnx/reprlib.py +++ b/flax/nnx/reprlib.py @@ -235,6 +235,7 @@ def __nnx_repr__(self): for key, value in self.mapping.items(): yield Attr(colorized(key), value, use_raw_key=True) + @dataclasses.dataclass(repr=False) class SequenceReprMixin(Representable): def __nnx_repr__(self): diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index ab9817acaa..0fef2c173b 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import dataclasses import functools import typing as tp @@ -48,7 +47,6 @@ class RngKey(RngState): ... NotKey = filterlib.All(RngState, filterlib.Not(RngKey)) -@dataclasses.dataclass(repr=False) class RngStream(Object): def __init__( self, @@ -56,13 +54,12 @@ def __init__( key: jax.Array, count: jax.Array, ): + if not isinstance(key, jax.Array): + raise TypeError(f'key must be a jax.Array, got {type(key)}') + self.key = RngKey(key, tag=tag) self.count = RngCount(count, tag=tag) - def __post_init__(self): - if not isinstance(self.key, jax.Array): - raise TypeError(f'key must be a jax.Array, got {type(self.key)}') - def __call__(self) -> jax.Array: self.check_valid_context( lambda: 'Cannot call RngStream from a different trace level' @@ -80,7 +77,7 @@ def __call__(self) -> jax.Array: ] -class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]): +class Rngs(Object): """NNX rng container class. To instantiate the ``Rngs``, pass in an integer, specifying the starting seed. ``Rngs`` can have different "streams", allowing the user to generate different @@ -237,6 +234,10 @@ def __getstate__(self): def __setstate__(self, state): vars(self).update(state) + def items(self): + for name in self: + yield name, self[name] + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 38cb3da759..44a72212c8 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -54,26 +54,45 @@ def __treescope_repr__(self, path, subtree_renderer): # Render as the dictionary itself at the same path. return subtree_renderer(children, path=path) -class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.SequenceReprMixin): +class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.Representable): + __slots__ = ('_keys', '_values') + _keys: tuple[PathParts, ...] _values: list[V] - def __init__(self, items: tp.Iterable[tuple[PathParts, V]]): + def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort: bool): keys, values = [], [] + if sort: + items = sorted(items) for key, value in items: keys.append(key) values.append(value) self._keys = tuple(keys) self._values = values + @staticmethod + def from_sorted_keys_values( + keys: tuple[PathParts, ...], values: list[V], / + ) -> FlatState[V]: + flat_state = object.__new__(FlatState) + flat_state._keys = keys + flat_state._values = values + return flat_state + @property - def paths(self) -> tp.Sequence[PathParts]: + def paths(self) -> tp.Tuple[PathParts, ...]: return self._keys @property - def leaves(self) -> tp.Sequence[V]: + def leaves(self) -> list[V]: return self._values + def __nnx_repr__(self): + yield reprlib.Object(type='FlatState', kv_sep='', start='([', end='])') + + for value in self: + yield reprlib.Attr('', value) + @tp.overload def __getitem__(self, index: int) -> tuple[PathParts, V]: ... @tp.overload @@ -83,7 +102,7 @@ def __getitem__( ) -> tuple[PathParts, V] | FlatState[V]: if isinstance(index, int): return self._keys[index], self._values[index] - return FlatState(zip(self._keys[index], self._values[index])) + return FlatState(zip(self._keys[index], self._values[index]), sort=False) def __len__(self) -> int: return len(self._keys) @@ -91,6 +110,91 @@ def __len__(self) -> int: def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]: return iter(zip(self._keys, self._values)) + def to_nested_state(self) -> State[PathParts, V]: + return State.from_flat_path(self) + + @tp.overload + def split(self, first: filterlib.Filter, /) -> FlatState[V]: ... + + @tp.overload + def split( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[FlatState[V], ...]: ... + + @tp.overload + def split( + self, /, *filters: filterlib.Filter + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ... + + def split( # type: ignore[misc] + self, first: filterlib.Filter, /, *filters: filterlib.Filter + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: + filters = (first, *filters) + *flat_states_, rest = _split_state(self, *filters) + + if rest: + raise ValueError( + 'Non-exhaustive filters, got a non-empty remainder: ' + f'{rest}.\nUse `...` to match all remaining elements.' + ) + + flat_states: FlatState[V] | tuple[FlatState[V], ...] + if len(flat_states_) == 1: + flat_states = flat_states_[0] + else: + flat_states = tuple(flat_states_) + return flat_states # type: ignore + + @tp.overload + def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ... + + @tp.overload + def filter( + self, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[FlatState[V], ...]: ... + + def filter( + self, + first: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: + *flat_states_, _rest = _split_state(self, first, *filters) + + assert len(flat_states_) == len(filters) + 1 + + flat_states: FlatState[V] | tuple[FlatState[V], ...] + if len(flat_states_) == 1: + flat_states = flat_states_[0] + else: + flat_states = tuple(flat_states_) + + return flat_states # type: ignore + + @staticmethod + def merge( + flat_state: tp.Iterable[tuple[PathParts, V]], + /, + *flat_states: tp.Iterable[tuple[PathParts, V]], + ) -> FlatState[V]: + if not flat_states: + if isinstance(flat_state, FlatState): + return flat_state + return FlatState(flat_state, sort=True) + flat_states = (flat_state, *flat_states) + + return FlatState( + (elem for flat_state in flat_states for elem in flat_state), sort=True + ) + def _flat_state_pytree_flatten(x: FlatState[V]): return x._values, x._keys @@ -211,7 +315,7 @@ def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: return State.from_flat_path(result) def flat_state(self) -> FlatState[V]: - return FlatState(traversals.flatten_to_sequence(self._mapping)) + return FlatState(traversals.flatten_to_sequence(self._mapping), sort=True) @classmethod def from_flat_path( @@ -299,7 +403,8 @@ def split( # type: ignore[misc] One or more ``States`` equal to the number of filters passed. """ filters = (first, *filters) - *states_, rest = _split_state(self.flat_state(), *filters) + flat_states = _split_state(self.flat_state(), *filters) + *states_, rest = (state.to_nested_state() for state in flat_states) if rest: raise ValueError( @@ -364,7 +469,8 @@ def filter( Returns: One or more ``States`` equal to the number of filters passed. """ - *states_, _rest = _split_state(self.flat_state(), first, *filters) + flat_states = _split_state(self.flat_state(), first, *filters) + *states_, _rest = (state.to_nested_state() for state in flat_states) assert len(states_) == len(filters) + 1 @@ -464,7 +570,7 @@ def _state_unflatten( def _split_state( flat_state: FlatState[V], *filters: filterlib.Filter, -) -> tuple[State[PathParts, V], ...]: +) -> tuple[FlatState[V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] @@ -490,7 +596,7 @@ def _split_state( # if we didn't break, set leaf to last state flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? - return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + return tuple(FlatState(flat_state, sort=False) for flat_state in flat_states) def create_path_filters(state: State): diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index a7b72b1540..c53bbd5c4d 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -18,7 +18,7 @@ import jax import jax.core -from flax.nnx import reprlib, visualization +from flax.nnx import reprlib def current_jax_trace(): @@ -47,11 +47,12 @@ def __nnx_repr__(self): yield reprlib.Attr('jax_trace', self._jax_trace) def __treescope_repr__(self, path, subtree_renderer): - return visualization.render_object_constructor( - object_type=type(self), - attributes={'jax_trace': self._jax_trace}, - path=path, - subtree_renderer=subtree_renderer, + import treescope # type: ignore[import-not-found,import-untyped] + return treescope.repr_lib.render_object_constructor( + object_type=type(self), + attributes={'jax_trace': self._jax_trace}, + path=path, + subtree_renderer=subtree_renderer, ) def __eq__(self, other): diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 5ef0d183b7..24ca8c9d6d 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -64,24 +64,26 @@ class DiffState: class GradFn: f: tp.Callable[..., tp.Any] has_aux: bool + nondiff_states: deque[State | None] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): # rebuild diff_state from substates in args - nondiff_states: deque[State | None] = extract.get_broadcast_state('grad') def _grad_merge_fn( ctx: graph.MergeContext, path, prefix, value: extract.NodeStates ): - nondiff = nondiff_states.popleft() + nondiff = self.nondiff_states.popleft() if nondiff is None: return ctx.merge(value.graphdef, value.state) else: return ctx.merge(value.graphdef, value.state, nondiff) - args = extract.from_tree(pure_args, merge_fn=_grad_merge_fn, ctxtag='grad') + args = extract.from_tree( + pure_args, merge_fn=_grad_merge_fn, ctxtag='grad', is_inner=True + ) out = self.f(*args) @@ -129,15 +131,6 @@ def _grad_general( else DiffState(-1, variablelib.Param) ) - gradded_fn = transform( - GradFn(f, has_aux), - argnums=jax_argnums, - has_aux=True, - holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes, - ) - @graph.update_context('grad') def grad_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) @@ -160,8 +153,16 @@ def _grad_split_fn( args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad' ) - with extract.broadcast_state('grad', nondiff_states): - fn_out = gradded_fn(*pure_args) + gradded_fn = transform( + GradFn(f, has_aux, nondiff_states), + argnums=jax_argnums, + has_aux=True, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + + fn_out = gradded_fn(*pure_args) def process_grads(grads): return jax.tree.map( @@ -171,7 +172,7 @@ def process_grads(grads): ) def process_out(pure_out: A, /) -> A: - return extract.from_tree(pure_out, ctxtag='grad') + return extract.from_tree(pure_out, ctxtag='grad', is_inner=False) if return_value: # unpack value_and_grad output @@ -427,11 +428,11 @@ def _custom_vjp_split_fn( nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) -def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): +def _extract_nodedefs(x, *, nodedefs: deque[graph.NodeDef]): if isinstance(x, graph.NodeDef): - assert x.index_mapping is not None - index_mappings.append(x.index_mapping) - return dataclasses.replace(x, index_mapping=None) + assert x.outer_index is not None + nodedefs.append(x) + return x.with_no_outer_index() return x @dataclasses.dataclass(eq=False) @@ -440,6 +441,7 @@ class CustomVjpFnWrapper: jax_nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] + nodedefs: deque[graph.NodeDef] def __post_init__(self): functools.update_wrapper(self, self.f) @@ -452,6 +454,7 @@ def __call__(self, *pure_args): _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag, + is_inner=True, ) out = self.f(*args) @@ -464,13 +467,10 @@ def __call__(self, *pure_args): pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) - # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state( - self.ctxtag - ) + # remove outer_index from NodeDef's but store them in global context pure_args_out, pure_out = jax.tree.map( - functools.partial(_extract_index_mappings, index_mappings=index_mappings), + functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graph.NodeDef), ) @@ -484,6 +484,7 @@ class FwdFn: nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] + nodedefs: deque[graph.NodeDef] def __post_init__(self): functools.update_wrapper(self, self.fwd) @@ -503,6 +504,7 @@ def __call__(self, *pure_args): _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag if update_context_active else None, + is_inner=True, ) out, residual = self.fwd(*args) @@ -519,14 +521,9 @@ def __call__(self, *pure_args): pure_residual = extract.to_tree(residual) if update_context_active: - # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[graph.HashableMapping] = ( - extract.get_broadcast_state(self.ctxtag) - ) + # remove outer_index from NodeDef's but store them in global context pure_args_out, pure_out = jax.tree.map( - functools.partial( - _extract_index_mappings, index_mappings=index_mappings - ), + functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graph.NodeDef), ) @@ -544,7 +541,7 @@ def __post_init__(self): def __call__(self, *args): *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args - residual = extract.from_tree(pure_residual) + residual = extract.from_tree(pure_residual, is_inner=True) (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, (pure_args_out_g, pure_out_g), @@ -632,40 +629,41 @@ def __call__( for i, x in enumerate(tree_node_args) if i not in self.jax_nondiff_argnums ) - index_mappings: deque[graph.HashableMapping] = deque() - with extract.broadcast_state(self.ctxtag, index_mappings): - if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: - raise ValueError() - - custom_vjp_fn = jax.custom_vjp( - fun=CustomVjpFnWrapper( - f=self.fun, - jax_nondiff_argnums=self.jax_nondiff_argnums, - ctxtag=self.ctxtag, - nondiff_states=nondiff_states, - ), + nodedefs: deque[graph.NodeDef] = deque() + if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: + raise ValueError() + + custom_vjp_fn = jax.custom_vjp( + fun=CustomVjpFnWrapper( + f=self.fun, + jax_nondiff_argnums=self.jax_nondiff_argnums, + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + nodedefs=nodedefs, + ), + nondiff_argnums=self.jax_nondiff_argnums, + ) + custom_vjp_fn.defvjp( + fwd=FwdFn( + fwd=self.fwd, nondiff_argnums=self.jax_nondiff_argnums, - ) - custom_vjp_fn.defvjp( - fwd=FwdFn( - fwd=self.fwd, - nondiff_argnums=self.jax_nondiff_argnums, - ctxtag=self.ctxtag, - nondiff_states=nondiff_states, - ), - bwd=BwdFn( - bwd=self.bwd, - tree_node_args=tree_node_args, - ), - symbolic_zeros=self.symbolic_zeros, - ) - pure_args_out, pure_out = custom_vjp_fn(*pure_args) + ctxtag=self.ctxtag, + nondiff_states=nondiff_states, + nodedefs=nodedefs, + ), + bwd=BwdFn( + bwd=self.bwd, + tree_node_args=tree_node_args, + ), + symbolic_zeros=self.symbolic_zeros, + ) + pure_args_out, pure_out = custom_vjp_fn(*pure_args) # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graph.NodeDef): - index_mapping: graph.HashableMapping = index_mappings.popleft() - return dataclasses.replace(x, index_mapping=index_mapping) + nodedef: graph.NodeDef = nodedefs.popleft() + return nodedef return x pure_args_out, pure_out = jax.tree_util.tree_map( @@ -675,7 +673,7 @@ def _insert_index_mappings(x): ) args_out, out = extract.from_tree( - (pure_args_out, pure_out), ctxtag=self.ctxtag + (pure_args_out, pure_out), ctxtag=self.ctxtag, is_inner=False ) return out diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index e5ce20f8e3..e70092e2c9 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -90,10 +90,15 @@ def __hash__(self): def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): - return extract.NodeStates.from_split( - *ctx.split(x, *prefix.filters), metadata=prefix - ) - return extract.NodeStates.from_split(*ctx.split(x)) + graphdef, *states = ctx.flatten(x, *prefix.filters) + return extract.NodeStates.from_split(graphdef, *states, metadata=prefix) + return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False)) + + +def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any: + if not isinstance(leaf, extract.NodeStates): + raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') + return ctx.unflatten(leaf.graphdef, *leaf.states) @dataclasses.dataclass(eq=False) @@ -102,12 +107,18 @@ class JitFn: in_shardings: tp.Any out_shardings: tp.Any kwarg_shardings: tp.Any + ctxtag: tp.Hashable def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): - args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit') + args, kwargs = extract.from_tree( + (pure_args, pure_kwargs), + merge_fn=_jit_merge_fn, + ctxtag=self.ctxtag, + is_inner=True, + ) out = self.f(*args, **kwargs) @@ -115,7 +126,7 @@ def __call__(self, *pure_args, **pure_kwargs): pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args_out, kwargs_out, out), prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings), - ctxtag='jit', + ctxtag=self.ctxtag, split_fn=_jit_split_fn, ) @@ -317,8 +328,33 @@ def jit( out_shardings, ) + @functools.wraps(fun) + def jit_wrapper(*args, **kwargs): + # run dynamic_cache_context before update_context + with graph.update_context(jit_wrapper, use_dynamic_cache=True): + pure_args, pure_kwargs = extract.to_tree( + (args, kwargs), + prefix=(in_shardings, kwarg_shardings) + if in_shardings is not None or kwarg_shardings is not None + else None, + split_fn=_jit_split_fn, + check_aliasing=in_shardings is not None or kwarg_shardings is not None, + ctxtag=jit_wrapper, + ) + jax_in_shardings, kwarg_shardings, jax_out_shardings + pure_args_out, pure_kwargs_out, pure_out = jitted_fn( + *pure_args, **pure_kwargs + ) + _args_out, _kwargs_out, out = extract.from_tree( + (pure_args_out, pure_kwargs_out, pure_out), + merge_fn=_jit_merge_fn, + is_inner=False, + ctxtag=jit_wrapper, + ) + return out + jitted_fn = jax.jit( - JitFn(fun, in_shardings, out_shardings, kwarg_shardings), + JitFn(fun, in_shardings, out_shardings, kwarg_shardings, jit_wrapper), in_shardings=jax_in_shardings, out_shardings=(jax_in_shardings, kwarg_shardings, jax_out_shardings), # type: ignore static_argnums=static_argnums, @@ -332,24 +368,6 @@ def jit( abstracted_axes=abstracted_axes, ) - @functools.wraps(fun) - @graph.update_context('jit') - def jit_wrapper(*args, **kwargs): - pure_args, pure_kwargs = extract.to_tree( - (args, kwargs), - prefix=(in_shardings, kwarg_shardings), - split_fn=_jit_split_fn, - check_aliasing=in_shardings is not None, - ctxtag='jit', - ) - pure_args_out, pure_kwargs_out, pure_out = jitted_fn( - *pure_args, **pure_kwargs - ) - _args_out, _kwargs_out, out = extract.from_tree( - (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit' - ) - return out - jit_wrapper.inner = jitted_fn # type: ignore return jit_wrapper # type: ignore diff --git a/flax/nnx/transforms/general.py b/flax/nnx/transforms/general.py index fa82cd890a..553c3e8926 100644 --- a/flax/nnx/transforms/general.py +++ b/flax/nnx/transforms/general.py @@ -151,7 +151,9 @@ def split_inputs( def split_inputs_wrapper(*args): pure_args = extract.to_tree(args, ctxtag=ctxtag) pure_args_out, pure_out = f(*pure_args) - args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag=ctxtag) + args_out, out = extract.from_tree( + (pure_args_out, pure_out), ctxtag=ctxtag, is_inner=False + ) return out return split_inputs_wrapper # type: ignore @@ -192,7 +194,7 @@ def merge_inputs( @functools.wraps(f) def merge_inputs_wrapper(*pure_args): - args = extract.from_tree(pure_args, ctxtag=ctxtag) + args = extract.from_tree(pure_args, ctxtag=ctxtag, is_inner=True) out = f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag=ctxtag) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 994e582862..e379cf1b9c 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -165,7 +165,7 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]): pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) - args = extract.from_tree(pure_args, ctxtag='vmap') + args = extract.from_tree(pure_args, ctxtag='vmap', is_inner=True) out = self.f(*args) @@ -343,7 +343,9 @@ def vmap_wrapper(*args, **kwargs): args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' ) pure_args_out, pure_out = vmapped_fn(*pure_args) - _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap') + _args_out, out = extract.from_tree( + (pure_args_out, pure_out), ctxtag='vmap', is_inner=False + ) return out return vmap_wrapper # type: ignore @@ -369,7 +371,7 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]): pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) - args = extract.from_tree(pure_args, ctxtag='pmap') + args = extract.from_tree(pure_args, ctxtag='pmap', is_inner=True) out = self.f(*args) @@ -566,7 +568,9 @@ def vmap_wrapper(*args): args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap' ) pure_args_out, pure_out = pmapped_fn(*pure_args) - _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='pmap') + _args_out, out = extract.from_tree( + (pure_args_out, pure_out), ctxtag='pmap', is_inner=False + ) return out return vmap_wrapper # type: ignore @@ -648,21 +652,17 @@ def check_carry_same_references(key_path, arg, out): check_carry_same_references, carry_arg, carry_arg_out ) -def _extract_index_mappings( - pure_carry_arg_out, - carry_index_mappings: list[graph.HashableMapping[int, int]], - /, +def _extract_nodedefs( + pure_carry_arg_out, carry_nodedefs: list[graph.NodeDef], / ): def extract_index_mappings(x): if isinstance(x, extract.NodeStates) and isinstance( x._graphdef, graph.NodeDef ): - index_mapping = x._graphdef.index_mapping - assert index_mapping is not None - carry_index_mappings.append(index_mapping) - x = x.replace( - _graphdef=dataclasses.replace(x._graphdef, index_mapping=None) - ) + nodedef = x._graphdef + assert nodedef.outer_index is not None + carry_nodedefs.append(nodedef) + x = x.replace(_graphdef=nodedef.with_no_outer_index()) return x pure_carry_arg_out = jax.tree.map( @@ -673,19 +673,17 @@ def extract_index_mappings(x): return pure_carry_arg_out -def _insert_index_mappings( +def _insert_nodedefs( pure_carry_arg_out, - carry_index_mappings: deque[graph.HashableMapping[int, int]], + carry_nodedefs: deque[graph.NodeDef], /, ): def insert_index_mappings(x): if isinstance(x, extract.NodeStates) and isinstance( x._graphdef, graph.NodeDef ): - index_mapping = carry_index_mappings.popleft() - x = x.replace( - _graphdef=dataclasses.replace(x._graphdef, index_mapping=index_mapping) - ) + nodedef = carry_nodedefs.popleft() + x = x.replace(_graphdef=nodedef) return x pure_carry_arg_out = jax.tree.map( @@ -1017,6 +1015,7 @@ def __call__( is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', + is_inner=True, ) assert not carry_deque and not broadcast_deque and not broadcast_arrays @@ -1096,10 +1095,8 @@ def __call__( # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs - carry_index_mappings: list[graph.HashableMapping[int, int]] = [] - pure_carry_arg_out = _extract_index_mappings( - pure_carry_arg_out, carry_index_mappings - ) + carry_nodedefs: list[graph.NodeDef] = [] + pure_carry_arg_out = _extract_nodedefs(pure_carry_arg_out, carry_nodedefs) carry_arg_out = ( pure_carry_arg_out, @@ -1108,7 +1105,7 @@ def __call__( broadcast_arrays_out, ) scan_out = ( - graph.Static(tuple(carry_index_mappings)), + carry_nodedefs, pure_args_out, pure_out, ) @@ -1248,16 +1245,15 @@ def scan_wrapper(*args, **kwargs): broadcast_arrays_out, ) = carry_out ( - static_carry_index_mappings, + carry_nodedefs, pure_args_out, pure_out, ) = scan_out # next we have to insert all the index_mappings back into the NodeDefs # in the carry outputs - carry_index_mappings = deque(static_carry_index_mappings.value) - pure_carry_arg_out = _insert_index_mappings( - pure_carry_arg_out, carry_index_mappings + pure_carry_arg_out = _insert_nodedefs( + pure_carry_arg_out, deque(carry_nodedefs) ) # insert pure carry into pure_args_out @@ -1280,6 +1276,7 @@ def scan_wrapper(*args, **kwargs): is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', + is_inner=False, ) # extract the carry from args_out @@ -1330,35 +1327,15 @@ def __call__(self, pure_val): def _add_fake_index_mapping(tree: tp.Any): global_index_mapping = {} # for the whole context, over all inputs - def per_node_state(ns: extract.NodeStates | tp.Any): - if not isinstance(ns, extract.NodeStates) or not isinstance( - ns._graphdef, graph.NodeDef + + def per_node_state(node_state: extract.NodeStates | tp.Any): + if not isinstance(node_state, extract.NodeStates) or not isinstance( + node_state._graphdef, graph.NodeDef ): - return ns - - def per_node_def(nd: graph.NodeDef | graph.NodeRef): - if nd.index >= 0: - global_index_mapping[nd.index] = nd.index - if isinstance(nd, graph.NodeRef): - return - - for attribute in nd.attributes: - if type(attribute) is graph.SubGraphAttribute: - per_node_def(attribute.value) - elif ( - type(attribute) is graph.LeafAttribute - and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) - and attribute.value.index >= 0 - ): - global_index_mapping[attribute.value.index] = attribute.value.index - return - - per_node_def(ns._graphdef) + return node_state + return dataclasses.replace( - ns, - _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) - ), + node_state, _graphdef=node_state._graphdef.with_same_outer_index() ) return jax.tree.map(per_node_state, tree, @@ -1366,16 +1343,18 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): def _remove_index_mapping(tree: tp.Any): - '''Remove a fake index_mapping for the input to match that of the output.''' - def per_node_state(ns: extract.NodeStates | tp.Any): - if not isinstance(ns, extract.NodeStates) or not isinstance( - ns._graphdef, graph.NodeDef + """Remove a fake outer_index for the input to match that of the output.""" + + def per_node_state(node_state: extract.NodeStates | tp.Any): + if not isinstance(node_state, extract.NodeStates) or not isinstance( + node_state._graphdef, graph.NodeDef ): - return ns - assert isinstance(ns._graphdef, graph.NodeDef) - return dataclasses.replace(ns, _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=None - )) + return node_state + assert isinstance(node_state._graphdef, graph.NodeDef) + node_state = dataclasses.replace( + node_state, _graphdef=node_state._graphdef.with_no_outer_index() + ) + return node_state return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) @@ -1393,19 +1372,23 @@ def __call__(self, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) - val = extract.from_tree(pure_val_in, ctxtag='while_loop_body') + val = extract.from_tree( + pure_val_in, ctxtag='while_loop_body', is_inner=True + ) out = self.f(val) pure_out = extract.to_tree(out, ctxtag='while_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: - msg = ("nnx.while_loop requires body function's input and output to " - "have the same reference and pytree structure, but they differ. " - "If the mismatch comes from `index_mapping` field, you might " - "have modified reference structure within the body function, " - "which is not allowed." - f"Detail of the mismatch: \n {str(e)}") + msg = ( + "nnx.while_loop requires body function's input and output to " + 'have the same reference and pytree structure, but they differ. ' + 'If the mismatch comes from `outer_index` field, you might ' + 'have modified reference structure within the body function, ' + 'which is not allowed.' + f'Detail of the mismatch: \n {str(e)}' + ) raise ValueError(msg) return pure_out @@ -1456,7 +1439,7 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any], WhileLoopBodyFn(body_fun), pure_init_val, ) - out = extract.from_tree(pure_out, ctxtag='while_loop') + out = extract.from_tree(pure_out, ctxtag='while_loop', is_inner=False) return out @@ -1472,19 +1455,21 @@ def __call__(self, i, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) - val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body') + val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body', is_inner=True) out = self.f(i, val) pure_out = extract.to_tree(out, ctxtag='fori_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: - msg = ("nnx.fori_loop requires body function's input and output to " - "have the same reference and pytree structure, but they differ. " - "If the mismatch comes from `index_mapping` field, you might " - "have modified reference structure within the body function, " - "which is not allowed. " - f"Detail of the mismatch: \n {str(e)}") + msg = ( + "nnx.fori_loop requires body function's input and output to " + 'have the same reference and pytree structure, but they differ. ' + 'If the mismatch comes from `outer_index` field, you might ' + 'have modified reference structure within the body function, ' + 'which is not allowed. ' + f'Detail of the mismatch: \n {str(e)}' + ) raise ValueError(msg) return pure_out @@ -1545,5 +1530,5 @@ def fori_loop(lower: int, upper: int, pure_out = jax.lax.fori_loop(lower, upper, ForiLoopBodyFn(body_fun), pure_init_val, unroll=unroll) - out = extract.from_tree(pure_out, ctxtag='fori_loop') + out = extract.from_tree(pure_out, ctxtag='fori_loop', is_inner=False) return out diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 8a83a026d4..3192b31aa7 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -160,7 +160,7 @@ def __post_init__(self): def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( - (pure_args, pure_kwargs), ctxtag='checkify' + (pure_args, pure_kwargs), ctxtag='checkify', is_inner=True ) out = self.f(*args, **kwargs) @@ -216,6 +216,7 @@ def jit_wrapper(*args, **kwargs): args_out, kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), ctxtag='checkify', + is_inner=False, ) return error, out diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index b2c0660962..2908c074a8 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -47,7 +47,6 @@ VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} - @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): raw_value: A @@ -125,6 +124,8 @@ class Variable(tp.Generic[A], reprlib.Representable): }) """ + __slots__ = ('raw_value', '_trace_state', '_var_metadata') + raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] @@ -134,9 +135,8 @@ def __init__( value: tp.Union[A, VariableMetadata[A]], **metadata: tp.Any, ): - type_vars = vars(type(self)) - vars_self = vars(self) - vars_self['_trace_state'] = tracers.TraceState() + var_t = type(self) + object.__setattr__(self, '_trace_state', tracers.TraceState()) if isinstance(value, VariableMetadata): metadata.update(value.metadata) @@ -144,27 +144,28 @@ def __init__( object.__setattr__(self, 'raw_value', value) - if 'on_get_value' in type_vars and 'on_get_value' not in metadata: - metadata['get_value'] = getattr(type(self), 'on_get_value') + if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata: + metadata['get_value'] = var_t.on_get_value - if 'on_set_value' in type_vars and 'on_set_value' not in metadata: - metadata['set_value'] = getattr(type(self), 'on_set_value') + if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata: + metadata['set_value'] = var_t.on_set_value - if 'on_create_value' in type_vars and 'on_create_value' not in metadata: - metadata['create_value'] = getattr(type(self), 'on_create_value') + if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata: + metadata['create_value'] = var_t.on_create_value - if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata: - metadata['add_axis'] = getattr(type(self), 'on_add_axis') + if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata: + metadata['add_axis'] = var_t.on_add_axis - if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata: - metadata['remove_axis'] = getattr(type(self), 'on_remove_axis') + if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata: + metadata['remove_axis'] = var_t.on_remove_axis - vars_self['_var_metadata'] = metadata + object.__setattr__(self, '_var_metadata', metadata) # run create_value hooks - vars_self['raw_value'] = self.create_value(self.raw_value) + object.__setattr__(self, 'raw_value', self.create_value(self.raw_value)) + def __getattr__(self, name: str) -> tp.Any: - if name in vars(self)['_var_metadata']: + if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] return getattr(self.value, name) @@ -220,9 +221,10 @@ def copy_from(self, other: Variable[A]) -> None: self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: VariableState[A]): - vars_self = vars(self) - vars_self['raw_value'] = variable_state.value - vars_self['_var_metadata'] = variable_state._var_metadata.copy() + object.__setattr__(self, 'raw_value', variable_state.value) + object.__setattr__( + self, '_var_metadata', variable_state._var_metadata.copy() + ) @property def value(self) -> A: @@ -239,7 +241,7 @@ def value(self, value: A): ) if 'on_set_value' in self._var_metadata: value = self._var_metadata['on_set_value'](self, value) - vars(self)['raw_value'] = value + object.__setattr__(self, 'raw_value', value) def create_value(self, value: A): if 'on_create_value' in self._var_metadata: @@ -254,9 +256,6 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: self._var_metadata['on_remove_axis'](self, axis_index, axis_name) - def __eq__(self, other: object) -> bool: - return type(self) is type(other) and vars(other) == vars(self) - @tp.overload def replace(self, value: B, **kwargs) -> Variable[B]: ... @@ -369,10 +368,16 @@ def __jax_array__(self): # pickle support def __getstate__(self): - return vars(self).copy() + return { + 'raw_value': self.raw_value, + '_trace_state': self._trace_state, + '_var_metadata': self._var_metadata, + } def __setstate__(self, state): - vars(self).update(state) + object.__setattr__(self, 'raw_value', state['raw_value']) + object.__setattr__(self, '_trace_state', state['_trace_state']) + object.__setattr__(self, '_var_metadata', state['_var_metadata']) # -------------------------------------------- # proxy methods @@ -841,6 +846,7 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: self._var_metadata['on_remove_axis'](self, axis_index, axis_name) +GraphVariableState = VariableState[VariableState[tp.Any]] def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): metadata = tuple(x.get_metadata().items()) @@ -944,7 +950,7 @@ def wrapper(*args): def split_flat_state( flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]], - filters: tp.Sequence[filterlib.Filter], + filters: tuple[filterlib.Filter, ...], ) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]: predicates = filterlib.filters_to_predicates(filters) # we have n + 1 states, where n is the number of predicates diff --git a/flax/typing.py b/flax/typing.py index 0ae990d95a..0f694383f6 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -168,11 +168,11 @@ class Missing: def _bytes_repr(num_bytes): count, units = ( - (f'{num_bytes / 1e9 :,.1f}', 'GB') + (f'{num_bytes / 1e9:,.1f}', 'GB') if num_bytes > 1e9 - else (f'{num_bytes / 1e6 :,.1f}', 'MB') + else (f'{num_bytes / 1e6:,.1f}', 'MB') if num_bytes > 1e6 - else (f'{num_bytes / 1e3 :,.1f}', 'KB') + else (f'{num_bytes / 1e3:,.1f}', 'KB') if num_bytes > 1e3 else (f'{num_bytes:,}', 'B') ) diff --git a/flaxlib_src/CMakeLists.txt b/flaxlib_src/CMakeLists.txt new file mode 100644 index 0000000000..a5a61b5b2a --- /dev/null +++ b/flaxlib_src/CMakeLists.txt @@ -0,0 +1,54 @@ +# Set the minimum CMake version and policies for highest tested version +cmake_minimum_required(VERSION 3.15...3.27) + +# Set up the project and ensure there is a working C++ compiler +project(flaxlib LANGUAGES CXX) + +# Warn if the user invokes CMake directly +if (NOT SKBUILD) + message(WARNING "\ + This CMake file is meant to be executed using 'scikit-build-core'. + Running it directly will almost certainly not produce the desired + result. If you are a user trying to install this package, use the + command below, which will install all necessary build dependencies, + compile the package in an isolated environment, and then install it. + ===================================================================== + $ pip install . + ===================================================================== + If you are a software developer, and this is your own package, then + it is usually much more efficient to install the build dependencies + in your environment once and use the following command that avoids + a costly creation of a new virtual environment at every compilation: + ===================================================================== + $ pip install nanobind scikit-build-core[pyproject] + $ pip install --no-build-isolation -ve . + ===================================================================== + You may optionally add -Ceditable.rebuild=true to auto-rebuild when + the package is imported. Otherwise, you need to rerun the above + after editing C++ files.") +endif() + +# Try to import all Python components potentially needed by nanobind +find_package(Python 3.8 + REQUIRED COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.SABIModule) + +# Import nanobind through CMake's find_package mechanism +find_package(nanobind CONFIG REQUIRED) + +# We are now ready to compile the actual extension module +nanobind_add_module( + # Name of the extension + flaxlib_cpp + + # Target the stable ABI for Python 3.12+, which reduces + # the number of binary wheels that must be built. This + # does nothing on older Python versions + STABLE_ABI + + # Source code goes here + src/lib.cc +) + +# Install directive for scikit-build-core +install(TARGETS flaxlib_cpp LIBRARY DESTINATION flaxlib) \ No newline at end of file diff --git a/flaxlib_src/meson.build b/flaxlib_src/meson.build deleted file mode 100644 index 0d78d9436b..0000000000 --- a/flaxlib_src/meson.build +++ /dev/null @@ -1,14 +0,0 @@ -project( - 'flaxlib', - 'cpp', - version: '0.0.1', - default_options: ['cpp_std=c++17'], -) -py = import('python').find_installation() -nanobind_dep = dependency('nanobind', static: true) -py.extension_module( - 'flaxlib', - sources: ['src/lib.cc'], - dependencies: [nanobind_dep], - install: true, -) \ No newline at end of file diff --git a/flaxlib_src/pyproject.toml b/flaxlib_src/pyproject.toml index 0afc7699a5..fd6c0b61b4 100644 --- a/flaxlib_src/pyproject.toml +++ b/flaxlib_src/pyproject.toml @@ -1,17 +1,28 @@ [build-system] -requires = ['meson-python'] -build-backend = 'mesonpy' +requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] +build-backend = "scikit_build_core.build" [project] name = "flaxlib" +version = "0.0.1" requires-python = ">=3.10" classifiers = [ "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = ["version"] + [project.optional-dependencies] tests = [ "pytest", ] + +[tool.scikit-build] +# Protect the configuration against future changes in scikit-build-core +minimum-version = "0.4" + +# Setuptools-style build caching in a local directory +build-dir = "build/{wheel_tag}" + +# Build stable ABI wheels for CPython 3.12+ +wheel.py-api = "cp312" \ No newline at end of file diff --git a/flaxlib_src/flaxlib.pyi b/flaxlib_src/src/flaxlib/__init__.py similarity index 84% rename from flaxlib_src/flaxlib.pyi rename to flaxlib_src/src/flaxlib/__init__.py index 505fd3d0f0..f458417719 100644 --- a/flaxlib_src/flaxlib.pyi +++ b/flaxlib_src/src/flaxlib/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -def sum_as_string(a: int, b: int) -> str: ... +from .flaxlib_cpp import RefMap as RefMap +from .flaxlib_cpp import _graph_fingerprint as _graph_fingerprint diff --git a/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi new file mode 100644 index 0000000000..03557efb9f --- /dev/null +++ b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi @@ -0,0 +1,25 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +RefMap = tp.MutableMapping[tp.Any, int] + +def _graph_fingerprint( + node, + node_impl, + ref_index: RefMap, + new_ref_index: RefMap, + next_index: int, +) -> tuple[tuple[tp.Any, ...], int]: ... \ No newline at end of file diff --git a/flaxlib_src/src/lib.cc b/flaxlib_src/src/lib.cc index c714588118..c915727030 100644 --- a/flaxlib_src/src/lib.cc +++ b/flaxlib_src/src/lib.cc @@ -1,14 +1,298 @@ +// Copyright 2024 The Flax Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" +namespace nb = nanobind; +using namespace nb::literals; -namespace flaxlib { -std::string sum_as_string(int a, int b) { - return std::to_string(a + b); +// ----------------------------------- +// helper functions +// ----------------------------------- +intptr_t nb_id(const nb::object &obj) +{ + // Get the object ID + return reinterpret_cast(obj.ptr()); } -NB_MODULE(flaxlib, m) { - m.def("sum_as_string", &sum_as_string); +nb::tuple vector_to_tuple(const std::vector &vec) +{ + + if (vec.empty()) + { + return nb::tuple(); + } + else + { + return nb::tuple(nb::cast(vec)); + } } -} // namespace flaxlib \ No newline at end of file + +// 1. Hash function for nb::object +struct NbObjectHash +{ + std::size_t operator()(const nb::object &obj) const + { + return nb::hash(obj); + } +}; + +// 2. Equality function for nb::object (Important!) +struct NbObjectEqual +{ + bool operator()(const nb::object &a, const nb::object &b) const + { + return a.equal(b); + } +}; + +NB_MAKE_OPAQUE(std::unordered_map); + +namespace flaxlib +{ + //--------------------------------------------------------------- + // RefMap + //--------------------------------------------------------------- + + using RefMap = std::unordered_map; + + std::optional ref_map_get(RefMap &map, nb::object &key, std::optional default_value = std::nullopt) + { + auto it = map.find(key); + if (it != map.end()) + { + return it->second; + } + else + { + return std::nullopt; + } + } + + //--------------------------------------------------------------- + // NNXContext + //--------------------------------------------------------------- + + struct PythonContext + { + nb::object nnx; + nb::object graph; + nb::object jax; + nb::object np; + nb::object jax_Array; + nb::object np_ndarray; + nb::type_object GraphNodeImpl; + nb::type_object PytreeNodeImpl; + nb::type_object Object; + nb::type_object Variable; + nb::object get_node_impl; + + PythonContext() + { + nnx = nb::module_::import_("flax.nnx"); + graph = nb::module_::import_("flax.nnx.graph"); + jax = nb::module_::import_("jax"); + np = nb::module_::import_("numpy"); + jax_Array = jax.attr("Array"); + np_ndarray = np.attr("ndarray"); + GraphNodeImpl = graph.attr("GraphNodeImpl"); + PytreeNodeImpl = graph.attr("PytreeNodeImpl"); + Object = nnx.attr("Object"); + Variable = graph.attr("Variable"); + get_node_impl = graph.attr("get_node_impl"); + } + + ~PythonContext() + { + graph.release(); + jax.release(); + np.release(); + jax_Array.release(); + np_ndarray.release(); + GraphNodeImpl.release(); + PytreeNodeImpl.release(); + Variable.release(); + get_node_impl.release(); + } + }; + + static std::optional _python_context; + + PythonContext &get_python_context() + { + if (!_python_context) + { + _python_context.emplace(); + } + return *_python_context; + } + + //--------------------------------------------------------------- + // fingerprint + //--------------------------------------------------------------- + std::tuple _key_values_metadata( + PythonContext &ctx, + nb::object &node, + nb::object &node_impl) + { + if (nb::isinstance(node, ctx.Object)) + { + nb::dict nodes_dict = node.attr("__dict__"); + nb::handle object_state = nodes_dict["_object__state"]; + nb::del(nodes_dict["_object__state"]); + auto nodes = nodes_dict.items(); + nodes.sort(); + nodes_dict["_object__state"] = object_state; + auto metadata = nb::make_tuple(node.type(), object_state.attr("_initializing")); + return {nodes, metadata}; + } + else if (PyList_Check(node.ptr()) || PyTuple_Check(node.ptr())) + { + int i = 0; + nb::list values; + for (const auto &value : node) + { + values.append(nb::make_tuple(i, value)); + i += 1; + } + return {values, nb::none()}; + } + else + { + auto values_metadata = node_impl.attr("flatten")(node); + auto values = values_metadata[0]; + auto metadata = values_metadata[1]; + return {values, metadata}; + } + } + + nb::tuple _graph_fingerprint_recursive( + PythonContext &ctx, + nb::object &node, + nb::object &node_impl, + RefMap &ref_index, + RefMap &new_ref_index, + int &next_index) + { + bool is_pytree_node = node_impl.type().is(ctx.PytreeNodeImpl); + bool is_graph_node = node_impl.type().is(ctx.GraphNodeImpl); + + if (is_pytree_node) + { + // pass + } + else if (ref_index.find(node) != ref_index.end()) + { + return nb::make_tuple(nb_id(node), node.type(), ref_index[node]); + } + else if (new_ref_index.find(node) != new_ref_index.end()) + { + return nb::make_tuple(nb_id(node), node.type(), new_ref_index[node]); + } + + // only cache graph nodes + int index; + if (is_graph_node) + { + index = new_ref_index[node] = next_index; + next_index += 1; + } + else + { + index = -1; + } + + std::vector attributes; + + auto [values, metadata] = _key_values_metadata(ctx, node, node_impl); + + for (const auto &key_value : values) + { + nb::object key = key_value[0]; + nb::object value = key_value[1]; + auto value_node_impl = ctx.get_node_impl(value); + if (!value_node_impl.is_none()) + { + auto node_fp = _graph_fingerprint_recursive(ctx, value, value_node_impl, ref_index, new_ref_index, next_index); + attributes.push_back(nb::make_tuple(key, node_fp)); + } + else if (nb::isinstance(value, ctx.Variable)) + { + if (ref_index.find(value) != ref_index.end()) + { + attributes.push_back(nb::make_tuple(key, nb_id(value), value.type(), ref_index[value])); + } + else if (new_ref_index.find(value) != new_ref_index.end()) + { + attributes.push_back(nb::make_tuple(key, nb_id(value), value.type(), new_ref_index[value])); + } + else + { + auto variable_index = new_ref_index[value] = next_index; + next_index += 1; + auto var_meta = nb::tuple(value.attr("_var_metadata").attr("items")()); + attributes.push_back(nb::make_tuple(key, nb_id(value), value.type(), variable_index, var_meta)); + } + } + else // static attribute + { + if (nb::isinstance(value, ctx.jax_Array) || nb::isinstance(value, ctx.np_ndarray)) + { + auto repr = "Arrays leaves are not supported: " + nb::cast(nb::repr(value)); + } + attributes.push_back(nb::make_tuple(key, value)); + } + } + + auto node_fp = nb::make_tuple( + is_graph_node ? nb::cast(nb_id(node)) : nb::none(), + node_impl.attr("type"), + index, + vector_to_tuple(attributes), + metadata); + + return node_fp; + } + + nb::tuple _graph_fingerprint( + nb::object &node, + nb::object &node_impl, + RefMap &ref_index, + RefMap &new_ref_index, + int next_index) + { + auto ctx = get_python_context(); + auto node_fp = _graph_fingerprint_recursive(ctx, node, node_impl, ref_index, new_ref_index, next_index); + return nb::make_tuple(node_fp, next_index); + } + + NB_MODULE(flaxlib_cpp, m) + { + // Remove the conflicting binding + nb::bind_map(m, "RefMap") + .def("get", &ref_map_get, nb::arg("key").none(), nb::arg("default_value").none()); + m.def("_graph_fingerprint", &_graph_fingerprint); + } +} // namespace flaxlib \ No newline at end of file diff --git a/flaxlib_src/src/lib.rs b/flaxlib_src/src/lib.rs deleted file mode 100644 index cadab2ef22..0000000000 --- a/flaxlib_src/src/lib.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 The Flax Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use pyo3::prelude::*; - -/// Formats the sum of two numbers as string. -#[pyfunction] -fn sum_as_string(a: usize, b: usize) -> PyResult { - Ok((a + b).to_string()) -} - -/// A Python module implemented in Rust. -#[pymodule] -fn flaxlib(_py: Python, m: &Bound) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; - Ok(()) -} diff --git a/pyproject.toml b/pyproject.toml index f7a890fad0..339d065ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,3 +229,9 @@ quote-style = "single" [tool.uv] # Ignore uv.lock and always upgrade the package to the latest upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"] + +[dependency-groups] +dev = [ + "nanobind>=2.4.0", + "scikit-build-core[pyproject]>=0.10.7", +] diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 5b65603a24..b353dd4925 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -228,7 +228,9 @@ def test_nnx_to_linen(self): assert y.shape == (1, 64) np.testing.assert_allclose(y, x @ variables['params']['kernel']) assert 'nnx' in variables - assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef) + assert isinstance( + variables['nnx']['graphdef'], nnx.graph.NodeDef | nnx.graph.NodeRef + ) def test_nnx_to_linen_multiple_rngs(self): class NNXInner(nnx.Module): diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index a7bbf178cb..397198ae41 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -64,10 +64,26 @@ def test_flatten(self): g = [a, 3, a, nnx.Param(4)] refmap = nnx.graph.RefMap() - graphdef, state = nnx.graph.flatten(g, ref_index=refmap) + graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap) - state[0]['b'].raw_value = 2 - state[3].raw_value = 4 + assert flat_state[0][1].value == 2 + assert flat_state[1][1].value == 4 + + assert len(refmap) == 2 + assert a['b'] in refmap + assert g[3] in refmap + + def test_flatten_no_paths(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + refmap = nnx.graph.RefMap() + graphdef, flat_state = nnx.graph.flatten( + g, ref_index=refmap, with_paths=False + ) + + assert flat_state[0] == 2 + assert flat_state[1] == 4 assert len(refmap) == 2 assert a['b'] in refmap @@ -108,9 +124,40 @@ def test_unflatten_empty(self): graphdef, state = nnx.split(g) - with self.assertRaisesRegex(ValueError, 'Expected key'): + with self.assertRaisesRegex( + ValueError, 'Not enough leaves to unflatten the graph' + ): nnx.graph.unflatten(graphdef, nnx.State({})) + def test_unflatten_return_variables(self): + a = Dict({'a': 1, 'b': nnx.Param(2)}) + g = List([a, 3, a, nnx.Param(4)]) + + graphdef, state = nnx.graph.flatten( + g, with_paths=False, return_variables=True + ) + + self.assertLen(state, 2) + self.assertIsInstance(state, list) + self.assertIsInstance(state[0], nnx.Param) + self.assertIsInstance(state[1], nnx.Param) + + def test_clone_with_same_variables(self): + a = Dict({'a': 1, 'b': nnx.Param(2)}) + g = List([a, 3, a, nnx.Param(4)]) + + graphdef, state = nnx.graph.flatten( + g, with_paths=False, return_variables=True + ) + + g2 = nnx.graph.unflatten(graphdef, state) + + self.assertIsNot(g, g2) + self.assertIsNot(g[0], g2[0]) + self.assertIsNot(g[2], g2[2]) + self.assertIs(g[0]['b'], g2[0]['b']) + self.assertIs(g[3], g2[3]) + def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] @@ -303,7 +350,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.attributes[0].value.type is nnx.graph.GenericPytree + assert graphdef.attributes[0][1].type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state) @@ -329,26 +376,28 @@ def f(m: Foo): ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) + ref_in_idx_out = nnx.graph.RefMap( + {v: k for k, v in idx_out_ref_in.items()} + ) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() - graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) - static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) - return state, static_out - - static_out: nnx.graph.Static - state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] - graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in + ref_in_idx_in = nnx.graph.RefMap() + graphdef, state = nnx.graph.flatten( + m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out + ) + state = state.to_nested_state() + return state, graphdef + + state, graphdef_out = f_pure(graphdef, state) + idx_out_ref_out = {v: k for k, v in ref_out_idx_out.items()} + m2 = nnx.graph.unflatten( + graphdef_out, state, outer_index_outer_ref=idx_out_ref_out ) - m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -366,29 +415,31 @@ def f(m: Foo): a = m.a b = m.b - ref_out_idx_out = nnx.graph.RefMap[Any, int]() + ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + idx_out_ref_out = {v: k for k, v in ref_out_idx_out.items()} + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) + ref_in_idx_out = nnx.graph.RefMap( + {v: k for k, v in idx_out_ref_in.items()} + ) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() - graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) - static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) - return state, static_out - - static_out: nnx.graph.Static - state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] - graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in + ref_in_idx_in = nnx.graph.RefMap() + graphdef, state = nnx.graph.flatten( + m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out + ) + state = state.to_nested_state() + return state, graphdef + + state, graphdef = f_pure(graphdef, state) + m2 = nnx.graph.unflatten( + graphdef, state, outer_index_outer_ref=idx_out_ref_out ) - m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -406,26 +457,28 @@ def f(m: Foo): ref_out_idx_out = nnx.graph.RefMap() graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) + idx_out_ref_out = {v: k for k, v in ref_out_idx_out.items()} + state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) + ref_in_idx_out = nnx.graph.RefMap( + {v: k for k, v in idx_out_ref_in.items()} + ) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() - graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) - static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) - return state, static_out - - static_out: nnx.graph.Static - state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] - graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in + ref_in_idx_in = nnx.graph.RefMap() + graphdef, state = nnx.graph.flatten( + m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out + ) + state = state.to_nested_state() + return state, graphdef + + state, graphdef_out = f_pure(graphdef, state) + m2 = nnx.graph.unflatten( + graphdef_out, state, outer_index_outer_ref=idx_out_ref_out ) - m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.ref is m2 @@ -582,7 +635,7 @@ def __init__(self): @jax.jit def f(graphdef1, state1, graphdef2, state2): - with nnx.graph.merge_context(ctxtag) as ctx: + with nnx.graph.merge_context(True, ctxtag) as ctx: m1 = ctx.merge(graphdef1, state1) m2 = ctx.merge(graphdef2, state2) @@ -603,7 +656,7 @@ def f(graphdef1, state1, graphdef2, state2): graphdef1, state1, graphdef2, state2 ) - with nnx.graph.merge_context(ctxtag) as ctx: + with nnx.graph.merge_context(False, ctxtag) as ctx: m1_out = ctx.merge(graphdef1, state1) m2_out = ctx.merge(graphdef2, state2) @@ -671,7 +724,7 @@ def __init__(self): @jax.jit def f(pure_tree): - impure_tree2 = nnx.from_tree(pure_tree, ctxtag=ctxtag) + impure_tree2 = nnx.from_tree(pure_tree, ctxtag=ctxtag, is_inner=True) m1_out = impure_tree2[0] m2_out = impure_tree2[2]['b'] @@ -700,7 +753,7 @@ def f(pure_tree): pure_tree2 = f(pure_tree) - impure_tree2 = nnx.from_tree(pure_tree2, ctxtag=ctxtag) + impure_tree2 = nnx.from_tree(pure_tree2, ctxtag=ctxtag, is_inner=False) m1_out = impure_tree2[0] m2_out = impure_tree2[2]['b'] @@ -762,7 +815,7 @@ def split_fn(ctx: nnx.SplitContext, path, prefix, x): @partial(jax.vmap, in_axes=jax_in_axes, out_axes=(jax_in_axes, out_axes)) def f(*pure_args): - args = nnx.from_tree(pure_args, ctxtag=ctxtag) + args = nnx.from_tree(pure_args, ctxtag=ctxtag, is_inner=True) y = 0 @@ -785,7 +838,9 @@ def f(*pure_args): pure_args_out, y = f(*pure_args) - args_out, y = nnx.from_tree((pure_args_out, y), ctxtag=ctxtag) + args_out, y = nnx.from_tree( + (pure_args_out, y), ctxtag=ctxtag, is_inner=False + ) self.assertEqual(y.shape, (5,)) self.assertGreater(y.sum(), 5) @@ -793,6 +848,44 @@ def f(*pure_args): self.assertIs(m1, args_out[2]['b']) self.assertIs(m2, args_out[1]) + def test_fingerprint_basic(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m) + fp2 = nnx.graph.fingerprint(m) + + self.assertEqual(fp1, fp2) + self.assertTrue(nnx.graph.check_fingerprint(m, fp1)) + self.assertTrue(nnx.graph.check_fingerprint(m, fp2)) + + def test_fingerprint_variable_id_sensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp1 = nnx.graph.fingerprint(m1) + + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + fp2 = nnx.graph.fingerprint(m2) + + self.assertNotEqual(fp1, fp2) + self.assertTrue(nnx.graph.check_fingerprint(m1, fp1)) + self.assertTrue(nnx.graph.check_fingerprint(m2, fp2)) + self.assertFalse(nnx.graph.check_fingerprint(m1, fp2)) + self.assertFalse(nnx.graph.check_fingerprint(m2, fp1)) + + def test_fingerprint_module_id_insensitive(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + m1.kernel = m2.kernel + m1.bias = m2.bias + + fp1 = nnx.graph.fingerprint(m1) + fp2 = nnx.graph.fingerprint(m2) + + self.assertNotEqual(fp1, fp2) + self.assertTrue(nnx.graph.check_fingerprint(m1, fp1)) + self.assertTrue(nnx.graph.check_fingerprint(m2, fp2)) + self.assertFalse(nnx.graph.check_fingerprint(m1, fp2)) + self.assertFalse(nnx.graph.check_fingerprint(m2, fp1)) + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 64928f46b8..9a7cd0a7ce 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -25,7 +25,6 @@ import jax.numpy as jnp import numpy as np - A = TypeVar('A') class List(nnx.Module): @@ -262,13 +261,13 @@ def test_clone(self): m2 = nnx.clone(m) assert m is not m2 - assert m2.a[0] == m2.b.c - assert m2.a[1] == m2.b.d + assert m2.a[0].value == m2.b.c.value + assert m2.a[1].value == m2.b.d.value - assert m.a[0] == m2.a[0] - assert m.a[1] == m2.a[1] - assert m.b.c == m2.b.c - assert m.b.d == m2.b.d + assert m.a[0].value == m2.a[0].value + assert m.a[1].value == m2.a[1].value + assert m.b.c.value == m2.b.c.value + assert m.b.d.value == m2.b.d.value def test_sow_basic(self): class Foo(nnx.Module): @@ -465,7 +464,7 @@ def __init__(self) -> None: m1 = Foo() m2 = deepcopy(m1) - assert m1.a == m2.a + assert m1.a.value == m2.a.value assert vars(m1)['a'] is not vars(m2)['a'] assert m1.b is not m2.b assert m1.c is not m2.c @@ -639,6 +638,9 @@ class Foo(nnx.Module): e: nnx.Variable[int] f: int + def __hash__(self): + return id(self) + m = Foo( a=1, # graphdef b=nnx.Variable(2), # node @@ -717,7 +719,7 @@ def __call__(self, x, *, rngs: nnx.Rngs): graphdef, state = nnx.split(foo) - assert isinstance(graphdef, nnx.GraphDef) + assert isinstance(graphdef, nnx.graph.NodeDef | nnx.graph.NodeRef) assert isinstance(state, nnx.State) assert issubclass(state.w.type, nnx.Param) assert issubclass(state.c.type, nnx.Variable) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index bfa461be39..10653ef20a 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -67,6 +67,27 @@ def g(m: Dict): assert m.a == 2 assert out == 1.0 + def test_simple_double_call(self): + n = 0 + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + @nnx.jit + def f(m: nnx.Linear, x: jnp.ndarray) -> jnp.ndarray: + nonlocal n + n += 1 + return m(x) + + x = jnp.ones((1, 2)) + y = f(m, x) + + self.assertEqual(n, 1) + self.assertEqual(y.shape, (1, 3)) + + y = f(m, x) + + self.assertEqual(n, 1) + self.assertEqual(y.shape, (1, 3)) + def test_jit_on_init(self): n = 0 @@ -634,6 +655,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + @nnx.custom_vjp def f(m: Foo): m.z += 1 @@ -674,6 +698,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + x_in_path = nnx.PathContains('x') diff_state = nnx.DiffState(0, x_in_path) @@ -715,6 +742,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + @nnx.custom_vjp @nnx.remat def f(m: Foo): @@ -760,6 +790,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + @nnx.custom_vjp def f(m1: Foo, m2: Foo): m1.z += 1 @@ -813,6 +846,9 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int + def __hash__(self): + return id(self) + @nnx.custom_vjp(nondiff_argnums=(0, 2)) def f(a, m: Foo, b): self.assertEqual(a, 1) @@ -1006,6 +1042,9 @@ def test_all_carry(self): class Foo(nnx.Module): n: nnx.BatchStat[int] + def __hash__(self): + return id(self) + foo = Foo(n=nnx.BatchStat(0)) @nnx.scan(in_axes=nnx.Carry, out_axes=nnx.Carry, length=3) @@ -1036,9 +1075,9 @@ def loop(foo: Foo, x): loop(foo, 0) def test_all_carry_new_reference_error(self): - @dataclasses.dataclass(repr=False) class Foo(nnx.Module): - n: nnx.BatchStat[int] + def __init__(self, n: nnx.BatchStat[int]): + self.n = n xs = jnp.arange(3) foo = Foo(n=nnx.BatchStat(0)) @@ -1056,9 +1095,9 @@ def loop(foo: Foo, x): loop(foo, xs) def test_all_scan(self): - @dataclasses.dataclass(repr=False) class Foo(nnx.Module): - n: nnx.BatchStat[jax.Array] + def __init__(self, n: nnx.BatchStat[jax.Array]): + self.n = n xs = jnp.arange(3) foo = Foo(n=nnx.BatchStat(jnp.arange(3))) @@ -1075,9 +1114,9 @@ def loop(foo: Foo, x): np.testing.assert_allclose(foo.n.value, jnp.arange(1, 4)) def test_all_broadcast(self): - @dataclasses.dataclass(repr=False) class Foo(nnx.Module): - n: nnx.BatchStat[int] + def __init__(self, n: nnx.BatchStat[int]): + self.n = n xs = jnp.array(1) foo = Foo(n=nnx.BatchStat(2)) @@ -1740,7 +1779,6 @@ def test_cache_tracing_object(self): x = jnp.arange(5) count = jnp.array(0) - @dataclasses.dataclass class Foo(nnx.Object): @nnx.split_rngs(splits=5) @@ -2696,6 +2734,9 @@ def zero(): class Foo(nnx.Object): timestep: TimeStep + def __hash__(self): + return id(self) + def update(self): def reward_2(self: Foo): self.timestep = TimeStep( @@ -2985,18 +3026,6 @@ def loop_fn(inputs): nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2)) nnx.fori_loop(0, 2, fori_loop_fn, (a, b)) - def test_fori_output(self): - model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0))) - model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1))) - - def f(i, x): - return x - - model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2)) - - self.assertIs(model, model_out) - self.assertIs(model2, model2_out) - class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): @@ -3093,6 +3122,9 @@ def test_basic(self): class Foo(nnx.Module): a: nnx.Param + def __hash__(self): + return id(self) + @nnx.jit def f(m): y = jnp.sin(m.a.value) # error diff --git a/uv.lock b/uv.lock index 48bda4f756..bd7053e5ad 100644 --- a/uv.lock +++ b/uv.lock @@ -838,6 +838,12 @@ testing = [ { name = "treescope" }, ] +[package.dev-dependencies] +dev = [ + { name = "nanobind" }, + { name = "scikit-build-core" }, +] + [package.metadata] requires-dist = [ { name = "cloudpickle", marker = "extra == 'testing'", specifier = ">=3.0.0" }, @@ -890,11 +896,17 @@ requires-dist = [ { name = "tensorflow-text", marker = "platform_system != 'Darwin' and extra == 'testing'", specifier = ">=2.11.0" }, { name = "tensorstore" }, { name = "torch", marker = "extra == 'testing'" }, - { name = "treescope", specifier = ">=0.1.7" }, + { name = "treescope", specifier = ">=0.1.2" }, { name = "treescope", marker = "python_full_version >= '3.10' and extra == 'testing'", specifier = ">=0.1.1" }, { name = "typing-extensions", specifier = ">=4.2" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "nanobind", specifier = ">=2.4.0" }, + { name = "scikit-build-core", extras = ["pyproject"], specifier = ">=0.10.7" }, +] + [[package]] name = "fonttools" version = "4.53.1" @@ -1935,6 +1947,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/59/7854fbfb59f8ae35483ce93493708be5942ebb6328cd85b3a609df629736/namex-0.0.8-py3-none-any.whl", hash = "sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487", size = 5806 }, ] +[[package]] +name = "nanobind" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/01/a28722f6626e5c8a606dee71cb40c0b2ab9f7715b96bd34a9553c79dbf42/nanobind-2.4.0.tar.gz", hash = "sha256:a0392dee5f58881085b2ac8bfe8e53f74285aa4868b1472bfaf76cfb414e1c96", size = 953467 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/07/abff41fcade3613349eac71dacb166352babef515efd960a751e3175c262/nanobind-2.4.0-py3-none-any.whl", hash = "sha256:8cf27b04fbadeb9deb4a73f02bd838bf9f7e3e5a8ce44c50c93142b5728da58a", size = 232882 }, +] + [[package]] name = "nbclient" version = "0.10.0" @@ -2303,6 +2324,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, ] +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2956,6 +2986,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/ea/6f121d1802f3adae1981aea4209ea66f9d3c7f2f6d6b85ef4f13a61d17ef/rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989", size = 213529 }, ] +[[package]] +name = "scikit-build-core" +version = "0.10.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/75/ad5664c8050bbbea46a5f2b6a3dfbc6e6cf284826c0eee0a12f861364b3f/scikit_build_core-0.10.7.tar.gz", hash = "sha256:04cbb59fe795202a7eeede1849112ee9dcbf3469feebd9b8b36aa541336ac4f8", size = 255019 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/fe/90476c4f6a1b2f922efa00d26e876dd40c7279e28ec18f08f0851ad21ba6/scikit_build_core-0.10.7-py3-none-any.whl", hash = "sha256:5e13ab7ca7c3c6dd019607c3a6f53cba67dade8757c4c4f75b459e2f90e4dbc3", size = 165511 }, +] + [[package]] name = "scikit-learn" version = "1.5.1" @@ -3669,14 +3714,14 @@ wheels = [ [[package]] name = "treescope" -version = "0.1.7" +version = "0.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/40/34/8ad5475c26837ca400c77951bcc0788b5f291d1509ae2eda5f97b042c24a/treescope-0.1.7.tar.gz", hash = "sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3", size = 530052 } +sdist = { url = "https://files.pythonhosted.org/packages/2f/5d/ecb176971c78d90a3f74b7878ab9d013995fed285e3386a503ca008c9b03/treescope-0.1.2.tar.gz", hash = "sha256:2e4b35780884dfdbdcf44315d1c1c98fcf41daa0ea48a5b45ecc716920f88c86", size = 402255 } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/7d/f6da2b223749c58ec8ff95c87319196765fed05bd44dd86fb9bc4bf35f77/treescope-0.1.7-py3-none-any.whl", hash = "sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102", size = 175566 }, + { url = "https://files.pythonhosted.org/packages/af/11/1a4d1877e5f7202bb3d0778a77b6ca222848b9b36fa65cbbc1fe12cb82b7/treescope-0.1.2-py3-none-any.whl", hash = "sha256:1811df6fbf79a5f54804e3ce2230b100547dc6350c99d973a6b9ba2bcd932e57", size = 172154 }, ] [[package]]