diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py
index 88809f777..6d10f79e0 100644
--- a/benchmarks/nnx_graph_overhead.py
+++ b/benchmarks/nnx_graph_overhead.py
@@ -24,31 +24,52 @@
from absl import app
FLAGS = flags.FLAGS
-flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in')
+flags.DEFINE_enum(
+ 'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
+)
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')
-
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
- self.list = [
- nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
- nnx.Param(jnp.zeros((dout,))),
- ]
- self.dict = {
- 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
- 'b': nnx.Param(jnp.zeros((dout,))),
- }
+ self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
+ self.b = nnx.Param(jnp.zeros((dout,)))
+
+ def __call__(self, x):
+ return x @ self.w + self.b
+
+
+class Block(nnx.Module):
+ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
+ self.linear = Linear(din, dout, rngs=rngs)
+ self.bn = nnx.BatchNorm(dout, rngs=rngs)
+
+ def __call__(self, x):
+ return nnx.relu(self.bn(self.linear(x)))
+
+class Count(nnx.Variable):
+ pass
class MLP(nnx.Module):
- def __init__(self, depth, *, rngs: nnx.Rngs):
+ def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
+ self.count = Count(jnp.array(0))
+ self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
- Linear(10, 10, rngs=rngs) for _ in range(depth)
+ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
+ self.linear_out = Block(dhidden, dout, rngs=rngs)
+
+ def __call__(self, x):
+ self.count.value += 1
+ x = nnx.relu(self.linear_in(x))
+ for layer in self.intermediates:
+ x = nnx.relu(layer(x))
+ x = self.linear_out(x)
+ return x
def main(argv):
@@ -63,21 +84,24 @@ def main(argv):
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
- model = MLP(depth=depth, rngs=nnx.Rngs(0))
- tx = optax.sgd(1e-3)
- optimizer = nnx.Optimizer(model, tx)
-
#------------------------------------------------------------
# NNX
#------------------------------------------------------------
if mode in ['all', 'nnx']:
+ model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+ tx = optax.sgd(1e-3)
+ optimizer = nnx.Optimizer(model, tx)
+ t0 = time()
+
@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass
+ cached_step_nnx = nnx.cache_args(step_nnx, model, optimizer)
+
t0 = time()
for _ in range(total_steps):
- step_nnx(model, optimizer)
+ cached_step_nnx()
total_time = time() - t0
time_per_step = total_time / total_steps
@@ -93,6 +117,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
#------------------------------------------------------------
if mode in ['all', 'jax']:
+ model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+ tx = optax.sgd(1e-3)
+ optimizer = nnx.Optimizer(model, tx)
+ t0 = time()
+
@jax.jit
def step_jax(graphdef, state):
return graphdef, state
diff --git a/benchmarks/nnx_mlpmixer_training.py b/benchmarks/nnx_mlpmixer_training.py
new file mode 100644
index 000000000..68d5e7973
--- /dev/null
+++ b/benchmarks/nnx_mlpmixer_training.py
@@ -0,0 +1,235 @@
+# Copyright 2024 The Flax Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# %%
+from functools import partial
+import jax
+import jax.numpy as jnp
+from flax import nnx
+import optax
+import numpy as np
+from einop import einop
+from time import time
+from tqdm import tqdm
+
+from flax import nnx
+
+from absl import flags
+from absl import app
+
+FLAGS = flags.FLAGS
+flags.DEFINE_enum(
+ 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
+)
+flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
+flags.DEFINE_integer('batch_size', 32, 'Batch size')
+flags.DEFINE_integer('width', 32, 'Hidden layer size')
+flags.DEFINE_integer('depth', 4, 'Depth of the model')
+
+
+class MlpBlock(nnx.Module):
+ def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
+ self.din, self.mlp_dim = din, mlp_dim
+ self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs)
+ self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs)
+
+ def __call__(self, x):
+ return self.linear_out(nnx.gelu(self.linear_in(x)))
+
+
+class MixerBlock(nnx.Module):
+ def __init__(
+ self,
+ tokens_mlp_dim: int,
+ channels_mlp_dim: int,
+ hidden_dim: int,
+ rngs: nnx.Rngs,
+ ):
+ self.tokens_mlp_dim = tokens_mlp_dim
+ self.channels_mlp_dim = channels_mlp_dim
+ self.hidden_dim = hidden_dim
+ self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs)
+ self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs)
+ self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
+ self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
+
+ def __call__(self, x):
+ y = self.ln1(x)
+ y = y.swapaxes(1, 2)
+ y = self.token_mixing(y)
+ y = y.swapaxes(1, 2)
+ x = x + y
+ y = self.ln2(x)
+ return x + self.channel_mixing(y)
+
+
+class MlpMixer(nnx.Module):
+ def __init__(
+ self,
+ din: int,
+ kernel_size: tuple[int, int],
+ strides: tuple[int, int],
+ num_blocks: int,
+ hidden_dim: int,
+ tokens_mlp_dim: int,
+ channels_mlp_dim: int,
+ rngs: nnx.Rngs,
+ ):
+ self.din = din
+ self.kernel_size = kernel_size
+ self.num_blocks = num_blocks
+ self.hidden_dim = hidden_dim
+ self.tokens_mlp_dim = tokens_mlp_dim
+ self.channels_mlp_dim = channels_mlp_dim
+ self.stem = nnx.Conv(
+ din + 1,
+ channels_mlp_dim,
+ kernel_size=kernel_size,
+ strides=strides,
+ rngs=rngs,
+ )
+ self.blocks = [
+ MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs)
+ for _ in range(num_blocks)
+ ]
+ self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
+ self.conv_t = nnx.ConvTranspose(
+ channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs
+ )
+
+ def __call__(self, *, x, t):
+ # add time feature to input
+ t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1)
+ x = jnp.concatenate([x, t], axis=-1)
+ # create patches
+ x = self.stem(x)
+ h, w = x.shape[1], x.shape[2]
+ x = einop(x, 'n h w c -> n (h w) c')
+ # apply blocks
+ for block in self.blocks:
+ x = block(x)
+ x = self.pre_head_layer_norm(x)
+ # recreate image
+ x = einop(x, 'n (h w) c -> n h w c', h=h, w=w)
+ x = self.conv_t(x)
+ return x
+
+
+def main(argv):
+ print(argv)
+ mode: str = FLAGS.mode
+ total_steps: int = FLAGS.total_steps
+ batch_size: int = FLAGS.batch_size
+ width: int = FLAGS.width
+ depth: int = FLAGS.depth
+
+ print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
+
+ X = np.random.uniform(size=(batch_size, 28, 28, 1))
+
+ if mode == 'nnx' or mode == 'all':
+ rngs = nnx.Rngs(0)
+ flow = MlpMixer(
+ din=1,
+ kernel_size=(2, 2),
+ strides=(2, 2),
+ num_blocks=4,
+ hidden_dim=512,
+ tokens_mlp_dim=196,
+ channels_mlp_dim=512,
+ rngs=rngs,
+ )
+ optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
+ t0 = time()
+
+ mse = lambda a, b: jnp.mean((a - b) ** 2)
+
+ @nnx.jit(donate_argnums=(0, 1, 2))
+ def train_step_nnx(flow, optimizer, rngs, x_1):
+ print('JITTING NNX')
+ x_0 = jax.random.normal(rngs(), x_1.shape)
+ t = jax.random.uniform(rngs(), (len(x_1),))
+
+ x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
+ dx_t = x_1 - x_0
+
+ loss, grads = nnx.value_and_grad(
+ lambda flow: mse(flow(x=x_t, t=t), dx_t)
+ )(flow)
+ optimizer.update(grads)
+ return loss
+
+ losses = []
+ t0 = time()
+ for step in tqdm(range(total_steps), desc='NNX'):
+ loss = train_step_nnx(flow, optimizer, rngs, X)
+ losses.append(loss)
+
+ total_time = time() - t0
+ print('### NNX ###')
+ print(f'final loss: {losses[-1]}')
+ print('total time:', total_time)
+ print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+
+ if mode == 'jax' or mode == 'all':
+ rngs = nnx.Rngs(0)
+ flow = MlpMixer(
+ din=1,
+ kernel_size=(2, 2),
+ strides=(2, 2),
+ num_blocks=depth,
+ hidden_dim=width,
+ tokens_mlp_dim=196,
+ channels_mlp_dim=width,
+ rngs=rngs,
+ )
+ optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
+ graphdef, state = nnx.split((flow, optimizer, rngs))
+ t0 = time()
+
+ mse = lambda a, b: jnp.mean((a - b) ** 2)
+
+ @partial(nnx.jit, donate_argnums=0)
+ def train_step_jax(state, x_1):
+ print('JITTING JAX')
+ flow, optimizer, rngs = nnx.merge(graphdef, state)
+ x_0 = jax.random.normal(rngs(), x_1.shape)
+ t = jax.random.uniform(rngs(), (len(x_1),))
+
+ x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
+ dx_t = x_1 - x_0
+
+ loss, grads = nnx.value_and_grad(
+ lambda flow: mse(flow(x=x_t, t=t), dx_t)
+ )(flow)
+ optimizer.update(grads)
+ state = nnx.state((flow, optimizer, rngs))
+ return loss, state
+
+ losses = []
+ t0 = time()
+ for step in tqdm(range(total_steps), desc='JAX'):
+ loss, state = train_step_jax(state, X)
+ losses.append(loss)
+
+ nnx.update((flow, optimizer, rngs), state)
+ total_time = time() - t0
+ print('### JAX ###')
+ print(f'final loss: {losses[-1]}')
+ print('total time:', total_time)
+ print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py
index 0cb08066f..88195b3ff 100644
--- a/benchmarks/nnx_simple_training.py
+++ b/benchmarks/nnx_simple_training.py
@@ -13,6 +13,7 @@
# limitations under the License.
# %%
+from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
@@ -25,7 +26,9 @@
from absl import app
FLAGS = flags.FLAGS
-flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
+flags.DEFINE_enum(
+ 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
+)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
@@ -46,6 +49,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
def __call__(self, x):
return x @ self.w + self.b
+class Block(nnx.Module):
+ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
+ self.linear = Linear(din, dout, rngs=rngs)
+ self.bn = nnx.BatchNorm(dout, rngs=rngs)
+
+ def __call__(self, x):
+ return nnx.relu(self.bn(self.linear(x)))
class Count(nnx.Variable):
pass
@@ -54,11 +64,11 @@ class Count(nnx.Variable):
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
- self.linear_in = Linear(din, dhidden, rngs=rngs)
+ self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
- Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
+ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
- self.linear_out = Linear(dhidden, dout, rngs=rngs)
+ self.linear_out = Block(dhidden, dout, rngs=rngs)
def __call__(self, x):
self.count.value += 1
@@ -79,20 +89,16 @@ def main(argv):
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
- if mode not in ['nnx', 'jax']:
- raise ValueError(f'Invalid mode: {mode}')
-
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
- model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
- tx = optax.sgd(1e-3)
- optimizer = nnx.Optimizer(model, tx)
- t0 = time()
-
- if mode == 'nnx':
+ if mode == 'nnx' or mode == 'all':
+ model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+ tx = optax.sgd(1e-3)
+ optimizer = nnx.Optimizer(model, tx)
+ t0 = time()
- @nnx.jit
+ @nnx.jit(donate_argnums=(0, 1))
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch
@@ -103,26 +109,40 @@ def loss_fn(model: MLP):
grads: nnx.State = nnx.grad(loss_fn)(model)
optimizer.update(grads)
- @nnx.jit
+ @nnx.jit(donate_argnums=0)
def test_step_nnx(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}
+ cached_train_step_nnx = nnx.cache_args(train_step_nnx, model, optimizer)
+ cached_test_step_nnx = nnx.cache_args(test_step_nnx, model)
+
for step, batch in enumerate(dataset(X, Y, batch_size)):
- train_step_nnx(model, optimizer, batch)
+ cached_train_step_nnx(batch)
if step % 1000 == 0:
- logs = test_step_nnx(model, (X, Y))
- print(f"step: {step}, loss: {logs['loss']}")
+ logs = cached_test_step_nnx((X, Y))
if step >= total_steps - 1:
break
- else:
- @jax.jit
- def train_step_jax(graphdef, state, batch):
+ print('### NNX ###')
+ print(f"final loss: {logs['loss']}")
+ total_time = time() - t0
+ print('total time:', total_time)
+ print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+ print('times called:', model.count.value)
+
+ if mode == 'jax' or mode == 'all':
+ model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
+ tx = optax.sgd(1e-3)
+ optimizer = nnx.Optimizer(model, tx)
+ t0 = time()
+
+ @partial(jax.jit, donate_argnums=0)
+ def train_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch
@@ -135,8 +155,8 @@ def loss_fn(model: MLP):
return nnx.state((model, optimizer))
- @jax.jit
- def test_step_jax(graphdef, state, batch):
+ @partial(jax.jit, donate_argnums=0)
+ def test_step_jax(state, batch):
model, optimizer = nnx.merge(graphdef, state)
x, y = batch
y_pred = model(x)
@@ -147,21 +167,22 @@ def test_step_jax(graphdef, state, batch):
graphdef, state = nnx.split((model, optimizer))
for step, batch in enumerate(dataset(X, Y, batch_size)):
- state = train_step_jax(graphdef, state, batch)
+ state = train_step_jax(state, batch)
if step % 1000 == 0:
- state, logs = test_step_jax(graphdef, state, (X, Y))
- print(f"step: {step}, loss: {logs['loss']}")
+ state, logs = test_step_jax(state, (X, Y))
if step >= total_steps - 1:
break
model, optimizer = nnx.merge(graphdef, state)
- total_time = time() - t0
- print('total time:', total_time)
- print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
- print('times called:', model.count.value)
+ print('### JAX ###')
+ print(f"final loss: {logs['loss']}")
+ total_time = time() - t0
+ print('total time:', total_time)
+ print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
+ print('times called:', model.count.value)
if __name__ == '__main__':
diff --git a/docs_nnx/api_reference/flax.nnx/graph.rst b/docs_nnx/api_reference/flax.nnx/graph.rst
index 2cf65c945..2630d256a 100644
--- a/docs_nnx/api_reference/flax.nnx/graph.rst
+++ b/docs_nnx/api_reference/flax.nnx/graph.rst
@@ -16,6 +16,7 @@ graph
.. autofunction:: iter_graph
.. autofunction:: clone
.. autofunction:: call
+.. autofunction:: cache_args
.. autoclass:: GraphDef
:members:
diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst
index 54ba3399a..5b4440ed3 100644
--- a/docs_nnx/api_reference/flax.nnx/transforms.rst
+++ b/docs_nnx/api_reference/flax.nnx/transforms.rst
@@ -15,6 +15,7 @@ transforms
.. autofunction:: grad
.. autofunction:: jit
+.. autofunction:: shard_map
.. autofunction:: remat
.. autofunction:: scan
.. autofunction:: value_and_grad
diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py
index 7eee47063..a32f4ae0c 100644
--- a/docs_nnx/conf.py
+++ b/docs_nnx/conf.py
@@ -137,7 +137,7 @@
# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
-# nb_execution_mode = 'off'
+nb_execution_mode = 'off'
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
diff --git a/docs_nnx/guides/performance.ipynb b/docs_nnx/guides/performance.ipynb
index 8f91fb704..3c4671624 100644
--- a/docs_nnx/guides/performance.ipynb
+++ b/docs_nnx/guides/performance.ipynb
@@ -6,29 +6,12 @@
"source": [
"# Performance considerations\n",
"\n",
- "Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python, which is slow and adds overhead. This is why in order to solve this the Flax team will be developing a Rust extension called `flaxlib` to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). This will be similar to how the JAX team resolved a similar issue by introducing [`jaxlib`](https://jax.readthedocs.io/en/latest/installation.html#installation) for standard [JAX pytrees](https://jax.readthedocs.io/en/latest/key-concepts.html#pytrees) (refer to the first steps in [Flax PR #4196](https://github.com/google/flax/pull/4196)).\n",
+ "Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python which can add overhead. This overhead mostly affects small to medium models and can be mitigated in the following ways:\n",
+ "* By leveraging JAX's [Asynchronous dispatch](#asynchronous-dispatch).\n",
+ "* By using [nnx.cache_args](#caching-graph-node-traversals) to cache the graph node traversals.\n",
+ "* By using a [Functional training loop](#functional-training-loop) which stages out the graph traversals.\n",
"\n",
- "However, there are two things to consider:\n",
- "\n",
- "* The overhead is only relevant for small models (refer to [Asynchronous dispatch](#asynchronous-dispatch).\n",
- "* You can remove the overhead by using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) + [`flax.nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) / [`flax.nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to stage out the traversal logic (Refer to [Lowering the Python overhead](#lowering-the-python-overhead).\n",
- "\n",
- "\n",
- "## Asynchronous dispatch\n",
- "\n",
- "In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html).\n",
- "\n",
- "As demonstrated in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. \n",
- "\n",
- "![performance-graph](images/performance-graph.png)\n",
- "\n",
- "This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.\n",
- "\n",
- "## Lowering the Python overhead\n",
- "\n",
- "To remove the Python overhead, you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic.\n",
- "\n",
- "To learn how to do this, let’s first create the following simple `Model`:"
+ "A full resolution _might_ involve developing a C extension (e.g. `flaxlib`) to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). Before we continue lets an example of a model and a simple training loop:"
]
},
{
@@ -51,22 +34,8 @@
"\n",
" def __call__(self, x):\n",
" x = nnx.relu(self.dropout(self.bn(self.linear(x))))\n",
- " return self.linear_out(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, let’s create a `train_step()` function that uses [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), taking in the `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
+ " return self.linear_out(x)\n",
+ " \n",
"model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n",
"optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing\n",
"metrics = nnx.MultiMetric(\n",
@@ -94,11 +63,61 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "To speed this up, before starting the training loop we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) over all the Flax NNX objects that are inputs to `train_step()` to create `graphdef` and `state` pytrees that are faster to traverse.\n",
+ "Important thing here is that we created a `train_step()` function that uses `nnx.jit` and takes in a `model`, `optimizer`, and `metrics` arguments, all of which are Flax NNX objects. We'll later see how to improve this."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Asynchronous dispatch\n",
+ "\n",
+ "Asynchronous dispatch is a feature of JAX where it runs operations in the background whenever possible so Python can continue executing other code. This can be use to absorve the cost of data loading and in this case the overhead of `nnx.jit` and similar transforms. In general, as the amount of computation JAX has to perform per iteration increases the more it is able to absorve the python overhead since eventually the JAX computation will be the main blocker and programs with different overhead will have the same performance. This could be achieved in a couple of ways:\n",
+ "\n",
+ "* Increasing the batch size.\n",
+ "* Increasing the model size.\n",
+ "* Performing more JAX steps per python step if data loading is fast enough.\n",
"\n",
- "Next, we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) at the beginning and the end of `train_step()` to switch back and forth between the objects and their pytree representations. And even though `nnx.split` and `nnx.merge` are slow, it doesn't matter because they will run only once during tracing.\n",
+ "To demonstrate this, the graph below which shows total time of running [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) for both `jax.jit` and `nnx.jit` with different model sizes:\n",
"\n",
- "With this in place, we can change the `train_step()` function to use `jax.jit` instead of `nnx.jit`:"
+ "![performance-graph](images/performance-graph.png)\n",
+ "\n",
+ "As we can observe, after a certain model size both `jax.jit` and `nnx.jit` converge to the same runtime cost. This means we don't have to modify our training loop above.\n",
+ "\n",
+ "## Caching graph node traversals\n",
+ "\n",
+ "The simplest way to get rid of the traversal overhead entirely is by using `nnx.cache_args` to convert a transformed function and the input graph objects into a partial function which caches the graph object and just expects the remaining arguments. In this example we use `nnx.cache_args` over `train_step` and partially apply `model`, `optimizer`, and `metrics`, to create `faster_train_step`. Then we simply update our training loop to use `faster_train_step` which only expects the `x` and `y` inputs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "faster_train_step = nnx.cache_args(train_step, model, optimizer, metrics)\n",
+ "\n",
+ "for _ in range(10):\n",
+ " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n",
+ " loss = faster_train_step(x, y)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note that `cache_args` will enforce that the structure of the graph nodes doesn't change during `train_step` (no mutations except for `Variable` state update) so the cache is guaranteed to be up-to-date and we can avoid costly checks which require traversals. This is actually what is expected for most step functions as making any change here would imply costly recompilation, so enforcing this might be a secondary feature that could be useful for this purpose.\n",
+ "\n",
+ "Similarly, to prevent the user from mutating the cached objects outside, `cache_args` creates a copy of all the graph nodes but, to allow state to be propagated to the original objects, they share references to the same `Variable`s."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Functional training loop\n",
+ "\n",
+ "To remove the Python overhead we can create a functional training loop that uses regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. Concretely we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) before the training loop to create a single `graphdef` and `state` pytrees for all the graph nodes. Then we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to recreate the objects inside, and either `nnx.split` or `nnx.state` at the end to get the output `state`. At the end of the training loop or whenever needed we can use [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update the objects to the current `state`."
]
},
{
@@ -107,16 +126,11 @@
"metadata": {},
"outputs": [],
"source": [
- "model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n",
- "optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing\n",
- "metrics = nnx.MultiMetric(\n",
- " loss=nnx.metrics.Average('loss'),\n",
- ")\n",
"# split before training loop\n",
"graphdef, state = nnx.split((model, optimizer, metrics))\n",
"\n",
"@jax.jit # regular JAX\n",
- "def train_step(graphdef, state, x, y):\n",
+ "def jax_train_step(graphdef, state, x, y):\n",
" # merge at the beginning of the function\n",
" model, optimizer, metrics = nnx.merge(graphdef, state)\n",
"\n",
@@ -128,15 +142,12 @@
" optimizer.update(grads)\n",
" metrics.update(loss=loss)\n",
"\n",
- " # split at the end of the function\n",
- " _, state = nnx.split((model, optimizer, metrics))\n",
- "\n",
- " # return new state\n",
- " return state, loss\n",
+ " state = nnx.state((model, optimizer, metrics))\n",
+ " return loss, state\n",
"\n",
"for _ in range(10):\n",
" x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n",
- " state, loss = train_step(graphdef, state, x, y)\n",
+ " state, loss = jax_train_step(graphdef, state, x, y)\n",
"\n",
"# update objects after training\n",
"nnx.update((model, optimizer, metrics), state)"
@@ -146,9 +157,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Notice that we only do this for `jit`. You can still use other [Flax transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html#transformations) like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) shown in the above example since their overhead is already absorbed by the outer `jit`.\n",
- "\n",
- "And after the training loop is done (or whenever it is needed), we can use Flax [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`."
+ "Notice that we only need to do this for `jit`, the use of other Flax transforms like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) inside `train_step` doesn't have any performance cost since `jit` will make sure this only traced once."
]
}
],
@@ -158,7 +167,7 @@
},
"language_info": {
"name": "python",
- "version": "3.10.13"
+ "version": "3.11.9"
}
},
"nbformat": 4,
diff --git a/docs_nnx/guides/performance.md b/docs_nnx/guides/performance.md
index d2b67260b..e9f7c0ead 100644
--- a/docs_nnx/guides/performance.md
+++ b/docs_nnx/guides/performance.md
@@ -10,29 +10,12 @@ jupytext:
# Performance considerations
-Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python, which is slow and adds overhead. This is why in order to solve this the Flax team will be developing a Rust extension called `flaxlib` to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). This will be similar to how the JAX team resolved a similar issue by introducing [`jaxlib`](https://jax.readthedocs.io/en/latest/installation.html#installation) for standard [JAX pytrees](https://jax.readthedocs.io/en/latest/key-concepts.html#pytrees) (refer to the first steps in [Flax PR #4196](https://github.com/google/flax/pull/4196)).
+Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python which can add overhead. This overhead mostly affects small to medium models and can be mitigated in the following ways:
+* By leveraging JAX's [Asynchronous dispatch](#asynchronous-dispatch).
+* By using [nnx.cache_args](#caching-graph-node-traversals) to cache the graph node traversals.
+* By using a [Functional training loop](#functional-training-loop) which stages out the graph traversals.
-However, there are two things to consider:
-
-* The overhead is only relevant for small models (refer to [Asynchronous dispatch](#asynchronous-dispatch).
-* You can remove the overhead by using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) + [`flax.nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) / [`flax.nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to stage out the traversal logic (Refer to [Lowering the Python overhead](#lowering-the-python-overhead).
-
-
-## Asynchronous dispatch
-
-In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html).
-
-As demonstrated in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`.
-
-![performance-graph](images/performance-graph.png)
-
-This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.
-
-## Lowering the Python overhead
-
-To remove the Python overhead, you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic.
-
-To learn how to do this, let’s first create the following simple `Model`:
+A full resolution _might_ involve developing a C extension (e.g. `flaxlib`) to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). Before we continue lets an example of a model and a simple training loop:
```{code-cell}
from flax import nnx
@@ -50,11 +33,7 @@ class Model(nnx.Module):
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
-```
-
-Next, let’s create a `train_step()` function that uses [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), taking in the `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:
-
-```{code-cell}
+
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
metrics = nnx.MultiMetric(
@@ -78,23 +57,52 @@ for _ in range(10):
loss = train_step(model, optimizer, metrics, x, y)
```
-To speed this up, before starting the training loop we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) over all the Flax NNX objects that are inputs to `train_step()` to create `graphdef` and `state` pytrees that are faster to traverse.
+Important thing here is that we created a `train_step()` function that uses `nnx.jit` and takes in a `model`, `optimizer`, and `metrics` arguments, all of which are Flax NNX objects. We'll later see how to improve this.
-Next, we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) at the beginning and the end of `train_step()` to switch back and forth between the objects and their pytree representations. And even though `nnx.split` and `nnx.merge` are slow, it doesn't matter because they will run only once during tracing.
++++
-With this in place, we can change the `train_step()` function to use `jax.jit` instead of `nnx.jit`:
+## Asynchronous dispatch
+
+Asynchronous dispatch is a feature of JAX where it runs operations in the background whenever possible so Python can continue executing other code. This can be use to absorve the cost of data loading and in this case the overhead of `nnx.jit` and similar transforms. In general, as the amount of computation JAX has to perform per iteration increases the more it is able to absorve the python overhead since eventually the JAX computation will be the main blocker and programs with different overhead will have the same performance. This could be achieved in a couple of ways:
+
+* Increasing the batch size.
+* Increasing the model size.
+* Performing more JAX steps per python step if data loading is fast enough.
+
+To demonstrate this, the graph below which shows total time of running [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) for both `jax.jit` and `nnx.jit` with different model sizes:
+
+![performance-graph](images/performance-graph.png)
+
+As we can observe, after a certain model size both `jax.jit` and `nnx.jit` converge to the same runtime cost. This means we don't have to modify our training loop above.
+
+## Caching graph node traversals
+
+The simplest way to get rid of the traversal overhead entirely is by using `nnx.cache_args` to convert a transformed function and the input graph objects into a partial function which caches the graph object and just expects the remaining arguments. In this example we use `nnx.cache_args` over `train_step` and partially apply `model`, `optimizer`, and `metrics`, to create `faster_train_step`. Then we simply update our training loop to use `faster_train_step` which only expects the `x` and `y` inputs:
+
+```{code-cell}
+faster_train_step = nnx.cache_args(train_step, model, optimizer, metrics)
+
+for _ in range(10):
+ x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
+ loss = faster_train_step(x, y)
+```
+
+Note that `cache_args` will enforce that the structure of the graph nodes doesn't change during `train_step` (no mutations except for `Variable` state update) so the cache is guaranteed to be up-to-date and we can avoid costly checks which require traversals. This is actually what is expected for most step functions as making any change here would imply costly recompilation, so enforcing this might be a secondary feature that could be useful for this purpose.
+
+Similarly, to prevent the user from mutating the cached objects outside, `cache_args` creates a copy of all the graph nodes but, to allow state to be propagated to the original objects, they share references to the same `Variable`s.
+
++++
+
+## Functional training loop
+
+To remove the Python overhead we can create a functional training loop that uses regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. Concretely we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) before the training loop to create a single `graphdef` and `state` pytrees for all the graph nodes. Then we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to recreate the objects inside, and either `nnx.split` or `nnx.state` at the end to get the output `state`. At the end of the training loop or whenever needed we can use [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update the objects to the current `state`.
```{code-cell}
-model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
-optimizer = nnx.Optimizer(model, optax.adamw(1e-3)) # reference sharing
-metrics = nnx.MultiMetric(
- loss=nnx.metrics.Average('loss'),
-)
# split before training loop
graphdef, state = nnx.split((model, optimizer, metrics))
@jax.jit # regular JAX
-def train_step(graphdef, state, x, y):
+def jax_train_step(graphdef, state, x, y):
# merge at the beginning of the function
model, optimizer, metrics = nnx.merge(graphdef, state)
@@ -106,20 +114,15 @@ def train_step(graphdef, state, x, y):
optimizer.update(grads)
metrics.update(loss=loss)
- # split at the end of the function
- _, state = nnx.split((model, optimizer, metrics))
-
- # return new state
- return state, loss
+ state = nnx.state((model, optimizer, metrics))
+ return loss, state
for _ in range(10):
x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))
- state, loss = train_step(graphdef, state, x, y)
+ state, loss = jax_train_step(graphdef, state, x, y)
# update objects after training
nnx.update((model, optimizer, metrics), state)
```
-Notice that we only do this for `jit`. You can still use other [Flax transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html#transformations) like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) shown in the above example since their overhead is already absorbed by the outer `jit`.
-
-And after the training loop is done (or whenever it is needed), we can use Flax [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`.
+Notice that we only need to do this for `jit`, the use of other Flax transforms like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) inside `train_step` doesn't have any performance cost since `jit` will make sure this only traced once.
diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb
index 03d062491..bf040b98d 100644
--- a/docs_nnx/nnx_basics.ipynb
+++ b/docs_nnx/nnx_basics.ipynb
@@ -92,7 +92,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -104,7 +104,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -190,13 +190,13 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -208,7 +208,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -263,7 +263,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -275,7 +275,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -399,84 +399,26 @@
{
"data": {
"text/html": [
- " MLP Summary \n",
- "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
- "┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
- "│ bn │ BatchNorm │ mean: float32[5,32] │ bias: float32[5,32] │ │\n",
- "│ │ │ var: float32[5,32] │ scale: float32[5,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ 320 (1.3 KB) │ 320 (1.3 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: uint32[5] │\n",
- "│ │ │ │ │ key: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: key<fry>[5] │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ │ 10 (60 B) │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear1 │ Linear │ │ b: float32[5,32] │ │\n",
- "│ │ │ │ w: float32[5,10,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ 1,760 (7.0 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear2 │ Linear │ │ b: float32[5,10] │ │\n",
- "│ │ │ │ w: float32[5,32,10] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ 1,650 (6.6 KB) │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ │ Total │ 320 (1.3 KB) │ 3,730 (14.9 KB) │ 10 (60 B) │\n",
- "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
- " \n",
- " Total Parameters: 4,060 (16.3 KB) \n",
- "
\n"
+ "
"
],
"text/plain": [
- "\u001b[3m MLP Summary \u001b[0m\n",
- "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtype \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mBatchStat \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mRngState \u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
- "│ bn │ BatchNorm │ mean: \u001b[2mfloat32\u001b[0m[5,32] │ bias: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ var: \u001b[2mfloat32\u001b[0m[5,32] │ scale: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: \u001b[2muint32\u001b[0m[5] │\n",
- "│ │ │ │ │ key: │\n",
- "│ │ │ │ │ tag: default │\n",
- "│ │ │ │ │ value: \u001b[2mkey\u001b[0m[5] │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ │ \u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear1 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
- "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,10,32] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ \u001b[1m1,760 \u001b[0m\u001b[1;2m(7.0 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│ linear2 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,10] │ │\n",
- "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,32,10] │ │\n",
- "│ │ │ │ │ │\n",
- "│ │ │ │ \u001b[1m1,650 \u001b[0m\u001b[1;2m(6.6 KB)\u001b[0m │ │\n",
- "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
- "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3,730 \u001b[0m\u001b[1;2m(14.9 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n",
- "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
- "\u001b[1m \u001b[0m\n",
- "\u001b[1m Total Parameters: 4,060 \u001b[0m\u001b[1;2m(16.3 KB)\u001b[0m\u001b[1m \u001b[0m\n"
+ ""
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -528,7 +470,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -540,7 +482,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -589,7 +531,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -601,7 +543,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -613,7 +555,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -714,7 +656,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -726,7 +668,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -738,7 +680,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -750,7 +692,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -803,7 +745,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.13"
+ "version": "3.11.9"
}
},
"nbformat": 4,
diff --git a/examples/nnx_toy_examples/02_lifted_transforms.py b/examples/nnx_toy_examples/02_lifted_transforms.py
index 9fef3adf2..f6d745560 100644
--- a/examples/nnx_toy_examples/02_lifted_transforms.py
+++ b/examples/nnx_toy_examples/02_lifted_transforms.py
@@ -82,13 +82,15 @@ def test_step(model: MLP, batch):
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}
+cached_train_step = nnx.cache_args(train_step, model, optimizer)
+cached_test_step = nnx.cache_args(test_step, model)
total_steps = 10_000
for step, batch in enumerate(dataset(32)):
- train_step(model, optimizer, batch)
+ cached_train_step(batch)
if step % 1000 == 0:
- logs = test_step(model, (X, Y))
+ logs = cached_test_step((X, Y))
print(f"step: {step}, loss: {logs['loss']}")
if step >= total_steps - 1:
diff --git a/flax/configurations.py b/flax/configurations.py
index ba19a572f..5e1a492fc 100644
--- a/flax/configurations.py
+++ b/flax/configurations.py
@@ -22,6 +22,7 @@
class Config:
+ flax_use_flaxlib: bool
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True
@@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /):
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value
+ def __repr__(self):
+ values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
+ return f'Config({values_repr}\n)'
+
config = Config()
@@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool):
' PRNG keys.'
),
)
+
+flax_use_flaxlib = bool_flag(
+ name='flax_use_flaxlib',
+ default=False,
+ help='Whether to use flaxlib for C++ acceleration.',
+)
\ No newline at end of file
diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py
index fcb15f060..a30deb55b 100644
--- a/flax/nnx/__init__.py
+++ b/flax/nnx/__init__.py
@@ -56,6 +56,7 @@
from .graph import MergeContext as MergeContext
from .graph import merge_context as merge_context
from .graph import variables as variables
+from .graph import cache_args as cache_args
from .nn import initializers as initializers
from .nn.activations import celu as celu
from .nn.activations import elu as elu
@@ -143,6 +144,7 @@
from .transforms.autodiff import custom_vjp as custom_vjp
from .transforms.autodiff import remat as remat
from .transforms.compilation import jit as jit
+from .transforms.compilation import shard_map as shard_map
from .transforms.compilation import StateSharding as StateSharding
from .transforms.iteration import Carry as Carry
from .transforms.iteration import scan as scan
diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py
index 121bb98eb..da83cd545 100644
--- a/flax/nnx/bridge/variables.py
+++ b/flax/nnx/bridge/variables.py
@@ -21,7 +21,6 @@
from flax.nnx import spmd
from flax.nnx import traversals
from flax.nnx import variablelib as variableslib
-from flax.nnx.module import GraphDef
import typing as tp
@@ -193,12 +192,7 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
"""Convert a dict of NNX variables (or variable states) to Linen-style variables."""
linen_structured = {}
- for kp, v in traversals.flatten_mapping(
- nnx_attrs,
- is_leaf=lambda _, x: isinstance(
- x, variableslib.Variable | variableslib.VariableState | GraphDef
- ),
- ).items():
+ for kp, v in traversals.flatten_mapping(nnx_attrs).items():
if isinstance(v, variableslib.Variable):
col_name = variable_type_name(type(v))
v = to_linen_var(v.to_state())
diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py
index eed4ba2f7..6f1da04f9 100644
--- a/flax/nnx/bridge/wrappers.py
+++ b/flax/nnx/bridge/wrappers.py
@@ -22,7 +22,7 @@
from flax.core import meta
from flax.nnx import graph
from flax.nnx.bridge import variables as bv
-from flax.nnx.module import GraphDef, Module
+from flax.nnx.module import Module
from flax.nnx.object import Object
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
@@ -36,7 +36,7 @@
@dataclasses.dataclass
class Functional(tp.Generic[M]):
module_type: tp.Type[M]
- graphdef: tp.Optional[GraphDef[M]]
+ graphdef: tp.Optional[graph.NodeDef[M]]
args: tuple[tp.Any, ...]
kwargs: dict[str, tp.Any]
@@ -46,6 +46,7 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
kwargs['rngs'] = rngs
module = self.module_type(*self.args, **self.kwargs, **kwargs)
graphdef, state = nnx.split(module)
+ assert type(graphdef) is graph.NodeDef
self.graphdef = graphdef
return state
diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py
index 191a0c195..364177b5f 100644
--- a/flax/nnx/extract.py
+++ b/flax/nnx/extract.py
@@ -13,9 +13,6 @@
# limitations under the License.
import abc
-import contextlib
-import dataclasses
-import threading
import typing as tp
import jax
@@ -67,7 +64,7 @@ def extract_graph_nodes(
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
- nodes = graph.RefMap[tp.Any, Index]()
+ nodes: dict[tp.Any, Index] = {}
node_prefixes = []
leaves = []
@@ -134,11 +131,10 @@ def check_consistent_aliasing(
prefix: tuple[tp.Any, ...],
/,
*,
- node_prefixes: graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]
- | None = None,
+ node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] | None = None,
):
if node_prefixes is None:
- node_prefixes = graph.RefMap()
+ node_prefixes = {}
# collect all paths and prefixes for each node
for path, value in graph.iter_graph(node):
@@ -181,50 +177,6 @@ def check_consistent_aliasing(
+ '\n'.join(node_msgs)
)
-
-# -----------------------------
-# broadcast
-# -----------------------------
-
-
-@dataclasses.dataclass
-class BroadcastContext(threading.local):
- broadcast_state_stacks: dict[str, list[tp.Any]] = dataclasses.field(
- default_factory=dict
- )
-
-
-BROADCAST_CONTEXT = BroadcastContext()
-
-
-@contextlib.contextmanager
-def broadcast_state(tag: str, state: tp.Any):
- if tag in BROADCAST_CONTEXT.broadcast_state_stacks:
- stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]
- else:
- stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag] = []
- stack.append(state)
- try:
- yield
- finally:
- stack.pop()
- if not stack:
- del BROADCAST_CONTEXT.broadcast_state_stacks[tag]
-
-
-def get_broadcast_state(tag: str) -> tp.Any:
- if tag not in BROADCAST_CONTEXT.broadcast_state_stacks:
- raise ValueError(f'No broadcast state found for {tag!r}')
-
- stack = BROADCAST_CONTEXT.broadcast_state_stacks[tag]
-
- if not stack:
- raise RuntimeError(
- f'Empty broadcast state stack for {tag!r}, this is a bug'
- )
-
- return stack[-1]
-
# -----------------------------
# to_tree/from_tree
# -----------------------------
@@ -251,10 +203,13 @@ class GraphDefState(struct.PyTreeNode):
graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: graph.GraphState = struct.field(pytree_node=True)
+S = tp.TypeVar(
+ 'S', bound=graph.GraphState | graph.GraphFlatState | list[tp.Any]
+)
-class NodeStates(struct.PyTreeNode):
+class NodeStates(struct.PyTreeNode, tp.Generic[S]):
_graphdef: graph.GraphDef[tp.Any] | None
- states: tuple[graph.GraphState, ...]
+ states: tuple[S, ...]
metadata: tp.Any = struct.field(pytree_node=False)
@property
@@ -264,7 +219,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef
@property
- def state(self) -> graph.GraphState:
+ def state(self) -> S:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
@@ -275,15 +230,19 @@ def state(self) -> graph.GraphState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
- state: graph.GraphState,
+ state: S,
/,
- *states: graph.GraphState,
+ *states: S,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)
@classmethod
- def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
+ def from_states(
+ cls,
+ state: S,
+ *states: S,
+ ):
return cls(_graphdef=None, states=(state, *states), metadata=None)
@classmethod
@@ -312,9 +271,18 @@ def to_tree(
[graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any
] = default_split_fn,
map_non_graph_nodes: bool = False,
- ctxtag: str | None = None,
+ ctxtag: tp.Hashable | None = None,
check_aliasing: bool = True,
) -> tp.Any:
+ if prefix is Missing or prefix is None:
+ # fast path, no need for prefix broadcasting or consistent aliasing checks
+ with graph.split_context(ctxtag) as split_ctx:
+ return jax.tree.map(
+ lambda x: split_fn(split_ctx, (), prefix, x)
+ if map_non_graph_nodes or graph.is_graph_node(x)
+ else x,
+ tree,
+ )
leaf_prefixes = broadcast_prefix(
prefix,
tree,
@@ -324,7 +292,7 @@ def to_tree(
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
- node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()
+ node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] = {}
with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
@@ -367,8 +335,19 @@ def from_tree(
is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
is_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
map_non_graph_nodes: bool = False,
- ctxtag: str | None = None,
+ is_inner: bool | None = None,
+ ctxtag: tp.Hashable | None = None,
) -> tp.Any:
+ if prefix is Missing or prefix is None:
+ # fast path, no need for prefix broadcasting or consistent aliasing checks
+ with graph.merge_context(is_inner, ctxtag) as merge_ctx:
+ return jax.tree.map(
+ lambda x: merge_fn(merge_ctx, (), prefix, x)
+ if map_non_graph_nodes or is_node_leaf(x)
+ else x,
+ tree,
+ is_leaf=is_leaf,
+ )
leaf_prefixes = broadcast_prefix(
prefix,
tree,
@@ -381,15 +360,11 @@ def from_tree(
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
- with graph.merge_context(ctxtag) as merge_ctx:
+ with graph.merge_context(is_inner, ctxtag) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
- if is_node_leaf(leaf):
- leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
- leaves_out.append(leaf_out)
- else:
- if map_non_graph_nodes:
- leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
- leaves_out.append(leaf)
+ if map_non_graph_nodes or is_node_leaf(leaf):
+ leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
+ leaves_out.append(leaf)
pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index 8cc272f8e..d59c1b39e 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -14,23 +14,26 @@
from __future__ import annotations
+from collections import deque
import contextlib
import dataclasses
import functools
import threading
import typing as tp
+from weakref import WeakKeyDictionary
+from flax import config
import jax
import numpy as np
import typing_extensions as tpe
-from flax.nnx import filterlib, reprlib, visualization
+from flax.nnx import filterlib, reprlib
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
DelayedAccessor,
)
-from flax.nnx.statelib import State
+from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
@@ -53,6 +56,7 @@
StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
+GraphFlatState = FlatState[StateLeaf]
def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
@@ -62,37 +66,12 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)
+RefMap = dict
-class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin):
- """A mapping that uses object id as the hash for the keys."""
-
- def __init__(
- self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), /
- ):
- self._mapping: dict[int, tuple[A, B]] = {}
- self.update(mapping)
-
- def __getitem__(self, key: A) -> B:
- return self._mapping[id(key)][1]
-
- def __contains__(self, key: object) -> bool:
- return id(key) in self._mapping
-
- def __setitem__(self, key: A, value: B):
- self._mapping[id(key)] = (key, value)
-
- def __delitem__(self, key: A):
- del self._mapping[id(key)]
-
- def __iter__(self) -> tp.Iterator[A]:
- return (key for key, _ in self._mapping.values())
-
- def __len__(self) -> int:
- return len(self._mapping)
-
- def __str__(self) -> str:
- return repr(self)
+if not tp.TYPE_CHECKING and config.flax_use_flaxlib:
+ import flaxlib
+ RefMap = flaxlib.RefMap
@dataclasses.dataclass(frozen=True, slots=True)
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
@@ -175,9 +154,9 @@ def is_node_type(x: type[tp.Any]) -> bool:
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
-def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
+def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
if isinstance(x, Variable):
- raise ValueError(f'Variable is not a node: {x}')
+ return None
node_type = type(x)
@@ -185,22 +164,27 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
return GRAPH_REGISTRY[node_type]
elif node_type in PYTREE_REGISTRY:
return PYTREE_REGISTRY[node_type]
- elif is_pytree_node(x):
+ elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple):
return PYTREE_NODE_IMPL # type: ignore
else:
- raise ValueError(f'Unknown node type: {x}')
+ return None
-def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
+def get_node_impl_for_type(
+ x: type[Node],
+) -> NodeImpl[Node, tp.Any, tp.Any] | None:
if x is GenericPytree:
return PYTREE_NODE_IMPL # type: ignore
elif x in PYTREE_REGISTRY:
return PYTREE_REGISTRY[x]
- else:
+ elif x in GRAPH_REGISTRY:
return GRAPH_REGISTRY[x]
+ else:
+ return None
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
+ _mapping: dict[HA, HB] | tp.Mapping[HA, HB]
def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
self._mapping = dict(mapping) if copy else mapping
@@ -228,17 +212,8 @@ def __repr__(self) -> str:
return repr(self._mapping)
-class GraphDef(tp.Generic[Node]):
- """A class that represents all the static, stateless, and Pythonic parts of a Flax
- :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or
- :func:`graphdef` on the :class:`Module`."""
-
- type: type[Node]
- index: int
-
-
@dataclasses.dataclass(frozen=True, repr=False)
-class NodeRef(GraphDef[Node], reprlib.Representable):
+class NodeRef(tp.Generic[Node], reprlib.Representable):
type: type[Node]
index: int
@@ -248,7 +223,8 @@ def __nnx_repr__(self):
yield reprlib.Attr('index', self.index)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={'type': self.type, 'index': self.index},
path=path,
@@ -262,16 +238,33 @@ def __treescope_repr__(self, path, subtree_renderer):
class VariableDef(reprlib.Representable):
type: type[Variable]
index: int
+ outer_index: int | None
metadata: HashableMapping[str, tp.Any]
+ def with_no_outer_index(self) -> VariableDef:
+ return VariableDef(
+ type=self.type, index=self.index, outer_index=None, metadata=self.metadata
+ )
+
+ def with_same_outer_index(self) -> VariableDef:
+ return VariableDef(
+ type=self.type,
+ index=self.index,
+ outer_index=self.index,
+ metadata=self.metadata,
+ )
+
def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
+ yield reprlib.Attr('outer_index', self.outer_index)
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
@@ -286,71 +279,74 @@ def __treescope_repr__(self, path, subtree_renderer):
jax.tree_util.register_static(VariableDef)
-@dataclasses.dataclass(frozen=True, slots=True)
-class SubGraphAttribute:
- key: Key
- value: NodeDef[tp.Any] | NodeRef[tp.Any]
-
-
-@dataclasses.dataclass(frozen=True, slots=True)
-class StaticAttribute:
- key: Key
- value: tp.Any
-
-
-@dataclasses.dataclass(frozen=True, slots=True)
-class LeafAttribute:
- key: Key
- value: VariableDef | NodeRef[tp.Any]
-
-
@dataclasses.dataclass(frozen=True, repr=False, slots=True)
-class NodeDef(GraphDef[Node], reprlib.Representable):
+class NodeDef(tp.Generic[Node], reprlib.Representable):
"""A dataclass that denotes the tree structure of a
:class:`Module`. A ``GraphDef`` can be generated by either
calling :func:`split` or :func:`graphdef` on the :class:`Module`."""
type: tp.Type[Node]
index: int
- attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...]
+ outer_index: int | None
+ attributes: tuple[
+ tuple[
+ Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any]
+ ],
+ ...,
+ ]
metadata: tp.Any
- index_mapping: HashableMapping[Index, Index] | None
- @classmethod
- def create(
- cls,
- type: tp.Type[Node],
- index: int,
- attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...],
- metadata: tp.Any,
- index_mapping: tp.Mapping[Index, Index] | None,
- ):
- return cls(
- type=type,
- index=index,
+ def with_no_outer_index(self) -> NodeDef[Node]:
+ attributes = tuple(
+ (
+ key,
+ value.with_no_outer_index()
+ if isinstance(value, NodeDef | VariableDef)
+ else value,
+ )
+ for key, value in self.attributes
+ )
+ return NodeDef(
+ type=self.type,
+ index=self.index,
+ outer_index=None,
+ attributes=attributes,
+ metadata=self.metadata,
+ )
+
+ def with_same_outer_index(self) -> NodeDef[Node]:
+ attributes = tuple(
+ (
+ key,
+ value.with_same_outer_index()
+ if isinstance(value, NodeDef | VariableDef)
+ else value,
+ )
+ for key, value in self.attributes
+ )
+ return NodeDef(
+ type=self.type,
+ index=self.index,
+ outer_index=self.index if self.index >= 0 else None,
attributes=attributes,
- metadata=metadata,
- index_mapping=HashableMapping(index_mapping)
- if index_mapping is not None
- else None,
+ metadata=self.metadata,
)
+ def replace(self, **kwargs):
+ return dataclasses.replace(self, **kwargs)
+
def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
- yield reprlib.Attr('attributes', reprlib.PrettySequence(self.attributes))
+ yield reprlib.Attr('outer_index', self.outer_index)
+ yield reprlib.Attr('attributes', self.attributes)
yield reprlib.Attr('metadata', self.metadata)
- yield reprlib.Attr(
- 'index_mapping',
- reprlib.PrettyMapping(self.index_mapping)
- if self.index_mapping is not None
- else None,
- )
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
@@ -373,19 +369,89 @@ def _apply(
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
- return out, flatten(module)
+ graphdef, flat_state = flatten(module)
+ state_ = State.from_flat_path(flat_state)
+ return out, (graphdef, state_)
return CallableProxy(_apply, accessor) # type: ignore
jax.tree_util.register_static(NodeDef)
-PureState = tuple[GraphDef[A], GraphState]
+GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]]
+PureState = tuple[GraphDef[Node], GraphState]
+@tp.overload
def flatten(
- node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
-) -> tuple[GraphDef[Node], GraphState]:
+ node: Node,
+ /,
+ *,
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ...
+@tp.overload
+def flatten(
+ node: Node,
+ /,
+ *,
+ with_paths: tp.Literal[True],
+ return_variables: tp.Literal[True],
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[
+ GraphDef[Node],
+ FlatState[Variable[tp.Any]],
+]: ...
+@tp.overload
+def flatten(
+ node: Node,
+ /,
+ *,
+ with_paths: tp.Literal[False],
+ return_variables: tp.Literal[True],
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[
+ GraphDef[Node],
+ list[Variable[tp.Any]],
+]: ...
+@tp.overload
+def flatten(
+ node: Node,
+ /,
+ *,
+ return_variables: tp.Literal[True],
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[
+ GraphDef[Node],
+ FlatState[Variable[tp.Any]],
+]: ...
+@tp.overload
+def flatten(
+ node: Node,
+ /,
+ *,
+ with_paths: bool,
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[
+ GraphDef[Node],
+ FlatState[VariableState[tp.Any]] | list[tp.Any],
+]: ...
+def flatten(
+ node: Node,
+ /,
+ *,
+ with_paths: bool = True,
+ return_variables: bool = False,
+ ref_index: RefMap | None = None,
+ ref_outer_index: RefMap | None = None,
+) -> tuple[
+ GraphDef[Node],
+ FlatState[VariableState[tp.Any]] | FlatState[Variable[tp.Any]] | list[tp.Any],
+]:
"""Flattens a graph node into a (graphdef, state) pair.
Args:
@@ -393,81 +459,353 @@ def flatten(
ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new
empty dictionary is created. This argument can be used to flatten a sequence of graph
nodes that share references.
+ with_paths: A boolean that indicates whether to return a FlatState object that includes
+ the paths to VariableState objects, or just a list of the Variable's inner values.
"""
if ref_index is None:
ref_index = RefMap()
- flat_state: list[tuple[PathParts, StateLeaf]] = []
- graphdef = _graph_flatten((), ref_index, flat_state, node)
- return graphdef, GraphState.from_flat_path(flat_state)
+
+ leaves: list[StateLeaf | Variable[tp.Any]] = []
+ path: list[Key] | None = [] if with_paths else None
+ paths: list[PathParts] | None = [] if with_paths else None
+ node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
+ graphdef = _graph_flatten(
+ node,
+ node_impl,
+ path,
+ ref_index,
+ ref_outer_index,
+ leaves,
+ paths,
+ return_variables,
+ )
+
+ if paths is not None:
+ return graphdef, FlatState.from_sorted_keys_values(tuple(paths), leaves) # type: ignore[return-value]
+ else:
+ return graphdef, leaves
def _graph_flatten(
- path: PathParts,
- ref_index: RefMap[tp.Any, Index],
- flat_state: list[tuple[PathParts, StateLeaf]],
node: Node,
-) -> NodeDef[Node] | NodeRef:
- if not is_node(node):
- raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
-
- if node in ref_index:
+ node_impl: NodeImpl[Node, Leaf, AuxData],
+ path: list[Key] | None,
+ ref_index: RefMap,
+ ref_outer_index: RefMap | None,
+ leaves: list[StateLeaf | Variable[tp.Any]],
+ paths: list[PathParts] | None,
+ return_variables: bool,
+) -> NodeDef[tp.Any] | NodeRef:
+ is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
+ is_graph_node_ = isinstance(node_impl, GraphNodeImpl)
+
+ if not is_pytree_node_ and node in ref_index:
return NodeRef(type(node), ref_index[node])
- node_impl = get_node_impl(node)
-
# only cache graph nodes
- if isinstance(node_impl, GraphNodeImpl):
+ if is_graph_node_:
index = len(ref_index)
ref_index[node] = index
else:
index = -1
- attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = []
+ attributes: list[
+ tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
+ ] = []
values, metadata = node_impl.flatten(node)
for key, value in values:
- if is_node(value):
- nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
- # subgraphs.append((key, nodedef))
- attributes.append(SubGraphAttribute(key, nodedef))
+ value_node_impl = get_node_impl(value)
+ if path is not None:
+ path.append(key)
+ if value_node_impl is not None:
+ nodedef = _graph_flatten(
+ value,
+ value_node_impl,
+ path,
+ ref_index,
+ ref_outer_index,
+ leaves,
+ paths,
+ return_variables,
+ )
+ attributes.append((key, nodedef))
elif isinstance(value, Variable):
if value in ref_index:
- attributes.append(
- LeafAttribute(key, NodeRef(type(value), ref_index[value]))
- )
+ attributes.append((key, NodeRef(type(value), ref_index[value])))
else:
- flat_state.append(((*path, key), value.to_state()))
+ if return_variables:
+ leaf = value
+ elif path is None:
+ leaf = value.raw_value
+ else:
+ leaf = value.to_state() # type: ignore[assignment]
+ leaves.append(leaf)
+ if path is not None:
+ assert paths is not None
+ paths.append(tuple(path))
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
- type(value), variable_index, HashableMapping(value._var_metadata)
+ type=type(value),
+ index=variable_index,
+ outer_index=ref_outer_index.get(value, None)
+ if ref_outer_index
+ else None,
+ metadata=HashableMapping(value._var_metadata),
)
- attributes.append(LeafAttribute(key, variabledef))
+ attributes.append((key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
- path_str = '/'.join(map(str, (*path, key)))
- raise ValueError(
+ if path is not None:
+ path_str = '/'.join(map(str, path))
+ raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
- )
+ )
+ else:
+ raise ValueError(f'Arrays leaves are not supported, found {value}')
# static_fields.append((key, value))
- attributes.append(StaticAttribute(key, value))
+ attributes.append((key, Static(value)))
- nodedef = NodeDef.create(
- type=node_impl.type,
+ if path is not None:
+ path.pop()
+
+ nodedef = NodeDef(
+ type=node_impl.type, # type: ignore[arg-type]
index=index,
+ outer_index=ref_outer_index[node]
+ if is_graph_node_ and ref_outer_index and node in ref_outer_index
+ else None,
attributes=tuple(attributes),
metadata=metadata,
- index_mapping=None,
)
return nodedef
+@dataclasses.dataclass(slots=True)
+class FingerprintContext:
+ next_index: int
+
+def fingerprint(
+ node,
+ /,
+ *,
+ ref_index: RefMap | None = None,
+ new_ref_index: RefMap | None = None,
+) -> list[tp.Hashable]:
+ """ """
+ if ref_index is None:
+ ref_index = RefMap()
+
+ if new_ref_index is None:
+ new_ref_index = RefMap()
+ node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
+ ctx = FingerprintContext(len(ref_index) + len(new_ref_index))
+ fp: list[tp.Hashable] = []
+ _graph_fingerprint(ctx, fp.append, node, node_impl, ref_index, new_ref_index)
+ return fp
+
+
+def _graph_fingerprint(
+ ctx: FingerprintContext,
+ append_fn: tp.Callable[[tp.Any], None],
+ node,
+ node_impl: NodeImpl[Node, Leaf, AuxData],
+ ref_index: RefMap,
+ new_ref_index: RefMap,
+):
+ is_pytree_node_ = type(node_impl) is PytreeNodeImpl
+ is_graph_node_ = type(node_impl) is GraphNodeImpl
+
+ append_fn(type(node))
+
+ if is_graph_node_:
+ append_fn(id(node))
+ if node in ref_index:
+ append_fn(ref_index[node])
+ return
+ elif node in new_ref_index:
+ append_fn(new_ref_index[node])
+ return
+ index = new_ref_index[node] = ctx.next_index
+ ctx.next_index += 1
+ else:
+ index = -1
+
+ values, metadata = node_impl.flatten(node)
+
+ append_fn(index)
+ append_fn(metadata)
+
+ for key, value in values:
+ value_node_impl = get_node_impl(value)
+ append_fn(key)
+ if value_node_impl is not None:
+ _graph_fingerprint(
+ ctx,
+ append_fn,
+ value,
+ value_node_impl,
+ ref_index,
+ new_ref_index,
+ )
+ elif isinstance(value, Variable):
+ append_fn(id(value))
+ append_fn(type(value))
+ if value in ref_index:
+ append_fn(ref_index[value])
+ elif value in new_ref_index:
+ append_fn(new_ref_index[value])
+ else:
+ variable_index = new_ref_index[value] = ctx.next_index
+ ctx.next_index += 1
+ append_fn(variable_index)
+ for key_value in value._var_metadata.items():
+ append_fn(key_value)
+ else:
+ if isinstance(value, (jax.Array, np.ndarray)):
+ raise ValueError(f'Arrays leaves are not supported: {value}')
+ append_fn(value)
+
+def check_fingerprint(
+ node,
+ fp: list[tp.Hashable],
+ /,
+ *,
+ ref_index: RefMap | None = None,
+ new_ref_index: RefMap | None = None,
+) -> bool:
+ """ """
+ if ref_index is None:
+ ref_index = RefMap()
+
+ if new_ref_index is None:
+ new_ref_index = RefMap()
+ node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
+ ctx = FingerprintContext(len(ref_index) + len(new_ref_index))
+ fp_matches = _check_graph_fingerprint(
+ ctx, iter(fp), node, node_impl, ref_index, new_ref_index
+ )
+ return fp_matches
+
+
+def _check_graph_fingerprint(
+ ctx: FingerprintContext,
+ fp_iterator: tp.Iterator[tp.Hashable],
+ node,
+ node_impl: NodeImpl[Node, Leaf, AuxData],
+ ref_index: RefMap,
+ new_ref_index: RefMap,
+) -> bool:
+ is_pytree_node_ = type(node_impl) is PytreeNodeImpl
+ is_graph_node_ = type(node_impl) is GraphNodeImpl
+
+ if type(node) != next(fp_iterator):
+ return False
+
+ if is_graph_node_:
+ # append_fn(id(node))
+ if id(node) != next(fp_iterator):
+ return False
+ if node in ref_index:
+ # append_fn(ref_index[node])
+ return ref_index[node] == next(fp_iterator)
+ elif node in new_ref_index:
+ # append_fn(new_ref_index[node])
+ return new_ref_index[node] == next(fp_iterator)
+ index = new_ref_index[node] = ctx.next_index
+ ctx.next_index += 1
+ else:
+ index = -1
+
+ values, metadata = node_impl.flatten(node)
+
+ # append_fn(index)
+ if index != next(fp_iterator):
+ return False
+ # append_fn(metadata)
+ if metadata != next(fp_iterator):
+ return False
+
+ for key, value in values:
+ value_node_impl = get_node_impl(value)
+ # append_fn(key)
+ if key != next(fp_iterator):
+ return False
+ if value_node_impl is not None:
+ if not _check_graph_fingerprint(
+ ctx,
+ fp_iterator,
+ value,
+ value_node_impl,
+ ref_index,
+ new_ref_index,
+ ):
+ return False
+ elif isinstance(value, Variable):
+ # append_fn(id(value))
+ if id(value) != next(fp_iterator):
+ return False
+ # append_fn(type(value))
+ if type(value) != next(fp_iterator):
+ return False
+ if value in ref_index:
+ # append_fn(ref_index[value])
+ if ref_index[value] != next(fp_iterator):
+ return False
+ elif value in new_ref_index:
+ # append_fn(new_ref_index[value])
+ if new_ref_index[value] != next(fp_iterator):
+ return False
+ else:
+ variable_index = new_ref_index[value] = ctx.next_index
+ ctx.next_index += 1
+ # append_fn(variable_index)
+ if variable_index != next(fp_iterator):
+ return False
+ for key_value in value._var_metadata.items():
+ # append_fn(key_value)
+ if key_value != next(fp_iterator):
+ return False
+ else:
+ if isinstance(value, (jax.Array, np.ndarray)):
+ raise ValueError(f'Arrays leaves are not supported: {value}')
+ # append_fn(value)
+ if value != next(fp_iterator):
+ return False
+
+ return True
+
+
+def _get_sorted_leaves(
+ xs: tp.Mapping[tp.Any, tp.Any],
+) -> deque[tp.Any]:
+ if not isinstance(xs, tp.Mapping): # type: ignore
+ raise TypeError(f'expected Mapping; got {type(xs).__qualname__}')
+ leaves: deque[tp.Any] = deque()
+
+ def _flatten(xs):
+ if not isinstance(xs, tp.Mapping):
+ leaves.append(xs)
+ else:
+ for _, value in sorted(xs.items()):
+ _flatten(value)
+
+ _flatten(xs)
+ return leaves
+
def unflatten(
graphdef: GraphDef[Node],
- state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]],
+ state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any],
/,
*,
index_ref: dict[Index, tp.Any] | None = None,
- index_ref_cache: dict[Index, tp.Any] | None = None,
+ outer_index_outer_ref: dict[Index, tp.Any] | None = None,
) -> Node:
"""Unflattens a graphdef into a node with the given state.
@@ -484,19 +822,41 @@ def unflatten(
existing graph nodes are mutated to have the new content/topology
specified by the graphdef.
"""
- if isinstance(state, State):
- state = state.raw_mapping # type: ignore
+ if isinstance(state, (State, dict)):
+ leaves = _get_sorted_leaves(state)
+ elif isinstance(state, FlatState):
+ leaves = deque(state.leaves)
+ elif isinstance(state, list): # type: ignore
+ leaves = deque(state)
+ else:
+ raise ValueError(f'Unsupported state type: {type(state)}')
if index_ref is None:
index_ref = {}
- assert isinstance(graphdef, (NodeDef, NodeRef))
- node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache)
+
+ if isinstance(graphdef, NodeRef):
+ node = index_ref[graphdef.index]
+ else:
+ assert isinstance(graphdef, NodeDef)
+ node_impl = get_node_impl_for_type(graphdef.type)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.')
+ node = _graph_unflatten(
+ graphdef, node_impl, leaves, index_ref, outer_index_outer_ref
+ )
+ if leaves:
+ raise ValueError(
+ f'Incorrect number of leaves: got an extra {len(leaves)} leaves in the state'
+ )
+
return node
+
def _graph_unflatten(
nodedef: NodeDef[Node] | NodeRef[Node],
- state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]],
+ node_impl: NodeImpl[Node, Leaf, AuxData],
+ leaves: deque[tp.Any],
index_ref: dict[Index, tp.Any],
- index_ref_cache: dict[Index, tp.Any] | None,
+ outer_index_outer_ref: dict[Index, tp.Any] | None,
) -> Node:
"""Recursive helper for graph_unflatten.
@@ -511,134 +871,83 @@ def _graph_unflatten(
existing graph nodes are mutated to have the new content/topology
specified by the nodedef.
"""
- if isinstance(nodedef, NodeRef):
+ if type(nodedef) is NodeRef:
return index_ref[nodedef.index]
-
- if not is_node_type(nodedef.type):
- raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.')
-
+ assert type(nodedef) is NodeDef
if nodedef.index in index_ref:
raise RuntimeError(f'GraphDef index {nodedef.index} already used.')
- node_impl = get_node_impl_for_type(nodedef.type)
-
- def _get_children():
- children: list[tuple[Key, NodeLeaf | Node]] = []
- state_keys: set = set(state.keys())
-
- # for every key in attributes there are 6 possible cases:
- # - (2) the key can either be present in the state or not
- # - (3) the key can be a subgraph, a leaf, or a static attribute
- for attribute in nodedef.attributes:
- key = attribute.key
- if key not in state:
- # if key is not present create an empty types
- if type(attribute) is StaticAttribute:
- children.append((key, attribute.value))
- elif type(attribute) is SubGraphAttribute:
- # if the key is a subgraph we create an empty node
- subgraphdef = attribute.value
- assert not isinstance(subgraphdef, VariableDef)
- if isinstance(subgraphdef, NodeRef):
- # subgraph exists, take it from the cache
- children.append((key, index_ref[subgraphdef.index]))
- else:
- # create a node from an empty state, reasoning:
- # * its a node with no state
- # * its a node with state but only through references of already
- # created nodes
- substate = {}
- subnode = _graph_unflatten(
- subgraphdef, substate, index_ref, index_ref_cache
- )
- children.append((key, subnode))
- elif type(attribute) is LeafAttribute:
- variabledef = attribute.value
- if variabledef.index in index_ref:
- # variable exists, take it from the cache
- children.append((key, index_ref[variabledef.index]))
- else:
- # key for a variable is missing, raise an error
+ def _get_children() -> list[tuple[Key, tp.Any]]:
+ children: list[tuple[Key, NodeLeaf | Node]] = [] # type: ignore[invalid-annotation]
+
+ assert type(nodedef) is NodeDef
+ for key, value in nodedef.attributes:
+ if type(value) is Static:
+ children.append((key, value.value))
+ elif type(value) is NodeRef:
+ children.append((key, index_ref[value.index]))
+ elif type(value) is NodeDef:
+ # if the key is a subgraph we create an empty node
+ subgraphdef = value
+ value_node_impl = get_node_impl_for_type(subgraphdef.type)
+ assert value_node_impl is not None
+ subnode = _graph_unflatten(
+ subgraphdef, value_node_impl, leaves, index_ref, outer_index_outer_ref
+ )
+ children.append((key, subnode))
+ elif type(value) is VariableDef:
+ variabledef = value
+ if not leaves:
+ raise ValueError('Not enough leaves to unflatten the graph')
+ # its a unseen variable, create a new one
+ value = leaves.popleft()
+ # when idxmap is present, check if the Varable exists there
+ # and update existing variables if it does
+ if (
+ outer_index_outer_ref is not None
+ and variabledef.outer_index in outer_index_outer_ref
+ ):
+ # if variable exists, update it
+ variable = outer_index_outer_ref[variabledef.outer_index]
+ if not isinstance(variable, Variable):
raise ValueError(
- f'Expected key {key!r} in state while building node of type '
- f'{nodedef.type.__name__}.'
+ f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
- else:
- raise RuntimeError(f'Unknown static field: {key!r}')
- else:
- state_keys.remove(key)
- value = state[key]
- # if key in nodedef.static_fields:
- if type(attribute) is StaticAttribute:
- raise ValueError(
- f'Got state for static field {key!r}, this is not supported.'
- )
- elif type(attribute) is SubGraphAttribute:
- if is_state_leaf(value):
+ elif isinstance(value, Variable):
raise ValueError(
- f'Expected value of type {attribute.value} for '
- f'{key!r}, but got {value!r}'
+ f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. '
+ f'Got {value!r} for {key!r}.'
)
- assert isinstance(value, dict)
- subgraphdef = attribute.value
-
- if isinstance(subgraphdef, NodeRef):
- children.append((key, index_ref[subgraphdef.index]))
+ elif isinstance(value, VariableState):
+ variable.update_from_state(value)
else:
- subnode = _graph_unflatten(
- subgraphdef, value, index_ref, index_ref_cache
- )
- children.append((key, subnode))
-
- elif type(attribute) is LeafAttribute:
- variabledef = attribute.value
-
- if variabledef.index in index_ref:
- # add an existing variable
- assert isinstance(variabledef, NodeRef)
- children.append((key, index_ref[variabledef.index]))
+ variable.raw_value = value
+ else: # variabledef.index not in index_ref_cache
+ # variable reference does not exist outside, create a new one
+ if isinstance(value, Variable):
+ variable = value
+ elif isinstance(value, VariableState):
+ variable = value.to_variable()
else:
- # its a unseen variable, create a new one
- assert isinstance(variabledef, VariableDef)
- # when idxmap is present, check if the Varable exists there
- # and update existing variables if it does
- if (
- index_ref_cache is not None
- and variabledef.index in index_ref_cache
- ):
- # if variable exists, update it
- variable = index_ref_cache[variabledef.index]
- if not isinstance(variable, Variable):
- raise ValueError(
- f'Expected a Variable type for {key!r}, but got {type(variable)}.'
- )
- if isinstance(value, VariableState):
- variable.update_from_state(value)
- else:
- variable.raw_value = value
- else: # if it doesn't, create a new variable
- if isinstance(value, VariableState):
- variable = value.to_variable()
- else:
- variable = variabledef.type.from_metadata(
- value, variabledef.metadata
- )
- children.append((key, variable))
- index_ref[variabledef.index] = variable
- else:
- raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
-
- # NOTE: we could allw adding new StateLeafs here
- if state_keys:
- raise ValueError(f'Unknown keys: {state_keys}')
+ variable = variabledef.type.from_metadata(
+ value, variabledef.metadata
+ )
+ children.append((key, variable))
+ index_ref[variabledef.index] = variable
+ else:
+ raise RuntimeError(f'Unknown static field: {key!r}')
return children
if isinstance(node_impl, GraphNodeImpl):
# we create an empty node first and add it to the index
# this avoids infinite recursion when there is a reference cycle
- if index_ref_cache is not None and nodedef.index in index_ref_cache:
- node = index_ref_cache[nodedef.index]
+ assert type(nodedef) is NodeDef
+ if (
+ outer_index_outer_ref is not None
+ and nodedef.outer_index in outer_index_outer_ref
+ ):
+ node = outer_index_outer_ref[nodedef.outer_index]
if type(node) != nodedef.type:
raise ValueError(
f'Expected a node of type {nodedef.type} for index '
@@ -688,6 +997,8 @@ def _graph_pop(
id_to_index[id(node)] = len(id_to_index)
node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise TypeError(f'Unknown node type: {type(node)}')
node_dict = node_impl.node_dict(node)
for name, value in node_dict.items():
@@ -707,6 +1018,9 @@ def _graph_pop(
node_path = (*path_parts, name)
node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise TypeError(f'Unknown node type: {type(node)}')
+
for state, predicate in zip(flat_states, predicates):
if predicate(node_path, value):
if isinstance(node_impl, PytreeNodeImpl):
@@ -729,6 +1043,8 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]):
raise RuntimeError(f'Unsupported type: {type(node)}')
node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise TypeError(f'Unknown node type: {type(node)}')
node_dict = node_impl.node_dict(node)
for key, value in state.items():
# case 1: new state is being added
@@ -765,26 +1081,202 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]):
# updated from raw value
current_value.raw_value = value
+
# --------------------------------------------------------
# UpdateContext
# --------------------------------------------------------
+
+class DynamicCache(tp.NamedTuple):
+ fingerprint: list[tp.Hashable]
+ graphdef: GraphDef[tp.Any]
+ final_graphdef: GraphDef[tp.Any]
+ paths: tuple[PathParts, ...]
+ variables: list[Variable[tp.Any]]
+ new_index_ref: dict[Index, tp.Any]
+
+ @staticmethod
+ def create(
+ fingerprint: list[tp.Hashable],
+ graphdef: GraphDef[tp.Any],
+ paths: tuple[PathParts, ...],
+ variables: list[Variable[tp.Any]],
+ new_ref_index: RefMap,
+ ):
+ new_index_ref = {index: obj for obj, index in new_ref_index.items()}
+ final_graphdef: NodeDef[tp.Any] | NodeRef[tp.Any]
+ if type(graphdef) is NodeDef:
+ final_graphdef = graphdef.with_same_outer_index()
+ else:
+ final_graphdef = graphdef
+ return DynamicCache(
+ fingerprint=fingerprint,
+ graphdef=graphdef,
+ final_graphdef=final_graphdef,
+ paths=paths,
+ variables=variables,
+ new_index_ref=new_index_ref,
+ )
+
+class StaticCache(tp.NamedTuple):
+ graphdef: GraphDef[tp.Any]
+ final_graphdef: GraphDef[tp.Any]
+ paths: tuple[PathParts, ...]
+ variables: list[Variable[tp.Any]]
+ new_ref_index: RefMap
+ new_index_ref: dict[Index, tp.Any]
+
+ @staticmethod
+ def create(
+ graphdef: GraphDef[tp.Any],
+ paths: tuple[PathParts, ...],
+ variables: list[Variable[tp.Any]],
+ new_ref_index: RefMap,
+ ):
+ new_index_ref = {index: obj for obj, index in new_ref_index.items()}
+ final_graphdef: NodeDef[tp.Any] | NodeRef[tp.Any]
+ if type(graphdef) is NodeDef:
+ final_graphdef = graphdef.with_same_outer_index()
+ else:
+ final_graphdef = graphdef
+ return StaticCache(
+ graphdef=graphdef,
+ final_graphdef=final_graphdef,
+ paths=paths,
+ variables=variables,
+ new_ref_index=new_ref_index,
+ new_index_ref=new_index_ref,
+ )
+
@dataclasses.dataclass
class GraphContext(threading.local):
- update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field(
- default_factory=dict
+ update_context_stacks: dict[tp.Hashable, list[UpdateContext]] = (
+ dataclasses.field(default_factory=dict)
)
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
+ dynamic_cache_context: WeakKeyDictionary[
+ tp.Hashable, WeakKeyDictionary[tp.Any, DynamicCache]
+ ] = dataclasses.field(default_factory=WeakKeyDictionary)
+ tmp_static_cache: WeakKeyDictionary[tp.Any, StaticCache] | None = None
+ caching: bool = False
GRAPH_CONTEXT = GraphContext()
+@contextlib.contextmanager
+def static_cache(static_cache: WeakKeyDictionary[tp.Any, StaticCache]):
+ if GRAPH_CONTEXT.caching:
+ yield
+ return
+
+ GRAPH_CONTEXT.tmp_static_cache = static_cache
+
+ try:
+ yield
+ finally:
+ if GRAPH_CONTEXT.tmp_static_cache is not None:
+ raise ValueError(
+ 'GRAPH_CONTEXT.tmp_static_cache should be None, no context consumed it.'
+ )
+
+
+def _cache_args(f: tp.Callable[..., tp.Any], *cached_args):
+ """Create a partial from a NNX transformed function alog with some cached input arguments
+ and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful
+ for speed up function that are called repeatedly with the same subset of inputs e.g. a
+ ``train_step`` with a ``model`` and ``optimizer``::
+
+ >>> from flax import nnx
+ >>> import jax.numpy as jnp
+ >>> import optax
+ ...
+ >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
+ >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3))
+ ...
+ >>> @nnx.jit
+ ... def train_step(model, optimizer, x, y):
+ ... def loss_fn(model):
+ ... return jnp.mean((model(x) - y) ** 2)
+ ...
+ ... loss, grads = nnx.value_and_grad(loss_fn)(model)
+ ... optimizer.update(grads)
+ ... return loss
+ ...
+ >>> faster_train_step = nnx.cache_args(train_step, model, optimizer)
+ ...
+ >>> for step in range(total_steps:=2):
+ ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3))
+ ... # loss = train_step(model, optimizer, x, y)
+ ... loss = faster_train_step(x, y)
+ ... print(f'Step {step}: loss={loss:.3f}')
+ Step 0: loss=1.649
+ Step 1: loss=1.642
+
+ Note that ``cache_args`` will clone all cached graph nodes to gurantee the validity
+ of the cache, and these clones will contain references to the same Variable objects
+ which guarantees that state is propagated correctly back to the original graph nodes.
+ Because of the previous, the final structure of all graph nodes must be the same
+ after each call to the cached function, otherswise an error will be raised. Temporary
+ mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before
+ the function returns (e.g. via ``nnx.pop``).
+
+ Args:
+ f: A function to cache.
+ *cached_args: A subset of the input arguments containing the graph nodes to cache.
+
+ Returns:
+ A partial function expecting the remaining arguments to the original function.
+ """
+ cache: WeakKeyDictionary[tp.Any, StaticCache] = WeakKeyDictionary()
+ original_ref_index: RefMap = RefMap()
+ index_ref: dict[Index, tp.Any] = {}
+ cached_ref_index: RefMap = RefMap()
+
+ def create_static_cache(x):
+ if is_graph_node(x):
+ graphdef, flat_state = flatten(
+ x, with_paths=True, return_variables=True, ref_index=original_ref_index
+ )
+ paths = flat_state.paths
+ variables = flat_state.leaves
+ # clone but keep the same variable references
+ node_cache = unflatten(graphdef, flat_state, index_ref=index_ref)
+ cached_new_ref_index = RefMap()
+ _fp = fingerprint(
+ node_cache,
+ ref_index=cached_ref_index,
+ new_ref_index=cached_new_ref_index,
+ )
+ cached_ref_index.update(cached_new_ref_index)
+ cache[node_cache] = StaticCache.create(
+ graphdef, paths, variables, cached_new_ref_index
+ )
+ return node_cache
+ return x
+
+ cached_args = jax.tree.map(create_static_cache, cached_args)
+
+ @functools.wraps(f)
+ def cache_args_wrapper(*args, **kwargs):
+ with static_cache(cache):
+ return f(*cached_args, *args, **kwargs)
+
+ return cache_args_wrapper
+
+
+if tp.TYPE_CHECKING:
+ cache_args = functools.partial
+else:
+ cache_args = _cache_args
+
+
@dataclasses.dataclass
class SplitContext:
- ctxtag: str | None
- ref_index: RefMap[tp.Any, Index]
+ ctxtag: tp.Hashable | None
+ ref_index: RefMap
+ is_inner: bool | None
@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...
@@ -800,91 +1292,403 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
- ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...
+ ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... # type: ignore[not-supported-yet]
def split(
self, node: A, *filters: filterlib.Filter
- ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]:
+ ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: # type: ignore[not-supported-yet]
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
- graphdef, state = flatten(node, self.ref_index)
- states = _split_state(state, filters)
- if ctx is not None:
- if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
- index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
- graphdef = dataclasses.replace(
- graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
- )
+ inner_ref_outer_index = (
+ ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None
+ )
+ graphdef, flat_state = flatten(
+ node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index
+ )
+ flat_states = _split_state(flat_state, filters)
+ states = tuple(
+ State.from_flat_path(flat_state) for flat_state in flat_states
+ )
return graphdef, *states
+ @tp.overload
+ def flatten(
+ self,
+ graph_node: A,
+ /,
+ *,
+ with_paths: tp.Literal[False],
+ ) -> tuple[GraphDef[A], list[tp.Any]]: ...
+ @tp.overload
+ def flatten(
+ self,
+ graph_node: A,
+ /,
+ ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
+ @tp.overload
+ def flatten(
+ self,
+ graph_node: A,
+ first: filterlib.Filter,
+ /,
+ ) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
+ @tp.overload
+ def flatten(
+ self,
+ graph_node: A,
+ first: filterlib.Filter,
+ second: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+ ) -> tuple[
+ GraphDef[A],
+ FlatState[VariableState[tp.Any]],
+ tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
+ ]: ...
+ def flatten(
+ self,
+ node: A,
+ *filters: filterlib.Filter,
+ with_paths: bool = True,
+ ) -> tuple[
+ GraphDef[A],
+ FlatState[VariableState[tp.Any]] | list[tp.Any],
+ tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
+ ]:
+ if not with_paths and filters:
+ raise ValueError('Cannot use filters with with_paths=False')
+
+ ctx = (
+ current_update_context(self.ctxtag) if self.ctxtag is not None else None
+ )
+ dynamic_cache = (
+ ctx.dynamic_cache if ctx is not None and self.is_inner is False else None
+ )
+ static_cache = (
+ ctx.static_cache if ctx is not None and self.is_inner is False else None
+ )
+ ref_outer_index = (
+ ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None
+ )
+ flat_state: (
+ FlatState[VariableState[tp.Any]]
+ | FlatState[Variable[tp.Any]]
+ | list[tp.Any]
+ )
+ leaves: list[tp.Any]
+ if node in self.ref_index:
+ # node is already in the ref_index, call flatten which will return a NodeRef
+ graphdef, flat_state = flatten(
+ node,
+ ref_index=self.ref_index,
+ ref_outer_index=ref_outer_index,
+ with_paths=with_paths,
+ )
+ if with_paths:
+ assert isinstance(flat_state, FlatState)
+ paths = flat_state.paths
+ leaves = flat_state.leaves
+ else:
+ assert isinstance(flat_state, list)
+ paths = None
+ leaves = flat_state
+ elif static_cache is not None and node in static_cache:
+ node_static_cache = static_cache[node]
+ graphdef = node_static_cache.graphdef
+ # add the new references to the ref_index
+ self.ref_index.update(node_static_cache.new_ref_index)
+
+ if with_paths:
+ paths = node_static_cache.paths
+ leaves = [
+ variable.to_state() for variable in node_static_cache.variables
+ ]
+ else:
+ paths = None
+ leaves = [
+ variable.raw_value for variable in node_static_cache.variables
+ ]
+
+ elif dynamic_cache is not None and node in dynamic_cache:
+ node_dynamic_cache = dynamic_cache[node]
+ cache_fp = node_dynamic_cache.fingerprint
+ new_ref_index: RefMap = RefMap()
+ fp_matches = check_fingerprint(
+ node, cache_fp, ref_index=self.ref_index, new_ref_index=new_ref_index
+ )
+ if fp_matches:
+ graphdef = node_dynamic_cache.graphdef
+ self.ref_index.update(new_ref_index)
+
+ if with_paths:
+ paths = node_dynamic_cache.paths
+ leaves = [
+ variable.to_state() for variable in node_dynamic_cache.variables
+ ]
+ else:
+ paths = None
+ leaves = [
+ variable.raw_value for variable in node_dynamic_cache.variables
+ ]
+ else:
+ del cache_fp
+ del node_dynamic_cache
+ new_ref_index = RefMap()
+ node_fp = fingerprint(
+ node, ref_index=self.ref_index, new_ref_index=new_ref_index
+ )
+ graphdef, flat_state = flatten(
+ node,
+ ref_index=self.ref_index,
+ ref_outer_index=ref_outer_index,
+ with_paths=True,
+ return_variables=True,
+ )
+ paths = flat_state.paths
+ variables = flat_state.leaves
+ assert paths is not None
+ if with_paths:
+ leaves = [variable.to_state() for variable in variables]
+ else:
+ leaves = [variable.raw_value for variable in variables]
+ dynamic_cache[node] = DynamicCache.create(
+ node_fp, graphdef, paths, variables, new_ref_index
+ )
+ elif dynamic_cache is not None: # node not in cache_context
+ new_ref_index = RefMap()
+ node_fp = fingerprint(
+ node, ref_index=self.ref_index, new_ref_index=new_ref_index
+ )
+ graphdef, flat_state = flatten(
+ node,
+ ref_index=self.ref_index,
+ ref_outer_index=ref_outer_index,
+ with_paths=True,
+ return_variables=True,
+ )
+ paths = flat_state.paths
+ variables = flat_state.leaves
+ if with_paths:
+ leaves = [variable.to_state() for variable in variables]
+ else:
+ leaves = [variable.raw_value for variable in variables]
+ dynamic_cache[node] = DynamicCache.create(
+ node_fp, graphdef, paths, variables, new_ref_index
+ )
+ else:
+ graphdef, flat_state = flatten(
+ node,
+ ref_index=self.ref_index,
+ ref_outer_index=ref_outer_index,
+ with_paths=with_paths,
+ )
+ if with_paths:
+ assert isinstance(flat_state, FlatState)
+ paths = flat_state.paths
+ leaves = flat_state.leaves
+ else:
+ assert isinstance(flat_state, list)
+ paths = None
+ leaves = flat_state
+
+ if with_paths:
+ assert paths is not None
+ flat_state = FlatState.from_sorted_keys_values(paths, leaves)
+ flat_states = _split_state(flat_state, filters)
+ return graphdef, *flat_states # type: ignore[bad-return-type]
+ else:
+ return graphdef, leaves
+
@contextlib.contextmanager
-def split_context(ctxtag: str | None = None):
- index_ref: RefMap[tp.Any, Index] = RefMap()
- flatten_ctx = SplitContext(ctxtag, index_ref)
- GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx)
+def split_context(ctxtag: tp.Hashable | None = None):
+ ctx = current_update_context(ctxtag) if ctxtag is not None else None
+ is_inner = ctx.outer_ref_outer_index is not None if ctx is not None else None
+ GRAPH_CONTEXT.ref_index_stack.append(SplitContext(ctxtag, RefMap(), is_inner))
try:
- yield flatten_ctx
+ yield GRAPH_CONTEXT.ref_index_stack[-1]
finally:
- GRAPH_CONTEXT.ref_index_stack.pop()
+ flatten_ctx = GRAPH_CONTEXT.ref_index_stack.pop()
if ctxtag is not None:
ctx = current_update_context(ctxtag)
- ctx.flatten_end(index_ref)
+ ctx.flatten_end(flatten_ctx.ref_index)
del flatten_ctx.ref_index
del flatten_ctx.ctxtag
@dataclasses.dataclass
class MergeContext:
- ctxtag: str | None
+ ctxtag: tp.Hashable | None
index_ref: dict[Index, tp.Any]
+ is_inner: bool | None
def merge(
- self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState
+ self,
+ graphdef: GraphDef[A],
+ state: GraphState,
+ /,
+ *states: GraphState,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
- if (
- ctx is not None
- and isinstance(graphdef, NodeDef)
- and graphdef.index_mapping is not None
- ):
- # outer merge (4), create index_ref_cache
- assert ctx.ref_index is not None
- index_ref_cache = compose_mapping_reversed(
- ctx.ref_index, graphdef.index_mapping
- )
- else:
- # inner merge (2)
- index_ref_cache = None
+ outer_index_outer_ref = (
+ ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None
+ )
- state = State.merge(state, *states)
+ _state = State.merge(state, *states)
node = unflatten(
graphdef,
- state,
+ _state,
index_ref=self.index_ref,
- index_ref_cache=index_ref_cache,
+ outer_index_outer_ref=outer_index_outer_ref,
)
return node
+ def unflatten(
+ self,
+ graphdef: GraphDef[A],
+ flat_state: GraphFlatState | list[tp.Any],
+ /,
+ *flat_states: GraphFlatState,
+ ) -> A:
+ ctx = (
+ current_update_context(self.ctxtag) if self.ctxtag is not None else None
+ )
+ dynamic_cache = (
+ ctx.dynamic_cache if ctx is not None and self.is_inner is False else None
+ )
+ static_cache = (
+ ctx.static_cache if ctx is not None and self.is_inner is False else None
+ )
+ state: FlatState[tp.Any] | list[tp.Any]
+ if type(flat_state) is list:
+ if flat_states:
+ raise ValueError(
+ 'Cannot use multiple flat_states when flat_state is a list, '
+ f'got flat_state: {flat_state!r}, flat_states: {flat_states!r}'
+ )
+ state = flat_state
+ else:
+ state = FlatState.merge(flat_state, *flat_states)
-@contextlib.contextmanager
-def merge_context(ctxtag: str | None = None):
- index_ref: dict[Index, tp.Any] = {}
+ if type(graphdef) is NodeRef:
+ node = unflatten(
+ graphdef,
+ state,
+ index_ref=self.index_ref,
+ )
- unflatten_ctx = MergeContext(ctxtag, index_ref)
- GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx)
+ elif dynamic_cache is not None or static_cache is not None:
+ assert isinstance(graphdef, NodeDef)
+ assert ctx is not None
+ if (outer_index := graphdef.outer_index) is not None:
+ outer_index_outer_ref = ctx.outer_index_outer_ref
+ assert outer_index_outer_ref is not None
+ node = outer_index_outer_ref[outer_index]
+
+ if static_cache and node in static_cache:
+ static_cache_node = static_cache[node]
+ if static_cache_node.final_graphdef != graphdef:
+ raise ValueError(
+ 'The graph structure of a node added to cache_args was mutated inside the transformation, '
+ f'this is not allowed.\nNode: {node}\nOuput graphdef: {graphdef}\nExpected graphdef: {static_cache_node.final_graphdef}'
+ )
+ if type(state) is list:
+ leaves = state
+ elif type(state) is FlatState:
+ leaves = state.leaves
+ else:
+ raise ValueError(f'Unsupported state type: {type(state)}')
+
+ if len(leaves) != len(static_cache_node.variables):
+ raise ValueError(
+ f'Incorrect number of leaves: expected {len(static_cache_node.variables)} '
+ f'leaves in the state, got {len(leaves)}'
+ )
+ for variable, leaf in zip(static_cache_node.variables, leaves):
+ if type(leaf) is VariableState:
+ variable.update_from_state(leaf)
+ else:
+ variable.raw_value = leaf
+ self.index_ref.update(static_cache_node.new_index_ref)
+ elif dynamic_cache and node in dynamic_cache:
+ # node is in cache_context, retrieve its cache
+ dyn_cache_node = dynamic_cache[node]
+ # check if the graphdef is the same
+ if dyn_cache_node.final_graphdef == graphdef:
+ if type(state) is list:
+ leaves = state
+ elif type(state) is FlatState: # type: ignore
+ leaves = state.leaves
+ else:
+ raise ValueError(f'Unsupported state type: {type(state)}')
+
+ # graphdefs match, update variables from state
+ if len(leaves) != len(dyn_cache_node.variables):
+ raise ValueError(
+ f'Incorrect number of leaves: expected {len(dyn_cache_node.variables)} '
+ f'leaves in the state, got {len(leaves)}'
+ )
+ for variable, leaf in zip(dyn_cache_node.variables, leaves):
+ if type(leaf) is VariableState:
+ variable.update_from_state(leaf)
+ else:
+ variable.raw_value = leaf
+ self.index_ref.update(dyn_cache_node.new_index_ref)
+ else: # cache.graphdef != graphdef_fp
+ # graph changed, re-create the node
+ node = unflatten(
+ graphdef,
+ state,
+ index_ref=self.index_ref,
+ outer_index_outer_ref=outer_index_outer_ref,
+ )
+ else:
+ # all nodes in index_ref_cache must be in cache_context
+ raise RuntimeError(f'Node not found in cache_context, node: {node}')
+ else: # graphdef.outer_index is None
+ # its a new node, create it
+ node = unflatten(
+ graphdef,
+ state,
+ index_ref=self.index_ref,
+ )
+ else:
+ outer_index_outer_ref = (
+ ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None
+ )
+ node = unflatten(
+ graphdef,
+ state,
+ index_ref=self.index_ref,
+ outer_index_outer_ref=outer_index_outer_ref,
+ )
+ return node
+
+
+@tp.overload
+@contextlib.contextmanager
+def merge_context(): ...
+@tp.overload
+@contextlib.contextmanager
+def merge_context(inner: bool | None, ctxtag: tp.Hashable | None): ...
+@contextlib.contextmanager
+def merge_context(inner: bool | None = None, ctxtag: tp.Hashable | None = None):
+ GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, {}, inner))
try:
- yield unflatten_ctx
+ yield GRAPH_CONTEXT.index_ref_stack[-1]
finally:
- GRAPH_CONTEXT.index_ref_stack.pop()
+ unflatten_ctx = GRAPH_CONTEXT.index_ref_stack.pop()
+ index_ref = unflatten_ctx.index_ref
if ctxtag is not None:
+ if inner is None:
+ raise ValueError('inner_merge must be specified when using ctxtag')
ctx = current_update_context(ctxtag)
- ctx.unflatten_end(index_ref)
+ ctx.unflatten_end(index_ref, inner)
del unflatten_ctx.index_ref
del unflatten_ctx.ctxtag
@@ -893,9 +1697,14 @@ def merge_context(ctxtag: str | None = None):
class UpdateContext:
"""A context manager for handling complex state updates."""
- tag: str
- ref_index: RefMap[tp.Any, Index] | None
- index_ref: dict[Index, tp.Any] | None
+ tag: tp.Hashable
+ outer_ref_outer_index: RefMap | None
+ outer_index_inner_ref: dict[Index, tp.Any] | None
+ # reverse caches
+ outer_index_outer_ref: dict[Index, tp.Any] | None
+ inner_ref_outer_index: RefMap | None
+ dynamic_cache: WeakKeyDictionary[tp.Any, DynamicCache] | None
+ static_cache: WeakKeyDictionary[tp.Any, StaticCache] | None
# define hash and eq to make this an opaque object
def __hash__(self):
@@ -904,16 +1713,25 @@ def __hash__(self):
def __eq__(self, other):
return isinstance(other, UpdateContext)
- def flatten_end(self, ref_index: RefMap[tp.Any, Index]):
- if self.ref_index is None:
+ def flatten_end(self, ref_index: RefMap):
+ if self.outer_ref_outer_index is None:
# outer split (1), store the references
- self.ref_index = ref_index
+ self.outer_ref_outer_index = ref_index
+ self.outer_index_outer_ref = {
+ index: obj for obj, index in self.outer_ref_outer_index.items()
+ }
else:
# inner split (3), clear index_ref
- self.index_ref = None
+ self.outer_index_inner_ref = None
+ self.inner_ref_outer_index = None
- def unflatten_end(self, index_ref: dict[Index, tp.Any]):
- self.index_ref = index_ref
+ def unflatten_end(self, index_ref: dict[Index, tp.Any], inner_merge: bool):
+ if inner_merge:
+ # inner merge (2)
+ self.outer_index_inner_ref = index_ref
+ self.inner_ref_outer_index = RefMap(
+ {obj: index for index, obj in index_ref.items()}
+ )
@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...
@@ -996,19 +1814,17 @@ def split(
:class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no
filters are passed, a single :class:`State` is returned.
"""
- ref_index: RefMap[tp.Any, Index] = RefMap()
- graphdef, state = flatten(node, ref_index)
- states = _split_state(state, filters)
-
- if self.index_ref is not None and isinstance(graphdef, NodeDef):
- index_to_index = compose_mapping(self.index_ref, ref_index)
- graphdef = dataclasses.replace(
- graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
- )
-
+ ref_index: RefMap = RefMap()
+ graphdef, flat_state = flatten(
+ node, ref_index=ref_index, ref_outer_index=self.inner_ref_outer_index
+ )
+ states = tuple(
+ State.from_flat_path(flat_state)
+ for flat_state in _split_state(flat_state, filters)
+ )
+ assert len(states) >= 1
self.flatten_end(ref_index)
-
- return graphdef, *states
+ return graphdef, *states # type: ignore[return-value]
def merge(
self,
@@ -1021,15 +1837,13 @@ def merge(
raise ValueError(
f'Expected a NodeDef instance, but got {type(graphdef)}.'
)
- if self.ref_index is None:
+ if self.outer_ref_outer_index is None:
raise ValueError('Cannot merge without ref_index.')
- if graphdef.index_mapping is not None:
+ if self.outer_ref_outer_index is not None:
# outer merge (4), create index_ref_cache
- assert self.ref_index is not None
- index_ref_cache = compose_mapping_reversed(
- self.ref_index, graphdef.index_mapping
- )
+ index_ref_cache = self.outer_index_outer_ref
+ assert index_ref_cache is not None
else:
# inner merge (2)
index_ref_cache = None
@@ -1037,10 +1851,13 @@ def merge(
state = State.merge(state, *states)
index_ref: dict[Index, tp.Any] = {}
node = unflatten(
- graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache
+ graphdef,
+ state,
+ index_ref=index_ref,
+ outer_index_outer_ref=index_ref_cache,
)
- self.unflatten_end(index_ref)
+ self.unflatten_end(index_ref, True)
return node
@@ -1050,10 +1867,31 @@ def merge(
@dataclasses.dataclass
class UpdateContextManager:
- tag: str
+ tag: tp.Hashable
+ use_dynamic_cache: bool
def __enter__(self):
- ctx = UpdateContext(self.tag, None, None)
+ dynamic_cache: WeakKeyDictionary[tp.Any, DynamicCache] | None
+ if self.use_dynamic_cache:
+ dynamic_cache = WeakKeyDictionary()
+ else:
+ dynamic_cache = None
+
+ if GRAPH_CONTEXT.tmp_static_cache is not None:
+ # take current static cache
+ static_cache = GRAPH_CONTEXT.tmp_static_cache
+ GRAPH_CONTEXT.tmp_static_cache = None
+ else:
+ static_cache = None
+ ctx = UpdateContext(
+ tag=self.tag,
+ outer_ref_outer_index=None,
+ outer_index_inner_ref=None,
+ outer_index_outer_ref=None,
+ inner_ref_outer_index=None,
+ dynamic_cache=dynamic_cache,
+ static_cache=static_cache,
+ )
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx]
else:
@@ -1069,8 +1907,10 @@ def __exit__(self, *args):
ctx = stack.pop()
# clear references
- del ctx.ref_index
- del ctx.index_ref
+ del ctx.outer_ref_outer_index
+ del ctx.outer_index_inner_ref
+ del ctx.outer_index_outer_ref
+ del ctx.inner_ref_outer_index
if not stack:
del GRAPH_CONTEXT.update_context_stacks[self.tag]
@@ -1084,7 +1924,7 @@ def update_context_manager_wrapper(*args, **kwargs):
return update_context_manager_wrapper # type: ignore
-def update_context(tag: str):
+def update_context(tag: tp.Hashable, *, use_dynamic_cache: bool = False):
"""Creates an :class:`UpdateContext` context manager which can be used to handle
more complex state updates beyond what ``nnx.update`` can handle, including
updates to static properties and graph structure.
@@ -1176,10 +2016,10 @@ def update_context(tag: str):
Args:
tag: A string tag to identify the context.
"""
- return UpdateContextManager(tag)
+ return UpdateContextManager(tag=tag, use_dynamic_cache=use_dynamic_cache)
-def current_update_context(tag: str) -> UpdateContext:
+def current_update_context(tag: tp.Hashable) -> UpdateContext:
"""Returns the current active :class:`UpdateContext` for the given tag."""
if tag not in GRAPH_CONTEXT.update_context_stacks:
raise ValueError(f'No update context found for tag {tag!r}.')
@@ -1191,14 +2031,14 @@ def current_update_context(tag: str) -> UpdateContext:
# --------------------------------------------------------
def _split_state(
- state: GraphState,
+ state: FlatState[tp.Any],
filters: tuple[filterlib.Filter, ...],
-) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
+) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
if not filters:
- return (state,)
+ return (state,) # type: ignore[bad-return-type]
states = state.split(*filters)
- if isinstance(states, State):
- return (states,)
+ if not isinstance(states, tuple):
+ return (states,) # type: ignore[bad-return-type]
assert len(states) > 0
return states # type: ignore[return-value]
@@ -1288,15 +2128,17 @@ def split(
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
"""
- graphdef, state = flatten(node)
- states = _split_state(state, filters)
- return graphdef, *states
+ graphdef, flat_state = flatten(node)
+ flat_states = _split_state(flat_state, filters)
+ states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+ return graphdef, *states # type: ignore[return-value]
+
def merge(
graphdef: GraphDef[A],
- state: tp.Mapping[KeyT, tp.Any],
+ state: tp.Mapping[Key, tp.Any],
/,
- *states: tp.Mapping[KeyT, tp.Any],
+ *states: tp.Mapping[Key, tp.Any],
) -> A:
"""The inverse of :func:`flax.nnx.split`.
@@ -1341,8 +2183,8 @@ def merge(
Returns:
The merged :class:`flax.nnx.Module`.
"""
- state = State.merge(state, *states)
- node = unflatten(graphdef, state)
+ _state = State.merge(state, *states)
+ node = unflatten(graphdef, _state)
return node
@@ -1481,7 +2323,8 @@ def state(
Returns:
One or more :class:`State` mappings.
"""
- _, state = flatten(node)
+ _, flat_state = flatten(node)
+ state = flat_state.to_nested_state()
states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
@@ -1748,23 +2591,16 @@ def _iter_graph(
if id(node) in visited:
return
visited.add(id(node))
- node_dict = get_node_impl(node).node_dict(node)
+ node_impl = get_node_impl(node)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
+ node_dict = node_impl.node_dict(node)
for key, value in node_dict.items():
yield from _iter_graph(value, visited, (*path_parts, key))
yield path_parts, node
-def compose_mapping(
- map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], /
-) -> dict[A, C]:
- return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc}
-
-
-def compose_mapping_reversed(
- map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], /
-) -> dict[C, A]:
- return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc}
@dataclasses.dataclass(frozen=True)
@@ -1783,21 +2619,15 @@ class Static(tp.Generic[A]):
# ---------------------------------------------------------
class GenericPytree: ...
+from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY
def is_pytree_node(x: tp.Any) -> bool:
- t = type(x)
- if t in PYTREE_REGISTRY:
+ if type(x) in JAX_PYTREE_REGISTRY:
return True
- elif t in GRAPH_REGISTRY:
- return False
- # known non-pytree types
- elif isinstance(x, Variable):
- return False
- # known pytree types
- elif type(x) is VariableState or type(x) is State:
+ elif isinstance(x, tuple):
return True
else:
- return not jax.tree_util.all_leaves((x,))
+ return False
def _key_path_to_key(key: tp.Any) -> Key:
@@ -1806,7 +2636,7 @@ def _key_path_to_key(key: tp.Any) -> Key:
elif isinstance(
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
):
- if not is_key_like(key.key):
+ if not is_key_like(key.key): # type: ignore[not-supported-yet]
raise ValueError(
f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
)
@@ -1816,20 +2646,28 @@ def _key_path_to_key(key: tp.Any) -> Key:
else:
return str(key)
+class IndexesPytreeDef(tp.NamedTuple):
+ key_index: HashableMapping[Key, int]
+ treedef: jax.tree_util.PyTreeDef
def _flatten_pytree(pytree: tp.Any):
leaves, treedef = jax.tree_util.tree_flatten_with_path(
pytree, is_leaf=lambda x: x is not pytree
)
- nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves)
-
- return nodes, treedef
+ nodes = [(_key_path_to_key(path[0]), value) for path, value in leaves]
+ key_index = HashableMapping(
+ {key: i for i, (key, _) in enumerate(nodes)}, copy=False
+ )
+ nodes.sort() # sort by key
+ return nodes, IndexesPytreeDef(key_index, treedef)
def _unflatten_pytree(
- nodes: tuple[tuple[Key, tp.Any], ...], treedef: jax.tree_util.PyTreeDef
+ nodes: tuple[tuple[Key, tp.Any], ...], metadata: IndexesPytreeDef
):
- pytree = treedef.unflatten(value for _, value in nodes)
+ # sort to original order
+ sorted_nodes = sorted(nodes, key=lambda x: metadata.key_index[x[0]])
+ pytree = metadata.treedef.unflatten(value for _, value in sorted_nodes)
return pytree
diff --git a/flax/nnx/helpers.py b/flax/nnx/helpers.py
index 96622f0e4..c72d7ce8c 100644
--- a/flax/nnx/helpers.py
+++ b/flax/nnx/helpers.py
@@ -20,6 +20,7 @@
import jax.numpy as jnp
import optax
+from flax.nnx import graph
from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import ApplyCaller
from flax.nnx.rnglib import Rngs
@@ -62,6 +63,10 @@ def __iter__(self) -> tp.Iterator[str]:
def __len__(self) -> int:
return len(vars(self))
+ def __hash__(self) -> int:
+ return id(self)
+
+
class Sequential(Module):
def __init__(self, *fns: tp.Callable[..., tp.Any]):
self.layers = list(fns)
@@ -97,7 +102,7 @@ def __call__(
class TrainState(tp.Generic[M], struct.PyTreeNode):
- graphdef: GraphDef[M]
+ graphdef: graph.NodeDef[M]
params: State
opt_state: optax.OptState
step: jax.Array
@@ -106,7 +111,7 @@ class TrainState(tp.Generic[M], struct.PyTreeNode):
@classmethod
def create(
cls,
- graphdef: GraphDef[M],
+ graphdef: graph.NodeDef[M],
*,
params: State,
tx: optax.GradientTransformation,
diff --git a/flax/nnx/module.py b/flax/nnx/module.py
index b07efa771..9a4338463 100644
--- a/flax/nnx/module.py
+++ b/flax/nnx/module.py
@@ -257,7 +257,10 @@ def iter_children(self) -> tp.Iterator[tuple[Key, Module]]:
linear Linear
submodule SubModule
"""
- node_dict = graph.get_node_impl(self).node_dict(self)
+ node_impl = graph.get_node_impl(self)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(self)}, this is a bug.')
+ node_dict = node_impl.node_dict(self)
for key, value in node_dict.items():
if isinstance(value, Module):
yield key, value
diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py
index add545634..a3313bf6e 100644
--- a/flax/nnx/nn/stochastic.py
+++ b/flax/nnx/nn/stochastic.py
@@ -125,3 +125,6 @@ def __call__(
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
+
+ def __hash__(self):
+ return id(self)
diff --git a/flax/nnx/object.py b/flax/nnx/object.py
index b1f7478ee..a49d13d4e 100644
--- a/flax/nnx/object.py
+++ b/flax/nnx/object.py
@@ -60,7 +60,10 @@ def _collect_stats(
stats[var_type] = size_bytes
else:
- node_dict = graph.get_node_impl(node).node_dict(node)
+ node_impl = graph.get_node_impl(node)
+ if node_impl is None:
+ raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
+ node_dict = node_impl.node_dict(node)
for key, value in node_dict.items():
if id(value) in node_stats:
continue
diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py
index 155c2e7e9..58a004145 100644
--- a/flax/nnx/reprlib.py
+++ b/flax/nnx/reprlib.py
@@ -235,6 +235,7 @@ def __nnx_repr__(self):
for key, value in self.mapping.items():
yield Attr(colorized(key), value, use_raw_key=True)
+
@dataclasses.dataclass(repr=False)
class SequenceReprMixin(Representable):
def __nnx_repr__(self):
diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py
index ab9817aca..0fef2c173 100644
--- a/flax/nnx/rnglib.py
+++ b/flax/nnx/rnglib.py
@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
-import dataclasses
import functools
import typing as tp
@@ -48,7 +47,6 @@ class RngKey(RngState): ...
NotKey = filterlib.All(RngState, filterlib.Not(RngKey))
-@dataclasses.dataclass(repr=False)
class RngStream(Object):
def __init__(
self,
@@ -56,13 +54,12 @@ def __init__(
key: jax.Array,
count: jax.Array,
):
+ if not isinstance(key, jax.Array):
+ raise TypeError(f'key must be a jax.Array, got {type(key)}')
+
self.key = RngKey(key, tag=tag)
self.count = RngCount(count, tag=tag)
- def __post_init__(self):
- if not isinstance(self.key, jax.Array):
- raise TypeError(f'key must be a jax.Array, got {type(self.key)}')
-
def __call__(self) -> jax.Array:
self.check_valid_context(
lambda: 'Cannot call RngStream from a different trace level'
@@ -80,7 +77,7 @@ def __call__(self) -> jax.Array:
]
-class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]):
+class Rngs(Object):
"""NNX rng container class. To instantiate the ``Rngs``, pass
in an integer, specifying the starting seed. ``Rngs`` can have
different "streams", allowing the user to generate different
@@ -237,6 +234,10 @@ def __getstate__(self):
def __setstate__(self, state):
vars(self).update(state)
+ def items(self):
+ for name in self:
+ yield name, self[name]
+
class ForkStates(tp.NamedTuple):
split_keys: State
diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py
index 38cb3da75..2b357be32 100644
--- a/flax/nnx/statelib.py
+++ b/flax/nnx/statelib.py
@@ -23,7 +23,7 @@
from flax.nnx import traversals
from flax.nnx import filterlib, reprlib
from flax.nnx import variablelib
-from flax.typing import PathParts
+from flax.typing import Key, PathParts
A = tp.TypeVar('A')
K = tp.TypeVar('K', bound=tp.Hashable)
@@ -54,26 +54,45 @@ def __treescope_repr__(self, path, subtree_renderer):
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)
-class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.SequenceReprMixin):
+class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.Representable):
+ __slots__ = ('_keys', '_values')
+
_keys: tuple[PathParts, ...]
_values: list[V]
- def __init__(self, items: tp.Iterable[tuple[PathParts, V]]):
+ def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort: bool):
keys, values = [], []
+ if sort:
+ items = sorted(items)
for key, value in items:
keys.append(key)
values.append(value)
self._keys = tuple(keys)
self._values = values
+ @staticmethod
+ def from_sorted_keys_values(
+ keys: tuple[PathParts, ...], values: list[V], /
+ ) -> FlatState[V]:
+ flat_state = object.__new__(FlatState)
+ flat_state._keys = keys
+ flat_state._values = values
+ return flat_state
+
@property
- def paths(self) -> tp.Sequence[PathParts]:
+ def paths(self) -> tp.Tuple[PathParts, ...]:
return self._keys
@property
- def leaves(self) -> tp.Sequence[V]:
+ def leaves(self) -> list[V]:
return self._values
+ def __nnx_repr__(self):
+ yield reprlib.Object(type='FlatState', kv_sep='', start='([', end='])')
+
+ for value in self:
+ yield reprlib.Attr('', value)
+
@tp.overload
def __getitem__(self, index: int) -> tuple[PathParts, V]: ...
@tp.overload
@@ -83,7 +102,7 @@ def __getitem__(
) -> tuple[PathParts, V] | FlatState[V]:
if isinstance(index, int):
return self._keys[index], self._values[index]
- return FlatState(zip(self._keys[index], self._values[index]))
+ return FlatState(zip(self._keys[index], self._values[index]), sort=False)
def __len__(self) -> int:
return len(self._keys)
@@ -91,6 +110,91 @@ def __len__(self) -> int:
def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
return iter(zip(self._keys, self._values))
+ def to_nested_state(self) -> State[Key, V]:
+ return State.from_flat_path(self)
+
+ @tp.overload
+ def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...
+
+ @tp.overload
+ def split(
+ self,
+ first: filterlib.Filter,
+ second: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+ ) -> tuple[FlatState[V], ...]: ...
+
+ @tp.overload
+ def split(
+ self, /, *filters: filterlib.Filter
+ ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ...
+
+ def split( # type: ignore[misc]
+ self, first: filterlib.Filter, /, *filters: filterlib.Filter
+ ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
+ filters = (first, *filters)
+ *flat_states_, rest = _split_state(self, *filters)
+
+ if rest:
+ raise ValueError(
+ 'Non-exhaustive filters, got a non-empty remainder: '
+ f'{rest}.\nUse `...` to match all remaining elements.'
+ )
+
+ flat_states: FlatState[V] | tuple[FlatState[V], ...]
+ if len(flat_states_) == 1:
+ flat_states = flat_states_[0]
+ else:
+ flat_states = tuple(flat_states_)
+ return flat_states # type: ignore
+
+ @tp.overload
+ def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ...
+
+ @tp.overload
+ def filter(
+ self,
+ first: filterlib.Filter,
+ second: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+ ) -> tuple[FlatState[V], ...]: ...
+
+ def filter(
+ self,
+ first: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+ ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]:
+ *flat_states_, _rest = _split_state(self, first, *filters)
+
+ assert len(flat_states_) == len(filters) + 1
+
+ flat_states: FlatState[V] | tuple[FlatState[V], ...]
+ if len(flat_states_) == 1:
+ flat_states = flat_states_[0]
+ else:
+ flat_states = tuple(flat_states_)
+
+ return flat_states # type: ignore
+
+ @staticmethod
+ def merge(
+ flat_state: tp.Iterable[tuple[PathParts, V]],
+ /,
+ *flat_states: tp.Iterable[tuple[PathParts, V]],
+ ) -> FlatState[V]:
+ if not flat_states:
+ if isinstance(flat_state, FlatState):
+ return flat_state
+ return FlatState(flat_state, sort=True)
+ flat_states = (flat_state, *flat_states)
+
+ return FlatState(
+ (elem for flat_state in flat_states for elem in flat_state), sort=True
+ )
+
def _flat_state_pytree_flatten(x: FlatState[V]):
return x._values, x._keys
@@ -211,7 +315,7 @@ def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]:
return State.from_flat_path(result)
def flat_state(self) -> FlatState[V]:
- return FlatState(traversals.flatten_to_sequence(self._mapping))
+ return FlatState(traversals.flatten_to_sequence(self._mapping), sort=True)
@classmethod
def from_flat_path(
@@ -299,7 +403,8 @@ def split( # type: ignore[misc]
One or more ``States`` equal to the number of filters passed.
"""
filters = (first, *filters)
- *states_, rest = _split_state(self.flat_state(), *filters)
+ flat_states = _split_state(self.flat_state(), *filters)
+ *states_, rest = (state.to_nested_state() for state in flat_states)
if rest:
raise ValueError(
@@ -364,7 +469,8 @@ def filter(
Returns:
One or more ``States`` equal to the number of filters passed.
"""
- *states_, _rest = _split_state(self.flat_state(), first, *filters)
+ flat_states = _split_state(self.flat_state(), first, *filters)
+ *states_, _rest = (state.to_nested_state() for state in flat_states)
assert len(states_) == len(filters) + 1
@@ -464,7 +570,7 @@ def _state_unflatten(
def _split_state(
flat_state: FlatState[V],
*filters: filterlib.Filter,
-) -> tuple[State[PathParts, V], ...]:
+) -> tuple[FlatState[V], ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
@@ -490,7 +596,7 @@ def _split_state(
# if we didn't break, set leaf to last state
flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here?
- return tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+ return tuple(FlatState(flat_state, sort=False) for flat_state in flat_states)
def create_path_filters(state: State):
diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py
index a7b72b154..c53bbd5c4 100644
--- a/flax/nnx/tracers.py
+++ b/flax/nnx/tracers.py
@@ -18,7 +18,7 @@
import jax
import jax.core
-from flax.nnx import reprlib, visualization
+from flax.nnx import reprlib
def current_jax_trace():
@@ -47,11 +47,12 @@ def __nnx_repr__(self):
yield reprlib.Attr('jax_trace', self._jax_trace)
def __treescope_repr__(self, path, subtree_renderer):
- return visualization.render_object_constructor(
- object_type=type(self),
- attributes={'jax_trace': self._jax_trace},
- path=path,
- subtree_renderer=subtree_renderer,
+ import treescope # type: ignore[import-not-found,import-untyped]
+ return treescope.repr_lib.render_object_constructor(
+ object_type=type(self),
+ attributes={'jax_trace': self._jax_trace},
+ path=path,
+ subtree_renderer=subtree_renderer,
)
def __eq__(self, other):
diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py
index 5ef0d183b..164c6d237 100644
--- a/flax/nnx/transforms/autodiff.py
+++ b/flax/nnx/transforms/autodiff.py
@@ -64,24 +64,26 @@ class DiffState:
class GradFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
+ nondiff_states: deque[State | None]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args):
# rebuild diff_state from substates in args
- nondiff_states: deque[State | None] = extract.get_broadcast_state('grad')
def _grad_merge_fn(
ctx: graph.MergeContext, path, prefix, value: extract.NodeStates
):
- nondiff = nondiff_states.popleft()
+ nondiff = self.nondiff_states.popleft()
if nondiff is None:
return ctx.merge(value.graphdef, value.state)
else:
return ctx.merge(value.graphdef, value.state, nondiff)
- args = extract.from_tree(pure_args, merge_fn=_grad_merge_fn, ctxtag='grad')
+ args = extract.from_tree(
+ pure_args, merge_fn=_grad_merge_fn, ctxtag='grad', is_inner=True
+ )
out = self.f(*args)
@@ -129,15 +131,6 @@ def _grad_general(
else DiffState(-1, variablelib.Param)
)
- gradded_fn = transform(
- GradFn(f, has_aux),
- argnums=jax_argnums,
- has_aux=True,
- holomorphic=holomorphic,
- allow_int=allow_int,
- reduce_axes=reduce_axes,
- )
-
@graph.update_context('grad')
def grad_wrapper(*args, **kwargs):
args = resolve_kwargs(f, args, kwargs)
@@ -152,7 +145,7 @@ def _grad_split_fn(
return extract.NodeStates.from_split(*ctx.split(value))
else:
graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
- nondiff_states.append(nondiff)
+ nondiff_states.append(nondiff) # type: ignore[container-type-mismatch]
return extract.NodeStates.from_split(graphdef, diff)
arg_filters = tuple(index_filter.get(i) for i in range(len(args)))
@@ -160,8 +153,16 @@ def _grad_split_fn(
args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
)
- with extract.broadcast_state('grad', nondiff_states):
- fn_out = gradded_fn(*pure_args)
+ gradded_fn = transform(
+ GradFn(f, has_aux, nondiff_states),
+ argnums=jax_argnums,
+ has_aux=True,
+ holomorphic=holomorphic,
+ allow_int=allow_int,
+ reduce_axes=reduce_axes,
+ )
+
+ fn_out = gradded_fn(*pure_args)
def process_grads(grads):
return jax.tree.map(
@@ -171,7 +172,7 @@ def process_grads(grads):
)
def process_out(pure_out: A, /) -> A:
- return extract.from_tree(pure_out, ctxtag='grad')
+ return extract.from_tree(pure_out, ctxtag='grad', is_inner=False)
if return_value:
# unpack value_and_grad output
@@ -427,11 +428,11 @@ def _custom_vjp_split_fn(
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)
-def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]):
+def _extract_nodedefs(x, *, nodedefs: deque[graph.NodeDef]):
if isinstance(x, graph.NodeDef):
- assert x.index_mapping is not None
- index_mappings.append(x.index_mapping)
- return dataclasses.replace(x, index_mapping=None)
+ assert x.outer_index is not None
+ nodedefs.append(x)
+ return x.with_no_outer_index()
return x
@dataclasses.dataclass(eq=False)
@@ -440,6 +441,7 @@ class CustomVjpFnWrapper:
jax_nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
+ nodedefs: deque[graph.NodeDef]
def __post_init__(self):
functools.update_wrapper(self, self.f)
@@ -452,6 +454,7 @@ def __call__(self, *pure_args):
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
+ is_inner=True,
)
out = self.f(*args)
@@ -464,13 +467,10 @@ def __call__(self, *pure_args):
pure_args_out, pure_out = extract.to_tree(
(args_out, out), ctxtag=self.ctxtag
)
- # remove index_mapping from NodeDef's but store them in global context
- index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state(
- self.ctxtag
- )
+ # remove outer_index from NodeDef's but store them in global context
pure_args_out, pure_out = jax.tree.map(
- functools.partial(_extract_index_mappings, index_mappings=index_mappings),
+ functools.partial(_extract_nodedefs, nodedefs=self.nodedefs),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)
@@ -484,6 +484,7 @@ class FwdFn:
nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
+ nodedefs: deque[graph.NodeDef]
def __post_init__(self):
functools.update_wrapper(self, self.fwd)
@@ -503,6 +504,7 @@ def __call__(self, *pure_args):
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag if update_context_active else None,
+ is_inner=True,
)
out, residual = self.fwd(*args)
@@ -519,14 +521,9 @@ def __call__(self, *pure_args):
pure_residual = extract.to_tree(residual)
if update_context_active:
- # remove index_mapping from NodeDef's but store them in global context
- index_mappings: deque[graph.HashableMapping] = (
- extract.get_broadcast_state(self.ctxtag)
- )
+ # remove outer_index from NodeDef's but store them in global context
pure_args_out, pure_out = jax.tree.map(
- functools.partial(
- _extract_index_mappings, index_mappings=index_mappings
- ),
+ functools.partial(_extract_nodedefs, nodedefs=self.nodedefs),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)
@@ -544,7 +541,7 @@ def __post_init__(self):
def __call__(self, *args):
*nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args
- residual = extract.from_tree(pure_residual)
+ residual = extract.from_tree(pure_residual, is_inner=True)
(pure_args_out_g, pure_out_g) = jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
(pure_args_out_g, pure_out_g),
@@ -632,40 +629,41 @@ def __call__(
for i, x in enumerate(tree_node_args)
if i not in self.jax_nondiff_argnums
)
- index_mappings: deque[graph.HashableMapping] = deque()
- with extract.broadcast_state(self.ctxtag, index_mappings):
- if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
- raise ValueError()
-
- custom_vjp_fn = jax.custom_vjp(
- fun=CustomVjpFnWrapper(
- f=self.fun,
- jax_nondiff_argnums=self.jax_nondiff_argnums,
- ctxtag=self.ctxtag,
- nondiff_states=nondiff_states,
- ),
+ nodedefs: deque[graph.NodeDef] = deque()
+ if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
+ raise ValueError()
+
+ custom_vjp_fn = jax.custom_vjp(
+ fun=CustomVjpFnWrapper(
+ f=self.fun,
+ jax_nondiff_argnums=self.jax_nondiff_argnums,
+ ctxtag=self.ctxtag,
+ nondiff_states=nondiff_states,
+ nodedefs=nodedefs,
+ ),
+ nondiff_argnums=self.jax_nondiff_argnums,
+ )
+ custom_vjp_fn.defvjp(
+ fwd=FwdFn(
+ fwd=self.fwd,
nondiff_argnums=self.jax_nondiff_argnums,
- )
- custom_vjp_fn.defvjp(
- fwd=FwdFn(
- fwd=self.fwd,
- nondiff_argnums=self.jax_nondiff_argnums,
- ctxtag=self.ctxtag,
- nondiff_states=nondiff_states,
- ),
- bwd=BwdFn(
- bwd=self.bwd,
- tree_node_args=tree_node_args,
- ),
- symbolic_zeros=self.symbolic_zeros,
- )
- pure_args_out, pure_out = custom_vjp_fn(*pure_args)
+ ctxtag=self.ctxtag,
+ nondiff_states=nondiff_states,
+ nodedefs=nodedefs,
+ ),
+ bwd=BwdFn(
+ bwd=self.bwd,
+ tree_node_args=tree_node_args,
+ ),
+ symbolic_zeros=self.symbolic_zeros,
+ )
+ pure_args_out, pure_out = custom_vjp_fn(*pure_args)
# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graph.NodeDef):
- index_mapping: graph.HashableMapping = index_mappings.popleft()
- return dataclasses.replace(x, index_mapping=index_mapping)
+ nodedef: graph.NodeDef = nodedefs.popleft()
+ return nodedef
return x
pure_args_out, pure_out = jax.tree_util.tree_map(
@@ -675,7 +673,7 @@ def _insert_index_mappings(x):
)
args_out, out = extract.from_tree(
- (pure_args_out, pure_out), ctxtag=self.ctxtag
+ (pure_args_out, pure_out), ctxtag=self.ctxtag, is_inner=False
)
return out
diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py
index e5ce20f8e..32667d1e9 100644
--- a/flax/nnx/transforms/compilation.py
+++ b/flax/nnx/transforms/compilation.py
@@ -17,6 +17,9 @@
import functools
import typing as tp
+import jax.experimental
+import jax.experimental.shard_map
+
from flax.nnx import (
extract,
filterlib,
@@ -27,11 +30,13 @@
import jax
import jax.core
import jax.stages
+from jax._src.mesh import Mesh, AbstractMesh
from flax.typing import Missing
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
-
+Specs = tp.Any
+AxisName = tp.Hashable
# -------------------------------
# jit
@@ -90,10 +95,15 @@ def __hash__(self):
def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x):
if isinstance(prefix, StateSharding):
- return extract.NodeStates.from_split(
- *ctx.split(x, *prefix.filters), metadata=prefix
- )
- return extract.NodeStates.from_split(*ctx.split(x))
+ graphdef, *states = ctx.flatten(x, *prefix.filters)
+ return extract.NodeStates.from_split(graphdef, *states, metadata=prefix)
+ return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))
+
+
+def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any:
+ if not isinstance(leaf, extract.NodeStates):
+ raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
+ return ctx.unflatten(leaf.graphdef, *leaf.states)
@dataclasses.dataclass(eq=False)
@@ -102,12 +112,18 @@ class JitFn:
in_shardings: tp.Any
out_shardings: tp.Any
kwarg_shardings: tp.Any
+ ctxtag: tp.Hashable
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args, **pure_kwargs):
- args, kwargs = extract.from_tree((pure_args, pure_kwargs), ctxtag='jit')
+ args, kwargs = extract.from_tree(
+ (pure_args, pure_kwargs),
+ merge_fn=_jit_merge_fn,
+ ctxtag=self.ctxtag,
+ is_inner=True,
+ )
out = self.f(*args, **kwargs)
@@ -115,7 +131,7 @@ def __call__(self, *pure_args, **pure_kwargs):
pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
(args_out, kwargs_out, out),
prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings),
- ctxtag='jit',
+ ctxtag=self.ctxtag,
split_fn=_jit_split_fn,
)
@@ -317,8 +333,32 @@ def jit(
out_shardings,
)
+ @functools.wraps(fun)
+ def jit_wrapper(*args, **kwargs):
+ # run dynamic_cache_context before update_context
+ with graph.update_context(jit_wrapper, use_dynamic_cache=True):
+ pure_args, pure_kwargs = extract.to_tree(
+ (args, kwargs),
+ prefix=(in_shardings, kwarg_shardings)
+ if in_shardings is not None or kwarg_shardings is not None
+ else None,
+ split_fn=_jit_split_fn,
+ check_aliasing=in_shardings is not None or kwarg_shardings is not None,
+ ctxtag=jit_wrapper,
+ )
+ pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
+ *pure_args, **pure_kwargs
+ )
+ _args_out, _kwargs_out, out = extract.from_tree(
+ (pure_args_out, pure_kwargs_out, pure_out),
+ merge_fn=_jit_merge_fn,
+ is_inner=False,
+ ctxtag=jit_wrapper,
+ )
+ return out
+
jitted_fn = jax.jit(
- JitFn(fun, in_shardings, out_shardings, kwarg_shardings),
+ JitFn(fun, in_shardings, out_shardings, kwarg_shardings, jit_wrapper),
in_shardings=jax_in_shardings,
out_shardings=(jax_in_shardings, kwarg_shardings, jax_out_shardings), # type: ignore
static_argnums=static_argnums,
@@ -332,24 +372,140 @@ def jit(
abstracted_axes=abstracted_axes,
)
- @functools.wraps(fun)
- @graph.update_context('jit')
- def jit_wrapper(*args, **kwargs):
- pure_args, pure_kwargs = extract.to_tree(
- (args, kwargs),
- prefix=(in_shardings, kwarg_shardings),
+ jit_wrapper.inner = jitted_fn # type: ignore
+
+ return jit_wrapper # type: ignore
+
+# -------------------------------
+# shard_map
+# -------------------------------
+
+# TODO: create StateSpec and consider enabling a mode that does
+# not use filters during split for performance. Overall there might
+# be performance limitations for using shard_map at a top-level
+
+@dataclasses.dataclass(eq=False)
+class ShardMapFn:
+ f: tp.Callable[..., tp.Any]
+ in_specs: tp.Any
+ out_specs: tp.Any
+ kwarg_specs: tp.Any
+ ctxtag: tp.Hashable
+
+ def __post_init__(self):
+ functools.update_wrapper(self, self.f)
+
+ def __call__(self, *pure_args, **pure_kwargs):
+ args, kwargs = extract.from_tree(
+ (pure_args, pure_kwargs),
+ merge_fn=_jit_merge_fn,
+ ctxtag=self.ctxtag,
+ is_inner=True,
+ )
+
+ out = self.f(*args, **kwargs)
+
+ args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
+ pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
+ (args_out, kwargs_out, out),
+ prefix=(self.in_specs, self.kwarg_specs, self.out_specs),
+ ctxtag=self.ctxtag,
split_fn=_jit_split_fn,
- check_aliasing=in_shardings is not None,
- ctxtag='jit',
)
- pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
- *pure_args, **pure_kwargs
+
+ return pure_args_out, pure_kwargs_out, pure_out
+
+
+@tp.overload
+def shard_map(
+ f: F,
+ *,
+ mesh: Mesh | AbstractMesh,
+ in_specs: Specs,
+ out_specs: Specs,
+ check_rep: bool = True,
+ auto: frozenset[AxisName] = frozenset(),
+) -> F: ...
+@tp.overload
+def shard_map(
+ *,
+ mesh: Mesh | AbstractMesh,
+ in_specs: Specs,
+ out_specs: Specs,
+ check_rep: bool = True,
+ auto: frozenset[AxisName] = frozenset(),
+) -> tp.Callable[[F], F]: ...
+def shard_map(
+ f: F | type[Missing] = Missing,
+ *,
+ mesh: Mesh | AbstractMesh,
+ in_specs: Specs,
+ out_specs: Specs,
+ check_rep: bool = True,
+ auto: frozenset[AxisName] = frozenset(),
+) -> F | tp.Callable[[F], F]:
+ if f is Missing:
+ return functools.partial(
+ shard_map,
+ mesh=mesh,
+ in_specs=in_specs,
+ out_specs=out_specs,
+ check_rep=check_rep,
+ auto=auto,
+ ) # type: ignore[return-value]
+ assert not isinstance(f, type)
+
+ kwarg_specs = PartitionSpec()
+ jax_in_specs = jax.tree.map(
+ lambda x: extract.NodeStates(
+ _graphdef=PartitionSpec(), states=x.shardings, metadata=x
)
- _args_out, _kwargs_out, out = extract.from_tree(
- (pure_args_out, pure_kwargs_out, pure_out), ctxtag='jit'
+ if isinstance(x, StateSharding)
+ else x,
+ in_specs,
+ )
+ jax_out_specs = jax.tree.map(
+ lambda x: extract.NodeStates(
+ _graphdef=PartitionSpec(), states=x.shardings, metadata=x
)
+ if isinstance(x, StateSharding)
+ else x,
+ out_specs,
+ )
+
+ @functools.wraps(f)
+ def shard_map_wrapper(*args, **kwargs):
+ # run dynamic_cache_context before update_context
+ with graph.update_context(shard_map_wrapper, use_dynamic_cache=True):
+ pure_args, pure_kwargs = extract.to_tree(
+ (args, kwargs),
+ prefix=(in_specs, kwarg_specs)
+ if in_specs is not None or kwarg_specs is not None
+ else None,
+ split_fn=_jit_split_fn,
+ check_aliasing=in_specs is not None or kwarg_specs is not None,
+ ctxtag=shard_map_wrapper,
+ )
+ pure_args_out, pure_kwargs_out, pure_out = shard_map_fn(
+ *pure_args, **pure_kwargs
+ )
+ _args_out, _kwargs_out, out = extract.from_tree(
+ (pure_args_out, pure_kwargs_out, pure_out),
+ merge_fn=_jit_merge_fn,
+ is_inner=False,
+ ctxtag=shard_map_wrapper,
+ )
return out
- jit_wrapper.inner = jitted_fn # type: ignore
+ shard_map_fn = jax.experimental.shard_map.shard_map(
+ ShardMapFn(f, in_specs, out_specs, kwarg_specs, shard_map_wrapper),
+ mesh=mesh,
+ in_specs=jax_in_specs,
+ out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore
+ check_rep=check_rep,
+ auto=auto,
+ )
- return jit_wrapper # type: ignore
+ shard_map_wrapper.inner = shard_map_fn # type: ignore
+
+ return shard_map_wrapper # type: ignore
\ No newline at end of file
diff --git a/flax/nnx/transforms/general.py b/flax/nnx/transforms/general.py
index fa82cd890..553c3e892 100644
--- a/flax/nnx/transforms/general.py
+++ b/flax/nnx/transforms/general.py
@@ -151,7 +151,9 @@ def split_inputs(
def split_inputs_wrapper(*args):
pure_args = extract.to_tree(args, ctxtag=ctxtag)
pure_args_out, pure_out = f(*pure_args)
- args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag=ctxtag)
+ args_out, out = extract.from_tree(
+ (pure_args_out, pure_out), ctxtag=ctxtag, is_inner=False
+ )
return out
return split_inputs_wrapper # type: ignore
@@ -192,7 +194,7 @@ def merge_inputs(
@functools.wraps(f)
def merge_inputs_wrapper(*pure_args):
- args = extract.from_tree(pure_args, ctxtag=ctxtag)
+ args = extract.from_tree(pure_args, ctxtag=ctxtag, is_inner=True)
out = f(*args)
args_out = extract.clear_non_graph_nodes(args)
pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag=ctxtag)
diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py
index 994e58286..62c41a3f9 100644
--- a/flax/nnx/transforms/iteration.py
+++ b/flax/nnx/transforms/iteration.py
@@ -165,7 +165,7 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
- args = extract.from_tree(pure_args, ctxtag='vmap')
+ args = extract.from_tree(pure_args, ctxtag='vmap', is_inner=True)
out = self.f(*args)
@@ -343,7 +343,9 @@ def vmap_wrapper(*args, **kwargs):
args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap'
)
pure_args_out, pure_out = vmapped_fn(*pure_args)
- _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap')
+ _args_out, out = extract.from_tree(
+ (pure_args_out, pure_out), ctxtag='vmap', is_inner=False
+ )
return out
return vmap_wrapper # type: ignore
@@ -369,7 +371,7 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
- args = extract.from_tree(pure_args, ctxtag='pmap')
+ args = extract.from_tree(pure_args, ctxtag='pmap', is_inner=True)
out = self.f(*args)
@@ -566,7 +568,9 @@ def vmap_wrapper(*args):
args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap'
)
pure_args_out, pure_out = pmapped_fn(*pure_args)
- _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='pmap')
+ _args_out, out = extract.from_tree(
+ (pure_args_out, pure_out), ctxtag='pmap', is_inner=False
+ )
return out
return vmap_wrapper # type: ignore
@@ -648,21 +652,17 @@ def check_carry_same_references(key_path, arg, out):
check_carry_same_references, carry_arg, carry_arg_out
)
-def _extract_index_mappings(
- pure_carry_arg_out,
- carry_index_mappings: list[graph.HashableMapping[int, int]],
- /,
+def _extract_nodedefs(
+ pure_carry_arg_out, carry_nodedefs: list[graph.NodeDef], /
):
def extract_index_mappings(x):
if isinstance(x, extract.NodeStates) and isinstance(
x._graphdef, graph.NodeDef
):
- index_mapping = x._graphdef.index_mapping
- assert index_mapping is not None
- carry_index_mappings.append(index_mapping)
- x = x.replace(
- _graphdef=dataclasses.replace(x._graphdef, index_mapping=None)
- )
+ nodedef = x._graphdef
+ assert nodedef.outer_index is not None
+ carry_nodedefs.append(nodedef)
+ x = x.replace(_graphdef=nodedef.with_no_outer_index())
return x
pure_carry_arg_out = jax.tree.map(
@@ -673,19 +673,17 @@ def extract_index_mappings(x):
return pure_carry_arg_out
-def _insert_index_mappings(
+def _insert_nodedefs(
pure_carry_arg_out,
- carry_index_mappings: deque[graph.HashableMapping[int, int]],
+ carry_nodedefs: deque[graph.NodeDef],
/,
):
def insert_index_mappings(x):
if isinstance(x, extract.NodeStates) and isinstance(
x._graphdef, graph.NodeDef
):
- index_mapping = carry_index_mappings.popleft()
- x = x.replace(
- _graphdef=dataclasses.replace(x._graphdef, index_mapping=index_mapping)
- )
+ nodedef = carry_nodedefs.popleft()
+ x = x.replace(_graphdef=nodedef)
return x
pure_carry_arg_out = jax.tree.map(
@@ -1017,6 +1015,7 @@ def __call__(
is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)),
map_non_graph_nodes=True,
ctxtag='scan',
+ is_inner=True,
)
assert not carry_deque and not broadcast_deque and not broadcast_arrays
@@ -1096,10 +1095,8 @@ def __call__(
# next we have to remove all the index_mappings from the NodeDefs
# in the carry outputs because they are not present in the inputs
- carry_index_mappings: list[graph.HashableMapping[int, int]] = []
- pure_carry_arg_out = _extract_index_mappings(
- pure_carry_arg_out, carry_index_mappings
- )
+ carry_nodedefs: list[graph.NodeDef] = []
+ pure_carry_arg_out = _extract_nodedefs(pure_carry_arg_out, carry_nodedefs)
carry_arg_out = (
pure_carry_arg_out,
@@ -1108,7 +1105,7 @@ def __call__(
broadcast_arrays_out,
)
scan_out = (
- graph.Static(tuple(carry_index_mappings)),
+ carry_nodedefs,
pure_args_out,
pure_out,
)
@@ -1248,16 +1245,15 @@ def scan_wrapper(*args, **kwargs):
broadcast_arrays_out,
) = carry_out
(
- static_carry_index_mappings,
+ carry_nodedefs,
pure_args_out,
pure_out,
) = scan_out
# next we have to insert all the index_mappings back into the NodeDefs
# in the carry outputs
- carry_index_mappings = deque(static_carry_index_mappings.value)
- pure_carry_arg_out = _insert_index_mappings(
- pure_carry_arg_out, carry_index_mappings
+ pure_carry_arg_out = _insert_nodedefs(
+ pure_carry_arg_out, deque(carry_nodedefs)
)
# insert pure carry into pure_args_out
@@ -1280,6 +1276,7 @@ def scan_wrapper(*args, **kwargs):
is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)),
map_non_graph_nodes=True,
ctxtag='scan',
+ is_inner=False,
)
# extract the carry from args_out
@@ -1329,36 +1326,14 @@ def __call__(self, pure_val):
def _add_fake_index_mapping(tree: tp.Any):
- global_index_mapping = {} # for the whole context, over all inputs
- def per_node_state(ns: extract.NodeStates | tp.Any):
- if not isinstance(ns, extract.NodeStates) or not isinstance(
- ns._graphdef, graph.NodeDef
+ def per_node_state(node_state: extract.NodeStates | tp.Any):
+ if not isinstance(node_state, extract.NodeStates) or not isinstance(
+ node_state._graphdef, graph.NodeDef
):
- return ns
-
- def per_node_def(nd: graph.NodeDef | graph.NodeRef):
- if nd.index >= 0:
- global_index_mapping[nd.index] = nd.index
- if isinstance(nd, graph.NodeRef):
- return
-
- for attribute in nd.attributes:
- if type(attribute) is graph.SubGraphAttribute:
- per_node_def(attribute.value)
- elif (
- type(attribute) is graph.LeafAttribute
- and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef))
- and attribute.value.index >= 0
- ):
- global_index_mapping[attribute.value.index] = attribute.value.index
- return
-
- per_node_def(ns._graphdef)
+ return node_state
+
return dataclasses.replace(
- ns,
- _graphdef=dataclasses.replace(
- ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping)
- ),
+ node_state, _graphdef=node_state._graphdef.with_same_outer_index()
)
return jax.tree.map(per_node_state, tree,
@@ -1366,16 +1341,18 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef):
def _remove_index_mapping(tree: tp.Any):
- '''Remove a fake index_mapping for the input to match that of the output.'''
- def per_node_state(ns: extract.NodeStates | tp.Any):
- if not isinstance(ns, extract.NodeStates) or not isinstance(
- ns._graphdef, graph.NodeDef
+ """Remove a fake outer_index for the input to match that of the output."""
+
+ def per_node_state(node_state: extract.NodeStates | tp.Any):
+ if not isinstance(node_state, extract.NodeStates) or not isinstance(
+ node_state._graphdef, graph.NodeDef
):
- return ns
- assert isinstance(ns._graphdef, graph.NodeDef)
- return dataclasses.replace(ns, _graphdef=dataclasses.replace(
- ns._graphdef, index_mapping=None
- ))
+ return node_state
+ assert isinstance(node_state._graphdef, graph.NodeDef)
+ node_state = dataclasses.replace(
+ node_state, _graphdef=node_state._graphdef.with_no_outer_index()
+ )
+ return node_state
return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
@@ -1393,19 +1370,23 @@ def __call__(self, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)
- val = extract.from_tree(pure_val_in, ctxtag='while_loop_body')
+ val = extract.from_tree(
+ pure_val_in, ctxtag='while_loop_body', is_inner=True
+ )
out = self.f(val)
pure_out = extract.to_tree(out, ctxtag='while_loop_body')
try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
- msg = ("nnx.while_loop requires body function's input and output to "
- "have the same reference and pytree structure, but they differ. "
- "If the mismatch comes from `index_mapping` field, you might "
- "have modified reference structure within the body function, "
- "which is not allowed."
- f"Detail of the mismatch: \n {str(e)}")
+ msg = (
+ "nnx.while_loop requires body function's input and output to "
+ 'have the same reference and pytree structure, but they differ. '
+ 'If the mismatch comes from `outer_index` field, you might '
+ 'have modified reference structure within the body function, '
+ 'which is not allowed.'
+ f'Detail of the mismatch: \n {str(e)}'
+ )
raise ValueError(msg)
return pure_out
@@ -1456,7 +1437,7 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
WhileLoopBodyFn(body_fun),
pure_init_val,
)
- out = extract.from_tree(pure_out, ctxtag='while_loop')
+ out = extract.from_tree(pure_out, ctxtag='while_loop', is_inner=False)
return out
@@ -1472,19 +1453,21 @@ def __call__(self, i, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)
- val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body')
+ val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body', is_inner=True)
out = self.f(i, val)
pure_out = extract.to_tree(out, ctxtag='fori_loop_body')
try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
- msg = ("nnx.fori_loop requires body function's input and output to "
- "have the same reference and pytree structure, but they differ. "
- "If the mismatch comes from `index_mapping` field, you might "
- "have modified reference structure within the body function, "
- "which is not allowed. "
- f"Detail of the mismatch: \n {str(e)}")
+ msg = (
+ "nnx.fori_loop requires body function's input and output to "
+ 'have the same reference and pytree structure, but they differ. '
+ 'If the mismatch comes from `outer_index` field, you might '
+ 'have modified reference structure within the body function, '
+ 'which is not allowed. '
+ f'Detail of the mismatch: \n {str(e)}'
+ )
raise ValueError(msg)
return pure_out
@@ -1545,5 +1528,5 @@ def fori_loop(lower: int, upper: int,
pure_out = jax.lax.fori_loop(lower, upper,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
- out = extract.from_tree(pure_out, ctxtag='fori_loop')
+ out = extract.from_tree(pure_out, ctxtag='fori_loop', is_inner=False)
return out
diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py
index 8a83a026d..3192b31aa 100644
--- a/flax/nnx/transforms/transforms.py
+++ b/flax/nnx/transforms/transforms.py
@@ -160,7 +160,7 @@ def __post_init__(self):
def __call__(self, *pure_args, **pure_kwargs):
args, kwargs = extract.from_tree(
- (pure_args, pure_kwargs), ctxtag='checkify'
+ (pure_args, pure_kwargs), ctxtag='checkify', is_inner=True
)
out = self.f(*args, **kwargs)
@@ -216,6 +216,7 @@ def jit_wrapper(*args, **kwargs):
args_out, kwargs_out, out = extract.from_tree(
(pure_args_out, pure_kwargs_out, pure_out),
ctxtag='checkify',
+ is_inner=False,
)
return error, out
diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py
index b2c066096..2908c074a 100644
--- a/flax/nnx/variablelib.py
+++ b/flax/nnx/variablelib.py
@@ -47,7 +47,6 @@
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
-
@dataclasses.dataclass
class VariableMetadata(tp.Generic[A]):
raw_value: A
@@ -125,6 +124,8 @@ class Variable(tp.Generic[A], reprlib.Representable):
})
"""
+ __slots__ = ('raw_value', '_trace_state', '_var_metadata')
+
raw_value: A
_trace_state: tracers.TraceState
_var_metadata: dict[str, tp.Any]
@@ -134,9 +135,8 @@ def __init__(
value: tp.Union[A, VariableMetadata[A]],
**metadata: tp.Any,
):
- type_vars = vars(type(self))
- vars_self = vars(self)
- vars_self['_trace_state'] = tracers.TraceState()
+ var_t = type(self)
+ object.__setattr__(self, '_trace_state', tracers.TraceState())
if isinstance(value, VariableMetadata):
metadata.update(value.metadata)
@@ -144,27 +144,28 @@ def __init__(
object.__setattr__(self, 'raw_value', value)
- if 'on_get_value' in type_vars and 'on_get_value' not in metadata:
- metadata['get_value'] = getattr(type(self), 'on_get_value')
+ if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata:
+ metadata['get_value'] = var_t.on_get_value
- if 'on_set_value' in type_vars and 'on_set_value' not in metadata:
- metadata['set_value'] = getattr(type(self), 'on_set_value')
+ if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata:
+ metadata['set_value'] = var_t.on_set_value
- if 'on_create_value' in type_vars and 'on_create_value' not in metadata:
- metadata['create_value'] = getattr(type(self), 'on_create_value')
+ if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata:
+ metadata['create_value'] = var_t.on_create_value
- if 'on_add_axis' in type_vars and 'on_add_axis' not in metadata:
- metadata['add_axis'] = getattr(type(self), 'on_add_axis')
+ if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata:
+ metadata['add_axis'] = var_t.on_add_axis
- if 'on_remove_axis' in type_vars and 'on_remove_axis' not in metadata:
- metadata['remove_axis'] = getattr(type(self), 'on_remove_axis')
+ if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata:
+ metadata['remove_axis'] = var_t.on_remove_axis
- vars_self['_var_metadata'] = metadata
+ object.__setattr__(self, '_var_metadata', metadata)
# run create_value hooks
- vars_self['raw_value'] = self.create_value(self.raw_value)
+ object.__setattr__(self, 'raw_value', self.create_value(self.raw_value))
+
def __getattr__(self, name: str) -> tp.Any:
- if name in vars(self)['_var_metadata']:
+ if name in object.__getattribute__(self, '_var_metadata'):
return self._var_metadata[name]
return getattr(self.value, name)
@@ -220,9 +221,10 @@ def copy_from(self, other: Variable[A]) -> None:
self._var_metadata.update(other.get_metadata())
def update_from_state(self, variable_state: VariableState[A]):
- vars_self = vars(self)
- vars_self['raw_value'] = variable_state.value
- vars_self['_var_metadata'] = variable_state._var_metadata.copy()
+ object.__setattr__(self, 'raw_value', variable_state.value)
+ object.__setattr__(
+ self, '_var_metadata', variable_state._var_metadata.copy()
+ )
@property
def value(self) -> A:
@@ -239,7 +241,7 @@ def value(self, value: A):
)
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
- vars(self)['raw_value'] = value
+ object.__setattr__(self, 'raw_value', value)
def create_value(self, value: A):
if 'on_create_value' in self._var_metadata:
@@ -254,9 +256,6 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
- def __eq__(self, other: object) -> bool:
- return type(self) is type(other) and vars(other) == vars(self)
-
@tp.overload
def replace(self, value: B, **kwargs) -> Variable[B]: ...
@@ -369,10 +368,16 @@ def __jax_array__(self):
# pickle support
def __getstate__(self):
- return vars(self).copy()
+ return {
+ 'raw_value': self.raw_value,
+ '_trace_state': self._trace_state,
+ '_var_metadata': self._var_metadata,
+ }
def __setstate__(self, state):
- vars(self).update(state)
+ object.__setattr__(self, 'raw_value', state['raw_value'])
+ object.__setattr__(self, '_trace_state', state['_trace_state'])
+ object.__setattr__(self, '_var_metadata', state['_var_metadata'])
# --------------------------------------------
# proxy methods
@@ -841,6 +846,7 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
+GraphVariableState = VariableState[VariableState[tp.Any]]
def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
metadata = tuple(x.get_metadata().items())
@@ -944,7 +950,7 @@ def wrapper(*args):
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
- filters: tp.Sequence[filterlib.Filter],
+ filters: tuple[filterlib.Filter, ...],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
diff --git a/flax/typing.py b/flax/typing.py
index 0ae990d95..0f694383f 100644
--- a/flax/typing.py
+++ b/flax/typing.py
@@ -168,11 +168,11 @@ class Missing:
def _bytes_repr(num_bytes):
count, units = (
- (f'{num_bytes / 1e9 :,.1f}', 'GB')
+ (f'{num_bytes / 1e9:,.1f}', 'GB')
if num_bytes > 1e9
- else (f'{num_bytes / 1e6 :,.1f}', 'MB')
+ else (f'{num_bytes / 1e6:,.1f}', 'MB')
if num_bytes > 1e6
- else (f'{num_bytes / 1e3 :,.1f}', 'KB')
+ else (f'{num_bytes / 1e3:,.1f}', 'KB')
if num_bytes > 1e3
else (f'{num_bytes:,}', 'B')
)
diff --git a/flaxlib_src/CMakeLists.txt b/flaxlib_src/CMakeLists.txt
new file mode 100644
index 000000000..a5a61b5b2
--- /dev/null
+++ b/flaxlib_src/CMakeLists.txt
@@ -0,0 +1,54 @@
+# Set the minimum CMake version and policies for highest tested version
+cmake_minimum_required(VERSION 3.15...3.27)
+
+# Set up the project and ensure there is a working C++ compiler
+project(flaxlib LANGUAGES CXX)
+
+# Warn if the user invokes CMake directly
+if (NOT SKBUILD)
+ message(WARNING "\
+ This CMake file is meant to be executed using 'scikit-build-core'.
+ Running it directly will almost certainly not produce the desired
+ result. If you are a user trying to install this package, use the
+ command below, which will install all necessary build dependencies,
+ compile the package in an isolated environment, and then install it.
+ =====================================================================
+ $ pip install .
+ =====================================================================
+ If you are a software developer, and this is your own package, then
+ it is usually much more efficient to install the build dependencies
+ in your environment once and use the following command that avoids
+ a costly creation of a new virtual environment at every compilation:
+ =====================================================================
+ $ pip install nanobind scikit-build-core[pyproject]
+ $ pip install --no-build-isolation -ve .
+ =====================================================================
+ You may optionally add -Ceditable.rebuild=true to auto-rebuild when
+ the package is imported. Otherwise, you need to rerun the above
+ after editing C++ files.")
+endif()
+
+# Try to import all Python components potentially needed by nanobind
+find_package(Python 3.8
+ REQUIRED COMPONENTS Interpreter Development.Module
+ OPTIONAL_COMPONENTS Development.SABIModule)
+
+# Import nanobind through CMake's find_package mechanism
+find_package(nanobind CONFIG REQUIRED)
+
+# We are now ready to compile the actual extension module
+nanobind_add_module(
+ # Name of the extension
+ flaxlib_cpp
+
+ # Target the stable ABI for Python 3.12+, which reduces
+ # the number of binary wheels that must be built. This
+ # does nothing on older Python versions
+ STABLE_ABI
+
+ # Source code goes here
+ src/lib.cc
+)
+
+# Install directive for scikit-build-core
+install(TARGETS flaxlib_cpp LIBRARY DESTINATION flaxlib)
\ No newline at end of file
diff --git a/flaxlib_src/meson.build b/flaxlib_src/meson.build
deleted file mode 100644
index 0d78d9436..000000000
--- a/flaxlib_src/meson.build
+++ /dev/null
@@ -1,14 +0,0 @@
-project(
- 'flaxlib',
- 'cpp',
- version: '0.0.1',
- default_options: ['cpp_std=c++17'],
-)
-py = import('python').find_installation()
-nanobind_dep = dependency('nanobind', static: true)
-py.extension_module(
- 'flaxlib',
- sources: ['src/lib.cc'],
- dependencies: [nanobind_dep],
- install: true,
-)
\ No newline at end of file
diff --git a/flaxlib_src/pyproject.toml b/flaxlib_src/pyproject.toml
index 0afc7699a..fd6c0b61b 100644
--- a/flaxlib_src/pyproject.toml
+++ b/flaxlib_src/pyproject.toml
@@ -1,17 +1,28 @@
[build-system]
-requires = ['meson-python']
-build-backend = 'mesonpy'
+requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"]
+build-backend = "scikit_build_core.build"
[project]
name = "flaxlib"
+version = "0.0.1"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: C++",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
-dynamic = ["version"]
+
[project.optional-dependencies]
tests = [
"pytest",
]
+
+[tool.scikit-build]
+# Protect the configuration against future changes in scikit-build-core
+minimum-version = "0.4"
+
+# Setuptools-style build caching in a local directory
+build-dir = "build/{wheel_tag}"
+
+# Build stable ABI wheels for CPython 3.12+
+wheel.py-api = "cp312"
\ No newline at end of file
diff --git a/flaxlib_src/flaxlib.pyi b/flaxlib_src/src/flaxlib/__init__.py
similarity index 84%
rename from flaxlib_src/flaxlib.pyi
rename to flaxlib_src/src/flaxlib/__init__.py
index 505fd3d0f..f45841771 100644
--- a/flaxlib_src/flaxlib.pyi
+++ b/flaxlib_src/src/flaxlib/__init__.py
@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-def sum_as_string(a: int, b: int) -> str: ...
+from .flaxlib_cpp import RefMap as RefMap
+from .flaxlib_cpp import _graph_fingerprint as _graph_fingerprint
diff --git a/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi
new file mode 100644
index 000000000..03557efb9
--- /dev/null
+++ b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi
@@ -0,0 +1,25 @@
+# Copyright 2024 The Flax Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import typing as tp
+
+RefMap = tp.MutableMapping[tp.Any, int]
+
+def _graph_fingerprint(
+ node,
+ node_impl,
+ ref_index: RefMap,
+ new_ref_index: RefMap,
+ next_index: int,
+) -> tuple[tuple[tp.Any, ...], int]: ...
\ No newline at end of file
diff --git a/flaxlib_src/src/lib.cc b/flaxlib_src/src/lib.cc
index c71458811..c91572703 100644
--- a/flaxlib_src/src/lib.cc
+++ b/flaxlib_src/src/lib.cc
@@ -1,14 +1,298 @@
+// Copyright 2024 The Flax Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include