diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index 7fb5972da1..c10c1fc486 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -65,12 +65,12 @@ In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, The [Functional API](#functional-api) converts an NNX Module python semantics into pure pytree object with functional semantics. It is the recommended way to use NNX as it provides tight control over the state, allows you to use regular JAX transformations, and it minimizes overhead. In this example the model will be trained using Stochastic Gradient Descent (SGD). ```python -params, counts, moduledef = model.split(nnx.Param, Count) +params, counts, static = model.split(nnx.Param, Count) @jax.jit def train_step(params, counts, x, y): def loss_fn(params): - model = moduledef.merge(params, counts) + model = static.merge(params, counts) y_pred = model(x) loss = jax.numpy.mean((y_pred - y) ** 2) return loss, updates.extract(Count) @@ -84,7 +84,7 @@ def train_step(params, counts, x, y): # execute the training step params, counts = train_step(params, counts, x, y) -model = moduledef.merge(params, counts) +model = static.merge(params, counts) assert model.count == 2 ``` @@ -141,51 +141,52 @@ NNX Modules are normal python classes, they obey regular python semantics such a ```python class Foo(nnx.Module): - def __init__(self, rngs: nnx.Rngs): - # node attributes - self.variable = nnx.Param(jnp.array(1)) - self.implicit_param = jnp.array(3) - self.submodule = nnx.Linear(2, 4, rngs=rngs) - # static attributes - self.int = 1 - self.float = 2.0 - self.str = "hello" - self.list = [1, 2, 3] + def __init__(self, rngs: nnx.Rngs): + # node attributes + self.variable = nnx.Param(jnp.array(1)) + self.submodule = nnx.Linear(2, 3, rngs=rngs) + self.container = [4, nnx.Linear(5, 6, rngs=rngs), 7] + # static attributes + self.int = 8 + self.float = 9.0 + self.str = 'hello' model = Foo(din=12, dout=2, rngs=nnx.Rngs(0)) ``` -As shown above, python container types such as `list`, `tuple`, and `dict` are treated as static attributes, if similar functionality is needed, NNX provides the `Sequence` and `Dict` Modules. +As shown above, python container types such as `list`, `tuple`, and `dict` are treated as node attributes, +this means you can naturally have e.g. `list`s or `dict`s of Modules. ### Functional API -NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `ModuleDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `ModuleDef` contains the static attributes and structural information needed to reconstruct the Module. +NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `GraphDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `GraphDef` contains the static attributes and structural information needed to reconstruct the Module. ```python -state, moduledef = model.split() +state, static = model.split() ``` ``` -State({ - 'implicit_param',: Param(value=Array(3)), - 'submodule/bias': Param(value=Array(...)), - 'submodule/kernel': Param(value=Array(...)), - 'variable': Param(value=Array(1)) +state = State({ + 'variable': Array(1, dtype=int32, weak_type=True), + 'submodule/kernel': Array(..., dtype=float32), + 'submodule/bias': Array(..., dtype=float32), + 'container/1/kernel': Array(..., dtype=float32), + 'container/1/bias': Array(..., dtype=float32) }) ``` -`State` and `ModuleDef` are pytrees so they can be passed to JAX transformations. More over, `ModuleDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object: +`State` and `GraphDef` are pytrees so they can be passed to JAX transformations. More over, `GraphDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object: ```python -model = moduledef.merge(state) +model = static.merge(state) ``` This can be use to e.g. recreate a module inside a JAX transformation. The `apply` provides a functional interface to the module, it can be used call any method or submodule and get the output and the updated state: ```python # run __call__ -y, (state, moduledef) = moduledef.apply(state)(x) +y, (state, static) = static.apply(state)(x) # run some_method -y, (state, moduledef) = moduledef.apply(state).some_method(x) +y, (state, static) = static.apply(state).some_method(x) # run submodule -y, (state, moduledef) = moduledef.apply(state).submodule(x) +y, (state, static) = static.apply(state).submodule(x) ``` `apply` can call any nested method or submodule as long as it can be accessed via the `.` or `[]` operators. @@ -196,14 +197,14 @@ In NNX you can filter based on any node type, most commonly you will want to fil Here are various examples of how you can use the `split` method to split a module into multiple substates: ```python -# split the module into the state with all the nodes and the moduledef -state, moduledef = model.split() +# split the module into the state with all the nodes and the static information +state, static = model.split() # verify that the state contains only params, else raise an error -params, moduledef = model.split(nnx.Param) +params, static = model.split(nnx.Param) # split the state into params and batch_stats, verify no nodes are left -params, batch_stats, moduledef = model.split(nnx.Param, nnx.BatchStat) +params, batch_stats, static = model.split(nnx.Param, nnx.BatchStat) # if there are any nodes left, use the `...` filter to capture them -params, batch_stats, rest, moduledef = model.split(nnx.Param, nnx.BatchStat, ...) +params, batch_stats, rest, static = model.split(nnx.Param, nnx.BatchStat, ...) # using `...` as the only filter is equivalent to not passing any filters model.split(...) = model.split() ``` @@ -215,13 +216,13 @@ model.split(...) = model.split() To reconstruct the module from a set of substates, you can use `merge` as usual but passing the substates as additional arguments: ```python -model = moduledef.merge(params, batch_stats, rest) +model = static.merge(params, batch_stats, rest) ``` The same is true for `apply`. ```python -y, (state, moduledef) = moduledef.apply(params, batch_stats, rest)(x) +y, (state, static) = static.apply(params, batch_stats, rest)(x) ``` Note that `apply` will return a single `state` object, if you need to `split` the state you can use `State`'s own `split` method: @@ -295,8 +296,8 @@ State({ If you use the functional API to call the module instead, the `Intermediate` nodes will be present in the output `state`. To retrieve the `Intermediate` nodes and optionally separate them from the output `state` you can use `State.split`: ```python -state, moduledef = model.split() -y, (state, moduledef) = moduledef.apply(state)(jnp.ones((8, 12))) +state, static = model.split() +y, (state, static) = static.apply(state)(jnp.ones((8, 12))) # "pop" the intermediates from the state intermediates, state = state.split(nnx.Intermediate, ...) ``` @@ -344,10 +345,10 @@ class ScanMLP(nnx.Module): def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): params_key = jax.random.split(rngs.params(), n_layers) self.n_layers = n_layers - state, moduledef = jax.vmap( + state, static = jax.vmap( lambda key: Block(dim, rngs=nnx.Rngs(params=key)).split() )(params_key) - self.layers = moduledef.merge(state) + self.layers = static.merge(state) ``` Note that we split the `params` key into `n_layers` keys so each layer has different parameters. @@ -359,11 +360,11 @@ apply it to the input `x`, passing the sliced `dropout_key` as part of the `Rngs ```python def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: dropout_key = jax.random.split(rngs.dropout(), self.n_layers) - params, moduledef = self.layers.split(nnx.Param) + params, static = self.layers.split(nnx.Param) def scan_fn(x: inputs): params, dropout_key = inputs - module = moduledef.merge(params) + module = static.merge(params) x = module(x, train=train, rngs=nnx.Rngs(dropout=dropout_key)) return x, module.extract(nnx.Param) diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index adfad6b761..efca7ba64e 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -18,19 +18,23 @@ from flax.linen.pooling import pool as pool from .nnx import compatibility as compatibility +from .nnx import graph_utils from .nnx.dataclasses import dataclass as dataclass from .nnx.dataclasses import field as field from .nnx.dataclasses import param_field as param_field from .nnx.dataclasses import treenode_field as treenode_field from .nnx.dataclasses import variable_field as variable_field from .nnx.errors import TraceContextError as TraceContextError +from .nnx.filterlib import All as All +from .nnx.filterlib import Not as Not from .nnx.flaglib import flags as flags +from .nnx.graph_utils import GraphDef as GraphDef from .nnx.helpers import Dict as Dict from .nnx.helpers import Sequence as Sequence from .nnx.helpers import TrainState as TrainState +from .nnx.module import GraphDef as GraphDef from .nnx.module import M as M from .nnx.module import Module as Module -from .nnx.module import ModuleDef as ModuleDef from .nnx.module import merge as merge from .nnx.nn import initializers as initializers from .nnx.nn.activations import celu as celu @@ -64,8 +68,6 @@ from .nnx.nn.normalization import BatchNorm as BatchNorm from .nnx.nn.normalization import LayerNorm as LayerNorm from .nnx.nn.stochastic import Dropout as Dropout -from .nnx.filterlib import All as All -from .nnx.filterlib import Not as Not from .nnx.pytreelib import Pytree as Pytree from .nnx.pytreelib import TreeNode as TreeNode from .nnx.rnglib import Rngs as Rngs diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/experimental/nnx/docs/quick_start.ipynb index 540aec36f4..ad3be07d3c 100644 --- a/flax/experimental/nnx/docs/quick_start.ipynb +++ b/flax/experimental/nnx/docs/quick_start.ipynb @@ -335,7 +335,7 @@ "\n", "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", "\n", - "In this next example we will use the `.split` method to split the model into a `params: State` and `moduledef: ModuleDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `moduledef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `ModuleDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." + "In this next example we will use the `.split` method to split the model into a `params: State` and `static: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `static` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." ] }, { @@ -344,13 +344,13 @@ "metadata": {}, "outputs": [], "source": [ - "params, moduledef = model.split(\"params\")\n", + "params, static = model.split(\"params\")\n", "\n", "\n", "@jax.jit\n", "def train_step(params: nnx.State, x, y):\n", " def loss_fn(params):\n", - " logits, _updates = moduledef.apply(params)(x)\n", + " logits, _updates = static.apply(params)(x)\n", " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(params)\n", @@ -420,7 +420,7 @@ "outputs": [], "source": [ "state = nnx.TrainState(\n", - " apply_fn=moduledef.apply,\n", + " apply_fn=static.apply,\n", " params=params,\n", " tx=optax.adam(0.001),\n", ")\n", diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/experimental/nnx/docs/tiny_nnx.ipynb index 05d35fd34a..b3944ae600 100644 --- a/flax/experimental/nnx/docs/tiny_nnx.ipynb +++ b/flax/experimental/nnx/docs/tiny_nnx.ipynb @@ -15,21 +15,22 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ + "import dataclasses\n", "import hashlib\n", "import typing as tp\n", + "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import random\n", - "import dataclasses\n", "\n", "A = tp.TypeVar(\"A\")\n", "M = tp.TypeVar(\"M\", bound=\"Module\")\n", "Sharding = tp.Tuple[tp.Optional[str], ...]\n", - "Array = random.Array\n", + "Array = jax.Array\n", "\n", "\n", "class Variable(tp.Generic[A]):\n", @@ -53,13 +54,13 @@ " jax.tree_util.register_pytree_node(\n", " cls,\n", " lambda x: ((x.value,), (x.sharding,)),\n", - " lambda metadata, value: Variable(value[0], sharding=metadata[0]),\n", + " lambda metadata, value: cls(value[0], sharding=metadata[0]),\n", " )\n", "\n", "\n", "class State(dict[str, Variable[tp.Any]]):\n", "\n", - " def filter(self, variable_type: tp.Type[Variable]) -> \"State\":\n", + " def extract(self, variable_type: tp.Type[Variable]) -> \"State\":\n", " return State(\n", " {\n", " path: variable\n", @@ -85,42 +86,42 @@ "\n", "\n", "@dataclasses.dataclass\n", - "class ModuleDef(tp.Generic[M]):\n", + "class GraphDef(tp.Generic[M]):\n", " type: tp.Type[M]\n", " index: int\n", - " submodules: tp.Dict[str, tp.Union[\"ModuleDef[Module]\", int]]\n", - " static_fields: tp.Dict[str, tp.Any]\n", + " submodules: dict[str, tp.Union[\"GraphDef[Module]\", int]]\n", + " static_fields: dict[str, tp.Any]\n", "\n", " def merge(self, state: State) -> M:\n", - " module = ModuleDef._build_module_recursive(self, {})\n", + " module = GraphDef._build_module_recursive(self, {})\n", " module.update(state)\n", " return module\n", "\n", " @staticmethod\n", " def _build_module_recursive(\n", - " moduledef: tp.Union[\"ModuleDef[M]\", int],\n", - " index_to_module: tp.Dict[int, \"Module\"],\n", + " graphdef: tp.Union[\"GraphDef[M]\", int],\n", + " index_to_module: dict[int, \"Module\"],\n", " ) -> M:\n", - " if isinstance(moduledef, int):\n", - " return index_to_module[moduledef] # type: ignore\n", + " if isinstance(graphdef, int):\n", + " return index_to_module[graphdef] # type: ignore\n", "\n", - " assert moduledef.index not in index_to_module\n", + " assert graphdef.index not in index_to_module\n", "\n", " # add a dummy module to the index to avoid infinite recursion\n", - " module = object.__new__(moduledef.type)\n", - " index_to_module[moduledef.index] = module\n", + " module = object.__new__(graphdef.type)\n", + " index_to_module[graphdef.index] = module\n", "\n", " submodules = {\n", - " name: ModuleDef._build_module_recursive(submodule, index_to_module)\n", - " for name, submodule in moduledef.submodules.items()\n", + " name: GraphDef._build_module_recursive(submodule, index_to_module)\n", + " for name, submodule in graphdef.submodules.items()\n", " }\n", - " vars(module).update(moduledef.static_fields)\n", + " vars(module).update(graphdef.static_fields)\n", " vars(module).update(submodules)\n", " return module\n", "\n", " def apply(\n", " self, state: State\n", - " ) -> tp.Callable[..., tuple[tp.Any, tuple[State, \"ModuleDef[M]\"]]]:\n", + " ) -> tp.Callable[..., tuple[tp.Any, tuple[State, \"GraphDef[M]\"]]]:\n", " def _apply(*args, **kwargs):\n", " module = self.merge(state)\n", " out = module(*args, **kwargs) # type: ignore\n", @@ -131,21 +132,21 @@ "\n", "class Module:\n", "\n", - " def split(self: M) -> tp.Tuple[State, ModuleDef[M]]:\n", + " def split(self: M) -> tp.Tuple[State, GraphDef[M]]:\n", " state = State()\n", - " moduledef = Module._partition_recursive(\n", + " graphdef = Module._partition_recursive(\n", " module=self, module_id_to_index={}, path_parts=(), state=state\n", " )\n", - " assert isinstance(moduledef, ModuleDef)\n", - " return state, moduledef\n", + " assert isinstance(graphdef, GraphDef)\n", + " return state, graphdef\n", "\n", " @staticmethod\n", " def _partition_recursive(\n", " module: M,\n", - " module_id_to_index: tp.Dict[int, int],\n", + " module_id_to_index: dict[int, int],\n", " path_parts: tp.Tuple[str, ...],\n", " state: State,\n", - " ) -> tp.Union[ModuleDef[M], int]:\n", + " ) -> tp.Union[GraphDef[M], int]:\n", " if id(module) in module_id_to_index:\n", " return module_id_to_index[id(module)]\n", "\n", @@ -167,17 +168,17 @@ " # if value is a Variable, add to state\n", " elif isinstance(value, Variable):\n", " state[\"/\".join(value_path)] = value\n", - " else: # otherwise, add to static fields\n", + " else: # otherwise, add to graphdef fields\n", " static_fields[name] = value\n", "\n", - " return ModuleDef(\n", + " return GraphDef(\n", " type=type(module),\n", " index=index,\n", " submodules=submodules,\n", " static_fields=static_fields,\n", " )\n", "\n", - " def update_state(self, state: State) -> None:\n", + " def update(self, state: State) -> None:\n", " for path, value in state.items():\n", " path_parts = path.split(\"/\")\n", " Module._set_value_at_path(self, path_parts, value)\n", @@ -242,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -314,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -340,27 +341,27 @@ "\n", " # lift init\n", " key = random.split(rngs.make_rng(), n_layers - 1)\n", - " moduledef: ModuleDef[Block] = None # type: ignore\n", + " graphdef: GraphDef[Block] = None # type: ignore\n", "\n", " def init_fn(key):\n", - " nonlocal moduledef\n", - " state, moduledef = Block(\n", + " nonlocal graphdef\n", + " state, graphdef = Block(\n", " hidden_size, hidden_size, rngs=Rngs(key)\n", " ).split()\n", " return state\n", "\n", " state = jax.vmap(init_fn)(key)\n", - " self.layers = moduledef.merge(state)\n", + " self.layers = graphdef.merge(state)\n", " self.linear = Linear(hidden_size, hidden_size, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", " # lift call\n", " key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1) # type: ignore\n", - " state, moduledef = self.layers.split()\n", + " state, graphdef = self.layers.split()\n", "\n", " def scan_fn(x, inputs: tuple[jax.Array, State]):\n", " key, state = inputs\n", - " x, (state, _) = moduledef.apply(state)(x, train=train, rngs=Rngs(key))\n", + " x, (state, _) = graphdef.apply(state)(x, train=train, rngs=Rngs(key))\n", " return x, state\n", "\n", " x, state = jax.lax.scan(scan_fn, x, (key, state))\n", @@ -371,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -379,16 +380,16 @@ "output_type": "stream", "text": [ "state = State({\n", - " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", - " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", - " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", - " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", - " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + " 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n", + " 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n", + " 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n", + " 'layers/bn/var': BatchStat(value=(4, 10), sharding=None),\n", + " 'layers/linear/b': Param(value=(4, 10), sharding=None),\n", + " 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n", + " 'linear/b': Param(value=(10,), sharding=None),\n", + " 'linear/w': Param(value=(10, 10), sharding=None)\n", "})\n", - "moduledef = ModuleDef(type=, index=0, submodules={'layers': ModuleDef(type=, index=1, submodules={'bn': ModuleDef(type=, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': ModuleDef(type=, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': ModuleDef(type=, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': ModuleDef(type=, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n" + "graphdef = GraphDef(type=, index=0, submodules={'layers': GraphDef(type=, index=1, submodules={'bn': GraphDef(type=, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': GraphDef(type=, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': GraphDef(type=, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': GraphDef(type=, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n" ] } ], @@ -397,9 +398,9 @@ "x = jax.random.normal(random.key(0), (2, 10))\n", "y = module(x, train=True, rngs=Rngs(random.key(1)))\n", "\n", - "state, moduledef = module.split()\n", + "state, graphdef = module.split()\n", "print(\"state =\", jax.tree_map(jnp.shape, state))\n", - "print(\"moduledef =\", moduledef)" + "print(\"graphdef =\", graphdef)" ] }, { @@ -412,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -420,16 +421,16 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'layers/bn/bias': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/bn/scale': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/linear/b': Variable(value=(4, 10), collection=params, sharding=None),\n", - " 'layers/linear/w': Variable(value=(4, 10, 10), collection=params, sharding=None),\n", - " 'linear/b': Variable(value=(10,), collection=params, sharding=None),\n", - " 'linear/w': Variable(value=(10, 10), collection=params, sharding=None)\n", + " 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n", + " 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n", + " 'layers/linear/b': Param(value=(4, 10), sharding=None),\n", + " 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n", + " 'linear/b': Param(value=(10,), sharding=None),\n", + " 'linear/w': Param(value=(10, 10), sharding=None)\n", "})\n", "batch_stats = State({\n", - " 'layers/bn/mean': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None),\n", - " 'layers/bn/var': Variable(value=(4, 1, 10), collection=batch_stats, sharding=None)\n", + " 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n", + " 'layers/bn/var': BatchStat(value=(4, 10), sharding=None)\n", "})\n" ] } @@ -457,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index 0fc10c0168..f5af116457 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -250,7 +250,7 @@ "\n", "NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", "\n", - "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module." + "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the static structure of the Module." ] }, { @@ -273,7 +273,7 @@ " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", "})\n", "\n", - "static = ModuleDef(\n", + "static = GraphDef(\n", " type=CounterLinear,\n", " index=0,\n", " static_fields=(),\n", @@ -281,7 +281,7 @@ " value=Empty\n", " )),),\n", " submodules=(\n", - " ('linear', ModuleDef(\n", + " ('linear', GraphDef(\n", " type=Linear,\n", " index=1,\n", " static_fields=(('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x7f3dc9ad3370>), ('out_features', 4), ('param_dtype', ), ('precision', None), ('use_bias', True)),\n", @@ -313,7 +313,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object.\n", + "The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object.\n", "\n", "Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example:" ] @@ -336,7 +336,7 @@ ], "source": [ "@jax.jit\n", - "def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array):\n", + "def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", " model = static.merge(state)\n", " y = model(x)\n", " state, _ = model.split()\n", diff --git a/flax/experimental/nnx/docs/why.md b/flax/experimental/nnx/docs/why.md index 2c262204b2..ee953c85f7 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/experimental/nnx/docs/why.md @@ -140,7 +140,7 @@ While NNX Modules inherently follow reference semantics, they can be easily conv NNX has two very simple APIs to interact with JAX: `split` and `merge`. -The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module. +The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the static structure of the Module. ```{code-cell} :outputId: 9a3f378b-739e-4f45-9968-574651200ede @@ -156,7 +156,7 @@ print(f'{state = }') print(f'\n{static = }') ``` -The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. +The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object. Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example: @@ -164,7 +164,7 @@ Using `split` and `merge` in conjunction allows you to carry your Module in and :outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d @jax.jit -def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array): +def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): model = static.merge(state) y = model(x) state, _ = model.split() diff --git a/flax/experimental/nnx/examples/00_demo.ipynb b/flax/experimental/nnx/examples/00_demo.ipynb index 97da6426bf..932b802367 100644 --- a/flax/experimental/nnx/examples/00_demo.ipynb +++ b/flax/experimental/nnx/examples/00_demo.ipynb @@ -64,7 +64,7 @@ " [0.31418085, 0.7399571 ]], dtype=float32)\n", " )\n", "})\n", - "ModuleDef(\n", + "GraphDef(\n", " type=Linear,\n", " index=0,\n", " static_fields=(('din', 2), ('dout', 2)),\n", @@ -79,10 +79,10 @@ } ], "source": [ - "state, moduledef = linear.split()\n", + "state, graphdef = linear.split()\n", "\n", "print(state)\n", - "print(moduledef)" + "print(graphdef)" ] }, { @@ -145,7 +145,7 @@ " [0.31418085, 0.7399571 ]], dtype=float32)\n", " )\n", "})\n", - "ModuleDef(\n", + "GraphDef(\n", " type=Linear,\n", " index=0,\n", " static_fields=(('din', 2), ('dout', 2)),\n", @@ -162,10 +162,10 @@ } ], "source": [ - "state, moduledef = linear.split()\n", + "state, graphdef = linear.split()\n", "\n", "print(state)\n", - "print(moduledef)" + "print(graphdef)" ] }, { @@ -185,7 +185,7 @@ } ], "source": [ - "linear2 = moduledef.merge(state)\n", + "linear2 = graphdef.merge(state)\n", "\n", "linear2.submodule is linear2" ] @@ -262,7 +262,7 @@ ], "source": [ "intermediates = linear.pop(nnx.Intermediate)\n", - "state, moduledef = linear.split()\n", + "state, graphdef = linear.split()\n", "\n", "print(intermediates)\n", "print(state)" diff --git a/flax/experimental/nnx/examples/03_train_state.py b/flax/experimental/nnx/examples/03_train_state.py index bc65d4165c..dcbe102173 100644 --- a/flax/experimental/nnx/examples/03_train_state.py +++ b/flax/experimental/nnx/examples/03_train_state.py @@ -58,12 +58,12 @@ def __call__(self, x): return x -params, counts, moduledef = MLP( - din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0) -).split(nnx.Param, ...) +params, counts, static = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)).split( + nnx.Param, ... +) state = nnx.TrainState( - moduledef, + static, params=params, tx=optax.sgd(0.1), counts=counts, @@ -107,7 +107,7 @@ def test_step(state: nnx.TrainState[MLP], batch): if step >= total_steps - 1: break -model = moduledef.merge(state.params, state.counts) +model = static.merge(state.params, state.counts) print('times called:', model.count) y_pred = model(X) diff --git a/flax/experimental/nnx/examples/05_vae.py b/flax/experimental/nnx/examples/05_vae.py index 719bab9ef4..fdad4c00fe 100644 --- a/flax/experimental/nnx/examples/05_vae.py +++ b/flax/experimental/nnx/examples/05_vae.py @@ -113,7 +113,7 @@ def generate(self, z): return nnx.sigmoid(logits) -params, moduledef = VAE( +params, static = VAE( din=int(np.prod(image_shape)), hidden_size=256, latent_size=latent_size, @@ -122,7 +122,7 @@ def generate(self, z): ).split(nnx.Param) state = nnx.TrainState( - moduledef, + static, params=params, tx=optax.adam(1e-3), ) diff --git a/flax/experimental/nnx/examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/06_scan_over_layers.py index 24dcfdb22c..0130035280 100644 --- a/flax/experimental/nnx/examples/06_scan_over_layers.py +++ b/flax/experimental/nnx/examples/06_scan_over_layers.py @@ -56,14 +56,14 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: # fork Rngs, split keys into `n_layers` keys = rngs.fork(self.n_layers) # split Module to get params - params, moduledef = self.layers.split(nnx.Param) + params, static = self.layers.split(nnx.Param) def scan_fn( x: jax.Array, inputs: Tuple[nnx.State, dict[str, nnx.RngStream]] ) -> Tuple[jax.Array, nnx.State]: params, keys = inputs # merge back Module and Rngs - module = moduledef.merge(params) + module = static.merge(params) # forward pass x = module(x, rngs=nnx.Rngs(keys)) # split state and return diff --git a/flax/experimental/nnx/examples/07_transformer.py b/flax/experimental/nnx/examples/07_transformer.py index d0352e32dd..61b2b969d0 100644 --- a/flax/experimental/nnx/examples/07_transformer.py +++ b/flax/experimental/nnx/examples/07_transformer.py @@ -389,14 +389,14 @@ def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): if cfg.scanned: assert isinstance(self.layers, DecoderBlock) - state, moduledef = self.layers.split() + state, static = self.layers.split() rngs, rngsdef = rngs.fork() dropout_key = jax.random.split(rngs['dropout'], cfg.layers) def scan_fn(x, s: tp.Tuple[jax.Array, nnx.State]): dropout_key, state = s rngs = rngsdef.merge({'dropout': dropout_key}) - y, (state, _) = moduledef.apply(state)(cfg, x, rngs=rngs) + y, (state, _) = static.apply(state)(cfg, x, rngs=rngs) return y, state x, state = jax.lax.scan( diff --git a/flax/experimental/nnx/examples/08_save_load_checkpoints.py b/flax/experimental/nnx/examples/08_save_load_checkpoints.py index 4e958a8ad1..076c7028a9 100644 --- a/flax/experimental/nnx/examples/08_save_load_checkpoints.py +++ b/flax/experimental/nnx/examples/08_save_load_checkpoints.py @@ -47,12 +47,12 @@ def create_and_save(seed: int, path: str): def load_model(path: str) -> MLP: # create that model with abstract shapes - state, moduledef = jax.eval_shape(lambda: create_model(0).split()) + state, static = jax.eval_shape(lambda: create_model(0).split()) # Load the parameters checkpointer = orbax.PyTreeCheckpointer() state = checkpointer.restore(f'{path}/state', item=state) # Merge the parameters into the model - model = moduledef.merge(state) + model = static.merge(state) return model diff --git a/flax/experimental/nnx/examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/09_parameter_surgery.py index cbdeae3eed..dc82f7ea27 100644 --- a/flax/experimental/nnx/examples/09_parameter_surgery.py +++ b/flax/experimental/nnx/examples/09_parameter_surgery.py @@ -50,7 +50,7 @@ def __call__(self, x): ) # split the parameters into trainable and non-trainable parameters -trainable_params, non_trainable, moduledef = model.split(is_trainable, ...) +trainable_params, non_trainable, static = model.split(is_trainable, ...) print('trainable_params =', jax.tree_map(jax.numpy.shape, trainable_params)) print('non_trainable = ', jax.tree_map(jax.numpy.shape, non_trainable)) diff --git a/flax/experimental/nnx/examples/10_quantization.py b/flax/experimental/nnx/examples/10_quantization.py index 0ac4ac7de2..63f7d796c6 100644 --- a/flax/experimental/nnx/examples/10_quantization.py +++ b/flax/experimental/nnx/examples/10_quantization.py @@ -106,12 +106,12 @@ def __call__(self, x: jax.Array) -> jax.Array: return x -params, moduledef = MLP( +params, static = MLP( din=np.prod(image_shape), dmid=256, dout=10, rngs=nnx.Rngs(0) ).split(nnx.Param) state = nnx.TrainState( - moduledef, + static, params=params, tx=optax.adam(1e-3), ) @@ -188,7 +188,7 @@ def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array: plt.show() -model = state.moduledef.merge(state.params) +model = state.static.merge(state.params) # %% # Quantization @@ -234,7 +234,7 @@ def optimize( num_steps: int = 100, debug: bool = False, ): - q_hparams, rest, moduledef = self.split(QHParam, ...) + q_hparams, rest, static = self.split(QHParam, ...) tx = optax.adam(1e-3) opt_state = tx.init(q_hparams) @@ -250,7 +250,7 @@ def optimization_step( print('JITTING') def loss_fn(q_hparams: nnx.State): - model = moduledef.merge(q_hparams, rest) + model = static.merge(q_hparams, rest) model.qkernel = model.quantize(pretrained.kernel, 8, jnp.uint8) assert pretrained.bias is not None model.qbias = model.quantize(pretrained.bias, 16, jnp.uint16) diff --git a/flax/experimental/nnx/nnx/compatibility.py b/flax/experimental/nnx/nnx/compatibility.py index f20d798383..ab205d16ab 100644 --- a/flax/experimental/nnx/nnx/compatibility.py +++ b/flax/experimental/nnx/nnx/compatibility.py @@ -19,7 +19,7 @@ from flax import linen from flax.experimental.nnx.nnx import helpers from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.module import Module, ModuleDef +from flax.experimental.nnx.nnx.module import GraphDef, Module from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State @@ -30,7 +30,7 @@ @dataclasses.dataclass class Functional(tp.Generic[M]): module_type: tp.Type[M] - moduledef: tp.Optional[ModuleDef[M]] + graphdef: tp.Optional[GraphDef[M]] args: tuple[tp.Any, ...] kwargs: dict[str, tp.Any] @@ -39,13 +39,13 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: if rngs is not None: kwargs['rngs'] = rngs module = self.module_type(*self.args, **self.kwargs, **kwargs) - state, moduledef = module.split() - self.moduledef = moduledef + state, graphdef = module.split() + self.graphdef = graphdef return state def apply(self, *states: tp.Any): - assert self.moduledef is not None - return self.moduledef.apply(*states) + assert self.graphdef is not None + return self.graphdef.apply(*states) def functional(cls: tp.Type[M]) -> tp.Callable[..., Functional[M]]: diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py new file mode 100644 index 0000000000..33964673e0 --- /dev/null +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import dataclasses +import enum +import typing as tp +from itertools import groupby + +import jax + +from flax.experimental.nnx.nnx import filterlib, reprlib +from flax.experimental.nnx.nnx.proxy_caller import ( + ApplyCaller, + CallableProxy, + DelayedAccessor, +) +from flax.experimental.nnx.nnx.state import State +from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable + +Index = int +Names = tp.Sequence[int] +PathParts = tuple[str, ...] +Path = str +Node = tp.TypeVar('Node') +Leaf = tp.TypeVar('Leaf') +AuxData = tp.TypeVar('AuxData') + +NODE_TYPES: dict[type, 'NodeImpl[tp.Any, tp.Any, tp.Any]'] = {} + + +@dataclasses.dataclass(frozen=True) +class NodeImpl(tp.Generic[Node, Leaf, AuxData]): + type: type + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]] + get_key: tp.Callable[[Node, str], Leaf] + set_key: tp.Callable[[Node, str, Leaf], Node] + has_key: tp.Callable[[Node, str], bool] + all_keys: tp.Callable[[Node], tuple[str, ...]] + unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] | None + create_empty: tp.Callable[[AuxData], Node] | None + init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None + + def items(self, node: Node) -> tp.Iterator[tuple[str, Leaf]]: + for key in self.all_keys(node): + yield key, self.get_key(node, key) + + +@tp.overload +def register_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], + get_key: tp.Callable[[Node, str], Leaf], + set_key: tp.Callable[[Node, str, Leaf], Node], + has_key: tp.Callable[[Node, str], bool], + all_keys: tp.Callable[[Node], tuple[str, ...]], + *, + unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node], +): + ... + + +@tp.overload +def register_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], + get_key: tp.Callable[[Node, str], Leaf], + set_key: tp.Callable[[Node, str, Leaf], Node], + has_key: tp.Callable[[Node, str], bool], + all_keys: tp.Callable[[Node], tuple[str, ...]], + *, + create_empty: tp.Callable[[AuxData], Node], + init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None], +): + ... + + +def register_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], + get_key: tp.Callable[[Node, str], Leaf], + set_key: tp.Callable[[Node, str, Leaf], Node], + has_key: tp.Callable[[Node, str], bool], + all_keys: tp.Callable[[Node], tuple[str, ...]], + *, + unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] + | None = None, + create_empty: tp.Callable[[AuxData], Node] | None = None, + init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None = None, +): + if type in NODE_TYPES: + raise ValueError(f"Node type '{type}' already registered.") + NODE_TYPES[type] = NodeImpl( + type, + flatten, + get_key, + set_key, + has_key, + all_keys, + unflatten, + create_empty, + init, + ) + + +def is_node(x: tp.Any) -> bool: + return type(x) in NODE_TYPES + + +def is_node_type(x: type[tp.Any]) -> bool: + return x in NODE_TYPES + + +@tp.overload +def get_node_impl(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: + ... + + +@tp.overload +def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: + ... + + +def get_node_impl(x: type[Node] | Node) -> NodeImpl[Node, tp.Any, tp.Any]: + if not isinstance(x, type): + x = type(x) + if not is_node_type(x): + raise ValueError(f'Unknown node type: {x}') + return NODE_TYPES[x] + + +@dataclasses.dataclass(repr=False) +class _SubgraphRepr(reprlib.Representable): + subgraphs: tuple[tuple[str, tp.Union['GraphDef[tp.Any]', int]], ...] + + def __nnx_repr__(self): + yield reprlib.Object(type='', value_sep=', ') + + for name, subgraph in self.subgraphs: + yield reprlib.Attr(repr(name), subgraph, start='(', end=')') + + +class GraphDef(tp.Generic[Node], reprlib.Representable): + __slots__ = ( + '_type', + '_index', + '_subgraphs', + '_static_fields', + '_variables', + '_metadata', + ) + + def __init__( + self, + type: tp.Type[Node], + index: int, + subgraphs: tuple[tuple[str, tp.Union['GraphDef[Node]', int]], ...], + static_fields: tuple[tuple[str, tp.Any], ...], + variables: tuple[tuple[str, Variable[Empty]], ...], + metadata: tp.Any, + ): + self._type = type + self._index = index + self._subgraphs = subgraphs + self._static_fields = static_fields + self._variables = variables + self._metadata = metadata + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + + yield reprlib.Attr('type', self._type.__name__) + yield reprlib.Attr('index', self._index) + yield reprlib.Attr('subgraphs', _SubgraphRepr(self._subgraphs)) + yield reprlib.Attr('static_fields', self._static_fields) + yield reprlib.Attr('variables', self._variables) + yield reprlib.Attr('metadata', self._metadata) + + def __hash__(self) -> int: + return hash((self._type, self._subgraphs)) + + def __eq__(self, other: tp.Any) -> bool: + if not isinstance(other, GraphDef): + return False + return self._type == other._type and self._subgraphs == other._subgraphs + + @property + def type(self) -> tp.Type[Node]: + return self._type + + @property + def index(self) -> int: + return self._index + + @property + def subgraphs( + self + ) -> tuple[tuple[str, tp.Union['GraphDef[tp.Any]', int]], ...]: + return self._subgraphs + + @property + def static_fields(self) -> tuple[tuple[str, tp.Any], ...]: + return self._static_fields + + @property + def variables(self) -> tuple[tuple[str, Variable[Empty]], ...]: + return self._variables + + @property + def metadata(self) -> tp.Any: + return self._metadata + + def merge(self, state: State, *states: State) -> Node: + if states: + state = State.merge(state, *states) + return graph_unflatten(self, state) + + def apply( + self, state: State, *states: State + ) -> ApplyCaller[tuple[State, 'GraphDef[Node]']]: + accessesor = DelayedAccessor() + + def _apply( + accessesor, *args, **kwargs + ) -> tuple[tp.Any, tuple[State, GraphDef[Node]]]: + module = self.merge(state, *states) + fn = accessesor(module) + out = fn(*args, **kwargs) + return out, graph_flatten(module) + + return CallableProxy(_apply, accessesor) # type: ignore + + def make_empty(self) -> Node: + return self.merge(State({})) + + +def _gradphdef_flatten(graphdef: GraphDef[tp.Any]): + return (), ( + graphdef._type, + graphdef._index, + graphdef._subgraphs, + graphdef._static_fields, + graphdef._variables, + graphdef._metadata, + ) + + +def _graphdef_unflatten( + metadata: tuple[ + tp.Type[Node], + int, + tuple[tuple[str, GraphDef[Node] | int], ...], + tuple[tuple[str, tp.Any], ...], + tuple[tuple[str, Variable[Empty]], ...], + tp.Any, + ], + _, +) -> GraphDef[Node]: + return GraphDef(*metadata) + + +jax.tree_util.register_pytree_node( + GraphDef, _gradphdef_flatten, _graphdef_unflatten +) + + +@dataclasses.dataclass +class FlattenState: + id_to_index: dict[int, Index] + state: dict[Path, Variable[tp.Any]] + + +def graph_flatten(x: Node) -> tuple[State, GraphDef[Node]]: + flatten_state = FlattenState({}, {}) + dagdef = _graph_flatten((), flatten_state, x) + assert not isinstance(dagdef, int) + return State(flatten_state.state), dagdef + + +def _graph_flatten( + path: PathParts, flatten_state: FlattenState, node: Node +) -> GraphDef[Node] | int: + if not is_node(node): + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + + if (index := id(node)) in flatten_state.id_to_index: + return flatten_state.id_to_index[index] + + index = len(flatten_state.id_to_index) + flatten_state.id_to_index[id(node)] = index + + subgraphs: list[tuple[str, tp.Union[GraphDef[Node], int]]] = [] + static_fields: list[tuple[str, tp.Any]] = [] + variables: list[tuple[str, Variable[Empty]]] = [] + + node_impl = get_node_impl(node) + values, metadata = node_impl.flatten(node) + for key, value in values: + if is_node(value): + graphdef = _graph_flatten((*path, key), flatten_state, value) + subgraphs.append((key, graphdef)) + elif isinstance(value, Variable): + str_path = '/'.join((*path, key)) + flatten_state.state[str_path] = value + variables.append((key, value.as_empty())) + else: + static_fields.append((key, value)) + + graphdef = GraphDef( + type=type(node), + index=index, + subgraphs=tuple(subgraphs), + static_fields=tuple(static_fields), + variables=tuple(variables), + metadata=metadata, + ) + return graphdef + + +@dataclasses.dataclass +class UnflattenState: + index_to_node: dict[Index, tp.Any] + + +def _group_state_recursive(state: dict[PathParts, Variable[Empty]]): + groups = groupby(state.items(), lambda item: item[0][0]) + nested_state: dict[str, Variable[tp.Any] | dict[str, tp.Any]] = {} + + for key, group in groups: + group = list(group) + if len(group[0][0]) == 1: + nested_state[key] = group[0][1] + else: + nested_state[key] = _group_state_recursive( + {path_parts[1:]: value for path_parts, value in group} + ) + + return nested_state + + +def graph_unflatten(graphdef: GraphDef[Node], state: State) -> Node: + unfalatten_state = UnflattenState({}) + sorted_elements = sorted(state.variables.items(), key=lambda item: item[0]) + nested_state = _group_state_recursive( + {tuple(path.split('/')): value for path, value in sorted_elements} + ) + return _graph_unflatten(graphdef, nested_state, unfalatten_state) + + +_sentinel = object() + + +def _graph_unflatten( + graphdef: tp.Union[GraphDef[Node], int], + state: dict[str, Variable[Empty] | dict[str, tp.Any]], + unflatten_state: UnflattenState, +) -> Node: + if isinstance(graphdef, int): + return unflatten_state.index_to_node[graphdef] + + if not is_node_type(graphdef.type): + raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') + + if graphdef.index in unflatten_state.index_to_node: + raise RuntimeError(f'GraphDef index {graphdef.index} already used.') + + node_impl = get_node_impl(graphdef.type) + + def _get_children(): + subgraph_nodes: dict[str, tp.Any] = {} + + for key, subgraphdef in graphdef.subgraphs: + substate = state.pop(key, {}) + if isinstance(substate, Variable): + raise ValueError( + f'Expected a subgraph for {key!r}, but got a variable.' + ) + subgraph_nodes[key] = _graph_unflatten( + subgraphdef, substate, unflatten_state + ) + + return {**subgraph_nodes, **state, **dict(graphdef.static_fields)} + + if node_impl.create_empty: + assert node_impl.init is not None + # we create an empty node first and add it to the index + # this avoids infinite recursion when there is a reference cycle + node = node_impl.create_empty(graphdef.metadata) + unflatten_state.index_to_node[graphdef.index] = node + children = _get_children() + node_impl.init(node, tuple(children.items())) + else: + # if the node type does not support the creation of an empty object it means + # that it cannot reference itself, so we can create its children first + assert node_impl.unflatten is not None + children = _get_children() + node = node_impl.unflatten(tuple(children.items()), graphdef.metadata) + unflatten_state.index_to_node[graphdef.index] = node + + return node + + +def _set_value_at_path( + node: tp.Any, path_parts: PathParts | tp.List[str], value: tp.Any +): + if not is_node(node): + raise RuntimeError(f'Unsupported type: {type(node)}') + + node_impl = get_node_impl(node) + if len(path_parts) == 1: + node_impl.set_key(node, path_parts[0], value) + else: + _set_value_at_path( + node_impl.get_key(node, path_parts[0]), path_parts[1:], value + ) + + +def graph_pop( + node: tp.Any, + filters: tuple[filterlib.Filter, ...], +) -> tuple[State, ...]: + id_to_index: dict[int, Index] = {} + path_parts: PathParts = () + predicates = tuple(filterlib.to_predicate(filter) for filter in filters) + states = tuple({} for _ in predicates) + _graph_pop(node, id_to_index, path_parts, states, predicates) + return tuple(State(x) for x in states) + + +def _graph_pop( + node: tp.Any, + id_to_index: dict[int, Index], + path_parts: PathParts, + states: tuple[dict[Path, tp.Any], ...], + predicates: tuple[filterlib.Predicate, ...], +) -> None: + if not is_node(node): + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + + if id(node) in id_to_index: + return + + index = len(id_to_index) + id_to_index[id(node)] = index + + for name, value in list(vars(node).items()): + if is_node(value): + _graph_pop(value, id_to_index, (*path_parts, name), states, predicates) + continue + elif not isinstance(value, Variable): + continue + elif value.is_empty: + continue + + path = '/'.join((*path_parts, name)) + node_impl = get_node_impl(node) + for state, predicate in zip(states, predicates): + if predicate(path, value): + state[path] = value + # empty Variable attributes + node_impl.set_key(node, name, value.as_empty()) + break + else: + # NOTE: should we raise an error here? + pass + + +def graph_update_dynamic( + node: tp.Any, + updates: State | tp.Sequence[State], +) -> None: + if not is_node(node): + raise ValueError(f'Unsupported type: {type(node)}') + + if isinstance(updates, State): + new_states = (updates,) + else: + new_states = updates + + state: dict[Path, tp.Any] = {} + for new_state in new_states: + state.update(new_state.variables) + + for path, value in state.items(): + path_parts = path.split('/') + _set_value_at_path(node, path_parts, value) + + +class _StaticModuleStatus(enum.Enum): + NEW = enum.auto() + UPDATED = enum.auto() + + +def graph_update_static(node: Node, updates: Node) -> None: + cache: dict[int, _StaticModuleStatus] = {} + _graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ()) + + +def _graph_update_static( + node: Node, + updates: Node, + cache: dict[int, _StaticModuleStatus], + status: _StaticModuleStatus, + path: PathParts, +) -> None: + if type(node) != type(updates): + raise ValueError( + f'Trying to update a node with a different type: ' + f'expected {type(node).__name__!r}, ' + f'but got {type(updates).__name__!r}' + ) + if not is_node(node): + raise ValueError(f'Unsupported node type: {type(node)}') + + if id(updates) in cache: + if cache[id(updates)] != status: + str_path = '/'.join(path) + if status is _StaticModuleStatus.NEW: + raise ValueError( + f'Trying to add a new node at path {str_path!r} but a' + ' node with the same reference has been updated' + ) + else: + raise ValueError( + f'Trying to update a node at path {str_path!r} but a new' + ' node with the same reference has been added' + ) + return + + cache[id(updates)] = status + + node_impl = get_node_impl(node) + for name, value_updates in node_impl.items(updates): + if isinstance(value_updates, Variable): + continue + elif is_node(value_updates): + if node_impl.has_key(node, name): + _graph_update_static( + node_impl.get_key(node, name), + value_updates, + cache, + _StaticModuleStatus.UPDATED, + (*path, name), + ) + else: + if id(value_updates) in cache: + if cache[id(value_updates)] is not _StaticModuleStatus.NEW: + raise ValueError( + f'Trying to add a new node at path {name!r} but a ' + 'node with the same reference has been updated' + ) + else: + cache[id(value_updates)] = _StaticModuleStatus.NEW + + node_impl.set_key(node, name, value_updates) + else: # static field + node_impl.set_key(node, name, value_updates) + + +def clone(node: Node) -> Node: + state, static = graph_flatten(node) + return static.merge(state) + + +# ----------------------------- +# register node types +# ----------------------------- +# dict +def _flatten_dict( + node: dict[str, tp.Any] +) -> tuple[tuple[tuple[str, tp.Any], ...], None]: + return tuple(node.items()), None + + +def _get_key_dict(node: dict[str, tp.Any], key: str) -> tp.Any: + return node[key] + + +def _set_key_dict( + node: dict[str, tp.Any], key: str, value: tp.Any +) -> dict[str, tp.Any]: + node[key] = value + return node + + +def _has_key_dict(node: dict[str, tp.Any], key: str) -> bool: + return key in node + + +def _all_keys_dict(node: dict[str, tp.Any]) -> tuple[str, ...]: + return tuple(node.keys()) + + +def _create_empty_dict(metadata: None) -> dict[str, tp.Any]: + return {} + + +def _init_dict(node: dict[str, tp.Any], items: tuple[tuple[str, tp.Any], ...]): + node.update(items) + + +register_node_type( + dict, + _flatten_dict, + _get_key_dict, + _set_key_dict, + _has_key_dict, + _all_keys_dict, + create_empty=_create_empty_dict, + init=_init_dict, +) + + +# list +def _flatten_list( + node: list[tp.Any] +) -> tuple[tuple[tuple[str, tp.Any], ...], int]: + return tuple((str(i), value) for i, value in enumerate(node)), len(node) + + +def _get_key_list(node: list[tp.Any], key: str) -> tp.Any: + return node[int(key)] + + +def _set_key_list(node: list[tp.Any], key: str, value: tp.Any) -> list[tp.Any]: + int_key = int(key) + if int_key >= len(node): + node.extend([EMPTY] * (int_key - len(node) + 1)) + node[int_key] = value + return node + + +def _has_key_list(node: list[tp.Any], key: str) -> bool: + return int(key) < len(node) + + +def _all_keys_list(node: list[tp.Any]) -> tuple[str, ...]: + return tuple(str(i) for i in range(len(node))) + + +def _create_empty_list(length: int) -> list[tp.Any]: + return [EMPTY] * length + + +def _init_list(node: list[tp.Any], items: tuple[tuple[str, tp.Any], ...]): + for key, value in items: + _set_key_list(node, key, value) + + +register_node_type( + list, + _flatten_list, + _get_key_list, + _set_key_list, + _has_key_list, + _all_keys_list, + create_empty=_create_empty_list, + init=_init_list, +) + + +# tuple +def _flatten_tuple( + node: tuple[tp.Any, ...] +) -> tuple[tuple[tuple[str, tp.Any], ...], int]: + return tuple((str(i), value) for i, value in enumerate(node)), len(node) + + +def _unflatten_tuple( + items: tuple[tuple[str, tp.Any], ...], length: int +) -> tuple[tp.Any, ...]: + node = [EMPTY] * length + for key, value in items: + node[int(key)] = value + return tuple(node) + + +def _get_key_tuple(node: tuple[tp.Any, ...], key: str) -> tp.Any: + return node[int(key)] + + +def _set_key_tuple( + node: tuple[tp.Any, ...], key: str, value: tp.Any +) -> tuple[tp.Any, ...]: + raise ValueError("'tuple' object is immutable, does not support assignment") + + +def _has_key_tuple(node: tuple[tp.Any, ...], key: str) -> bool: + return int(key) < len(node) + + +def _all_keys_tuple(node: tuple[tp.Any, ...]) -> tuple[str, ...]: + return tuple(str(i) for i in range(len(node))) + + +register_node_type( + tuple, + _flatten_tuple, + _get_key_tuple, + _set_key_tuple, + _has_key_tuple, + _all_keys_tuple, + unflatten=_unflatten_tuple, +) diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index 1d7d367253..f2e8714d3b 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -36,7 +36,8 @@ import optax from flax.experimental.nnx.nnx import pytreelib -from flax.experimental.nnx.nnx.module import ApplyCaller, Module, ModuleDef +from flax.experimental.nnx.nnx.module import GraphDef, Module +from flax.experimental.nnx.nnx.proxy_caller import ApplyCaller from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State @@ -112,7 +113,7 @@ def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: if isinstance(output, tp.Tuple): args = output kwargs = {} - elif isinstance(output, tp.Dict): + elif isinstance(output, dict): args = () kwargs = output else: @@ -129,21 +130,21 @@ def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: class ModuleDefApply(tp.Protocol, tp.Generic[M]): def __call__( self, state: State, *states: State - ) -> ApplyCaller[tuple[State, ModuleDef[M]]]: + ) -> ApplyCaller[tuple[State, GraphDef[M]]]: ... class TrainState(pytreelib.Pytree, tp.Generic[M]): def __init__( self, - moduledef: ModuleDef[M], + graphdef: GraphDef[M], *, params: State, tx: optax.GradientTransformation, step: int = 0, **kwargs, ): - self.moduledef = moduledef + self.graphdef = graphdef self.params: State = pytreelib.TreeNode(params) self.tx = tx self.opt_state = pytreelib.TreeNode(tx.init(self.params)) @@ -160,7 +161,7 @@ def __getattr__(self, key: str) -> tp.Any: def apply( self, state: tp.Union[State, str], *states: tp.Union[State, str] - ) -> ApplyCaller[tuple[State, ModuleDef[M]]]: + ) -> ApplyCaller[tuple[State, GraphDef[M]]]: states = (state, *states) _states = ( @@ -168,7 +169,7 @@ def apply( for state in states ) - return self.moduledef.apply(*_states) + return self.graphdef.apply(*_states) def apply_gradients(self, grads: State, **kwargs) -> 'TrainState[M]': updates, opt_state = self.tx.update(grads, self.opt_state, self.params) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index e7103680f5..f15c6436df 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -15,7 +15,6 @@ from __future__ import annotations import dataclasses -import enum import typing as tp from abc import ABCMeta from copy import deepcopy @@ -29,11 +28,18 @@ from flax.experimental.nnx.nnx import ( errors, filterlib, + graph_utils, ids, reprlib, tracers, ) from flax.experimental.nnx.nnx import variables as variableslib +from flax.experimental.nnx.nnx.graph_utils import GraphDef +from flax.experimental.nnx.nnx.proxy_caller import ( + ApplyCaller, + CallableProxy, + DelayedAccessor, +) from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import Variable @@ -46,13 +52,8 @@ Path = str PathParts = tuple[str, ...] -StateDict = tp.Dict[Path, tp.Any] -StateMapping = tp.Mapping[Path, tp.Any] - -class _ProxyContext(tp.Protocol): - def __call__(self, accessor: 'DelayedAccessor', /, *args, **kwargs) -> tp.Any: - ... +StateMapping = tp.Mapping[Path, tp.Any] @tp.runtime_checkable @@ -61,199 +62,8 @@ def setup(self) -> None: ... -@dataclasses.dataclass -class CallableProxy: - _proxy_context: _ProxyContext - _proxy_callable: tp.Callable[..., tp.Any] - - def __call__(self, *args, **kwargs): - return self._proxy_context(self._proxy_callable, *args, **kwargs) - - def __getattr__(self, name) -> 'CallableProxy': - return CallableProxy( - self._proxy_context, getattr(self._proxy_callable, name) - ) - - def __getitem__(self, key) -> 'CallableProxy': - return CallableProxy(self._proxy_context, self._proxy_callable[key]) - - -def _identity(x): - return x - - -@dataclasses.dataclass -class DelayedAccessor: - accessor: tp.Callable[[tp.Any], tp.Any] = _identity - - def __call__(self, x): - return self.accessor(x) - - def __getattr__(self, name): - return DelayedAccessor(lambda x: getattr(x, name)) - - def __getitem__(self, key): - return DelayedAccessor(lambda x: x[key]) - - -class ApplyCaller(tp.Protocol, tp.Generic[A]): - def __getattr__(self, __name) -> 'ApplyCaller[A]': - ... - - def __getitem__(self, __name) -> 'ApplyCaller[A]': - ... - - def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: - ... - - -@dataclasses.dataclass(repr=False) -class _SubmodulesRepr(reprlib.Representable): - submodules: tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...] - - def __nnx_repr__(self): - yield reprlib.Object(type='', value_sep=', ') - - for name, submodule in self.submodules: - yield reprlib.Attr(repr(name), submodule, start='(', end=')') - - -class ModuleDef(tp.Generic[M], reprlib.Representable): - __slots__ = ( - '_type', - '_index', - '_submodules', - '_static_fields', - '_variables', - '_module_state', - ) - - def __init__( - self, - type: tp.Type[M], - index: int, - submodules: tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...], - static_fields: tuple[tuple[str, tp.Any], ...], - variables: tuple[ - tuple[str, variableslib.Variable[variableslib.Empty]], ... - ], - module_state: 'ModuleStateTuple', - ): - self._type = type - self._index = index - self._submodules = submodules - self._static_fields = static_fields - self._variables = variables - self._module_state = module_state - - def __nnx_repr__(self): - yield reprlib.Object(type=type(self)) - - yield reprlib.Attr('type', self._type.__name__) - yield reprlib.Attr('index', self._index) - yield reprlib.Attr('static_fields', self._static_fields) - yield reprlib.Attr('variables', self._variables) - yield reprlib.Attr('submodules', _SubmodulesRepr(self._submodules)) - - def __hash__(self) -> int: - return hash( - (self._type, self._submodules, self._static_fields, self._variables) - ) - - def __eq__(self, other: tp.Any) -> bool: - if not isinstance(other, ModuleDef): - return False - return ( - self._type == other._type - and self._submodules == other._submodules - and self._static_fields == other._static_fields - ) - - @property - def type(self) -> tp.Type[M]: - return self._type - - @property - def index(self) -> int: - return self._index - - @property - def submodules( - self, - ) -> tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...]: - return self._submodules - - @property - def static_fields(self) -> tuple[tuple[str, tp.Any], ...]: - return self._static_fields - - @property - def variables( - self, - ) -> tuple[tuple[str, variableslib.Variable[variableslib.Empty]], ...]: - return self._variables - - @property - def module_state(self) -> 'ModuleStateTuple': - return self._module_state - - def make_module(self) -> M: - return _build_module(self) - - def merge(self, state: State, *states: State) -> M: - states = (state, *states) - module = self.make_module() - _update_module_dynamic_state(module, states) - return module - - def apply( - self, state: State, *states: State - ) -> ApplyCaller[tuple[State, 'ModuleDef[M]']]: - accessesor = DelayedAccessor() - - def _apply( - accessesor, *args, **kwargs - ) -> tuple[tp.Any, tuple[State, ModuleDef[M]]]: - module = self.merge(state, *states) - fn = accessesor(module) - out = fn(*args, **kwargs) - return out, module.split() - - return CallableProxy(_apply, accessesor) # type: ignore - - -def _moddef_flatten(moduledef: ModuleDef[M]): - return (), ( - moduledef._type, - moduledef._index, - moduledef._submodules, - moduledef._static_fields, - moduledef._variables, - moduledef._module_state, - ) - - -def _moddef_unflatten( - metadata: tuple[ - tp.Type[M], - int, - tuple[tuple[str, tp.Union['ModuleDef[Module]', int]], ...], - tuple[tuple[str, tp.Any], ...], - tuple[tuple[str, variableslib.Variable[variableslib.Empty]], ...], - 'ModuleStateTuple', - ], - _, -) -> ModuleDef[M]: - return ModuleDef(*metadata) - - -jtu.register_pytree_node(ModuleDef, _moddef_flatten, _moddef_unflatten) - - SEEN_MODULES_REPR: tp.Optional[tp.Set[ids.UUID]] = None -ModuleStateTuple = tuple[()] - class ModuleState(reprlib.Representable): __slots__ = ('_trace_state', '_id') @@ -270,13 +80,6 @@ def trace_state(self) -> tracers.TraceState: def id(self) -> ids.UUID: return self._id - def to_tuple(self) -> ModuleStateTuple: - return () - - @classmethod - def from_tuple(cls, tup: ModuleStateTuple) -> 'ModuleState': - return cls(*tup) - def __nnx_repr__(self): yield reprlib.Object(type(self)) yield reprlib.Attr('trace_state', self._trace_state) @@ -320,9 +123,9 @@ def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M: Updates = tp.Union[ M, - ModuleDef[M], - tuple[State, ModuleDef[M]], - tuple[tuple[State, ...], ModuleDef[M]], + GraphDef[M], + tuple[State, GraphDef[M]], + tuple[tuple[State, ...], GraphDef[M]], State, tuple[State, ...], ] @@ -343,7 +146,7 @@ def __getattribute__(self, name: str) -> Any: def __setattr__(self, name: str, value: Any) -> None: self._setattr(name, value) - def _setattr(self, name: str, value: Any) -> None: + def _setattr(self, name: str, value: tp.Any) -> None: if not self._module__state.trace_state.is_valid(): raise errors.TraceContextError( 'Cannot mutate Module from different trace level' @@ -365,10 +168,10 @@ def _setattr(self, name: str, value: Any) -> None: vars_dict[name] = value def __deepcopy__(self: M, memo=None) -> M: - state, moduledef = self.split() - moduledef = deepcopy(moduledef) + state, graphdef = self.split() + graphdef = deepcopy(graphdef) state = deepcopy(state) - return moduledef.merge(state) + return graphdef.merge(state) def __hash__(self) -> int: return hash(self._module__state.id) @@ -400,7 +203,7 @@ def __nnx_repr__(self): SEEN_MODULES_REPR = None @classmethod - def init(cls: type[M], *args, **kwargs) -> tuple[State, ModuleDef[M]]: + def init(cls: type[M], *args, **kwargs) -> tuple[State, GraphDef[M]]: return cls(*args, **kwargs).split() @classmethod @@ -415,10 +218,10 @@ def lift_rngs(kwargs: dict[str, tp.Any]): def _create_abstract(accessesor, *args, **kwargs): constructor = accessesor(cls) - state, moduledef = jax.eval_shape( + state, graphdef = jax.eval_shape( lambda: constructor(*args, **lift_rngs(kwargs)).split() ) - return moduledef.merge(state) + return graphdef.merge(state) return CallableProxy(_create_abstract, accessesor) # type: ignore @@ -426,11 +229,11 @@ def clone(self: M) -> M: return merge(self.split()) @tp.overload - def split(self: M) -> tuple[State, ModuleDef[M]]: + def split(self: M) -> tuple[State, GraphDef[M]]: ... @tp.overload - def split(self: M, first: filterlib.Filter, /) -> tuple[State, ModuleDef[M]]: + def split(self: M, first: filterlib.Filter, /) -> tuple[State, GraphDef[M]]: ... @tp.overload @@ -440,14 +243,13 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple[State, tpe.Unpack[tuple[State, ...]], ModuleDef[M]]: + ) -> tuple[State, tpe.Unpack[tuple[State, ...]], GraphDef[M]]: ... def split( self: M, *filters: filterlib.Filter - ) -> tuple[State, tpe.Unpack[tuple[State, ...]], ModuleDef[M]]: - moduledef = self.get_moduledef() - state = self.get_state() + ) -> tuple[State, tpe.Unpack[tuple[State, ...]], GraphDef[M]]: + state, graphdef = graph_utils.graph_flatten(self) if len(filters) == 0: states = (state,) @@ -456,17 +258,15 @@ def split( else: states = state.split(filters[0], filters[1], *filters[2:]) - return *states, moduledef + return *states, graphdef def get_state(self) -> State: - return State(_iter_state(self)) + state, _ = self.split() + return state - def get_moduledef(self: M) -> ModuleDef[M]: - module_index: tp.Dict[ids.UUID, int] = {} - path: PathParts = () - moduledef = _make_moduledef_recursive(self, module_index, path) - assert isinstance(moduledef, ModuleDef) - return moduledef + def get_graphdef(self: M) -> GraphDef[M]: + _, graphdef = self.split() + return graphdef @tp.overload def extract(self, first: filterlib.Filter, /) -> State: @@ -521,7 +321,7 @@ def pop( if len(filters) == 0: raise ValueError('Expected at least one filter') - states = _pop(self, filters) + states = graph_utils.graph_pop(self, filters) if len(states) == 1: return states[0] @@ -547,28 +347,28 @@ def _states_and_moduledef( updates, ) -> tuple[list[State], tp.Optional[Module]]: leaves = jax.tree_util.tree_leaves( - updates, is_leaf=lambda x: isinstance(x, (ModuleDef, State)) + updates, is_leaf=lambda x: isinstance(x, (GraphDef, State)) ) states: list[State] = [] module: tp.Optional[Module] = None for leaf in leaves: - if isinstance(leaf, (Module, ModuleDef)): + if isinstance(leaf, (Module, GraphDef)): if module is not None: raise ValueError( - 'Expected only one ModuleDef or Module in the updates' + 'Expected only one GraphDef or Module in the updates' ) if isinstance(leaf, Module): module = leaf states.append(leaf.get_state()) else: - module = leaf.make_module() + module = leaf.make_empty() elif isinstance(leaf, State): states.append(leaf) else: raise ValueError( - 'Expected a ModuleDef, Module or State, got' + 'Expected a GraphDef, Module or State, got' f' {type(leaf).__name__}' ) @@ -577,10 +377,10 @@ def _states_and_moduledef( states, module_update = _states_and_moduledef(updates) if module_update is not None: - _update_module_static_state(self, module_update) + graph_utils.graph_update_static(self, module_update) if states: - _update_module_dynamic_state(self, states) + graph_utils.graph_update_dynamic(self, states) def sow( self, @@ -634,6 +434,17 @@ def _on_all( def __init_subclass__(cls, experimental_pytree: bool = False) -> None: super().__init_subclass__() + graph_utils.register_node_type( + cls, + _module_graph_flatten, + _module_graph_get_key, + _module_graph_set_key, + _module_graph_has_key, + _module_graph_all_keys, + create_empty=_module_graph_create_empty, + init=_module_graph_init, + ) + if experimental_pytree: jtu.register_pytree_with_keys( cls, @@ -643,9 +454,11 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None: ) +# ------------------------- # Pytree Definition +# ------------------------- def _module_flatten(module: Module, *, with_keys: bool): - state, moduledef = module.split() + state, graphdef = module.split() variables = state.variables paths = tuple(variables.keys()) @@ -656,267 +469,59 @@ def _module_flatten(module: Module, *, with_keys: bool): else: children = tuple(variables.values()) - return children, (paths, moduledef) + return children, (paths, graphdef) def _module_unflatten( - paths_moduledef: tuple[tuple[Path, ...], ModuleDef[M]], + paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]], variables: tuple[Variable[tp.Any], ...], ) -> M: - paths, moduledef = paths_moduledef - return moduledef.merge(State(zip(paths, variables))) - - -def _make_moduledef_recursive( - module: M, - module_index: tp.Dict[ids.UUID, int], - path: PathParts, -) -> tp.Union[ModuleDef[M], int]: - if module._module__state.id in module_index: - return module_index[module._module__state.id] - - index = len(module_index) - module_index[module._module__state.id] = index - - submodules = [] - static_fields = [] - variables = [] - - for name, value in sorted(vars(module).items(), key=lambda x: x[0]): - value_path = (*path, name) - if isinstance(value, Module): - submodule_def = _make_moduledef_recursive(value, module_index, value_path) - submodules.append((name, submodule_def)) - elif isinstance(value, variableslib.Variable): - variables.append((name, value.as_empty())) - elif not name.startswith('_module__'): - static_fields.append((name, value)) - - module_def = ModuleDef( - type=type(module), - index=index, - submodules=tuple(submodules), - static_fields=tuple(static_fields), - variables=tuple(variables), - module_state=module._module__state.to_tuple(), + paths, graphdef = paths_moduledef + return graphdef.merge(State(zip(paths, variables))) + + +# ------------------------- +# Graph Definition +# ------------------------- +def _module_graph_flatten(module: Module): + nodes = tuple( + (name, value) + for name, value in vars(module).items() + if name != '_module__state' ) - return module_def - - -def _iter_state(module: Module) -> tp.Iterator[tuple[Path, tp.Any]]: - seen_modules: tp.Set[ids.UUID] = set() - path_parts: PathParts = () - - yield from _iter_state_recursive(module, seen_modules, path_parts) - + return nodes, type(module) -def _iter_state_recursive( - module: Module, seen_modules: tp.Set[ids.UUID], path_parts: PathParts -) -> tp.Iterator[tuple[Path, tp.Any]]: - if module._module__state.id in seen_modules: - return - seen_modules.add(module._module__state.id) +def _module_graph_get_key(module: Module, name: str) -> tp.Any: + return getattr(module, name) - for name, value in sorted(vars(module).items(), key=lambda x: x[0]): - new_path_parts = (*path_parts, name) - if isinstance(value, Module): - yield from _iter_state_recursive(value, seen_modules, new_path_parts) - elif isinstance(value, variableslib.Variable): - if value.is_empty: - # skip empty Variables - continue - path = '/'.join(new_path_parts) - yield path, value - - -def _set_value_at_path( - module: tp.Any, path_parts: tp.Union[PathParts, tp.List[str]], value: tp.Any -): - if len(path_parts) == 1: - setattr(module, path_parts[0], value) - else: - _set_value_at_path(vars(module)[path_parts[0]], path_parts[1:], value) - - -def _get_value_path(module: tp.Any, path: tp.Sequence[str]) -> tp.Any: - if len(path) == 0: - return module - else: - return _get_value_path(vars(module)[path[0]], path[1:]) - -def _build_module(moduledef: ModuleDef[M]) -> M: - index_module: tp.Dict[int, Module] = {} - module = _build_module_recursive(moduledef, index_module) +def _module_graph_set_key(module: M, name: str, value: tp.Any) -> M: + setattr(module, name, value) return module -def _build_module_recursive( - moduledef: tp.Union[ModuleDef[M], int], - index_module: tp.Dict[int, Module], -) -> M: - if isinstance(moduledef, int): - return index_module[moduledef] # type: ignore - - assert moduledef.index not in index_module +def _module_graph_has_key(module: Module, name: str) -> bool: + return hasattr(module, name) - # add a dummy module to the index to avoid infinite recursion - module = object.__new__(moduledef.type) - index_module[moduledef.index] = module - submodules = { - name: _build_module_recursive(submodule, index_module) - for name, submodule in moduledef.submodules - } +def _module_graph_all_keys(module: Module) -> tuple[str, ...]: + return tuple(name for name in vars(module).keys() if name != '_module__state') - vars(module).update(moduledef.static_fields) - vars(module).update(moduledef.variables) - vars(module).update(submodules) - vars(module)['_module__state'] = ModuleState.from_tuple( - moduledef.module_state - ) +def _module_graph_create_empty(cls: tp.Type[M]) -> M: + module = object.__new__(cls) + vars(module).update(_module__state=ModuleState()) return module -def _pop( - module: Module, - filters: tuple[filterlib.Filter, ...], -) -> tuple[State, ...]: - module_index: tp.Dict[ids.UUID, int] = {} - path_parts: PathParts = () - predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - states = tuple({} for _ in predicates) - _pop_recursive(module, module_index, path_parts, states, predicates) - - return tuple(State(x) for x in states) - - -def _pop_recursive( - module: Module, - module_index: tp.Dict[ids.UUID, int], - path_parts: PathParts, - states: tuple[tp.Dict[Path, tp.Any]], - predicates: tuple[filterlib.Predicate, ...], -) -> None: - if module._module__state.id in module_index: - return - - for name, value in list(vars(module).items()): - if isinstance(value, Module): - _pop_recursive( - value, module_index, (*path_parts, name), states, predicates - ) - continue - elif not isinstance(value, Variable): - continue - elif value.is_empty: - continue - - path = '/'.join((*path_parts, name)) - for state, predicate in zip(states, predicates): - if predicate(path, value): - state[path] = value - # empty Variable attributes - setattr(module, name, value.as_empty()) - break - else: - # NOTE: should we raise an error here? - pass - - module_index[module._module__state.id] = len(module_index) - - -def _update_module_dynamic_state( - module: Module, - updates: tp.Union[State, tp.Sequence[State]], -) -> None: - if isinstance(updates, State): - new_states = (updates,) - else: - new_states = updates - - state: StateDict = {} - for new_state in new_states: - state.update(new_state.variables) - - for path, value in state.items(): - path_parts = path.split('/') - _set_value_at_path(module, path_parts, value) - - -# _StaticSubmoduleState = tp.Literal["new", "updated"] -class _StaticModuleStatus(enum.Enum): - NEW = enum.auto() - UPDATED = enum.auto() - - -def _update_module_static_state(module: M, updates: M) -> None: - cache: dict[Module, _StaticModuleStatus] = {} - _update_module_static_state_recursive( - module, updates, cache, _StaticModuleStatus.UPDATED, () - ) - - -def _update_module_static_state_recursive( - module: M, - updates: M, - cache: dict[Module, _StaticModuleStatus], - status: _StaticModuleStatus, - path: PathParts, -) -> None: - if type(module) != type(updates): - raise ValueError( - f'Expected an instance of {type(module).__name__}, got' - f' {type(updates).__name__}' - ) - - if updates in cache: - if cache[updates] != status: - str_path = '/'.join(path) - if status is _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new submodule at path {str_path!r} but a' - ' submodule with the same reference has been updated' - ) - else: - raise ValueError( - f'Trying to update a submodule at path {str_path!r} but a new' - ' submodule with the same reference has been added' - ) - return - - cache[updates] = status - - module_vars = vars(module) - for name, value in vars(updates).items(): - if isinstance(value, variableslib.Variable): - continue - elif isinstance(value, Module): - if name in module_vars: - _update_module_static_state_recursive( - module_vars[name], - value, - cache, - _StaticModuleStatus.UPDATED, - (*path, name), - ) - else: - if value in cache: - if cache[value] is not _StaticModuleStatus.NEW: - raise ValueError( - f'Trying to add a new submodule at path {name!r} but a' - ' submodule with the same reference has been updated' - ) - else: - cache[value] = _StaticModuleStatus.NEW - - setattr(module, name, value) - else: # static field - setattr(module, name, value) +def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]): + vars(node).update(items) +# ------------------------- +# utils +# ------------------------- def first_from(*args: tp.Optional[A]) -> A: """Return the first non-None argument.""" for arg in args: @@ -926,7 +531,7 @@ def first_from(*args: tp.Optional[A]) -> A: def merge( - state_and_def: tuple[tpe.Unpack[tuple[State, ...]], ModuleDef[M]] + state_and_def: tuple[tpe.Unpack[tuple[State, ...]], GraphDef[M]] ) -> M: - *states, moduledef = state_and_def - return moduledef.merge(*states) + *states, graphdef = state_and_def + return graphdef.merge(*states) diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/experimental/nnx/nnx/proxy_caller.py new file mode 100644 index 0000000000..dd3803abe7 --- /dev/null +++ b/flax/experimental/nnx/nnx/proxy_caller.py @@ -0,0 +1,57 @@ +import dataclasses +import typing as tp + +import typing_extensions as tpe + +A = tp.TypeVar('A') + + +class _ProxyContext(tpe.Protocol): + def __call__(self, accessor: 'DelayedAccessor', /, *args, **kwargs) -> tp.Any: + ... + + +@dataclasses.dataclass +class CallableProxy: + _proxy_context: _ProxyContext + _proxy_callable: tp.Callable[..., tp.Any] + + def __call__(self, *args, **kwargs): + return self._proxy_context(self._proxy_callable, *args, **kwargs) + + def __getattr__(self, name) -> 'CallableProxy': + return CallableProxy( + self._proxy_context, getattr(self._proxy_callable, name) + ) + + def __getitem__(self, key) -> 'CallableProxy': + return CallableProxy(self._proxy_context, self._proxy_callable[key]) + + +def _identity(x): + return x + + +@dataclasses.dataclass +class DelayedAccessor: + accessor: tp.Callable[[tp.Any], tp.Any] = _identity + + def __call__(self, x): + return self.accessor(x) + + def __getattr__(self, name): + return DelayedAccessor(lambda x: getattr(x, name)) + + def __getitem__(self, key): + return DelayedAccessor(lambda x: x[key]) + + +class ApplyCaller(tp.Protocol, tp.Generic[A]): + def __getattr__(self, __name) -> 'ApplyCaller[A]': + ... + + def __getitem__(self, __name) -> 'ApplyCaller[A]': + ... + + def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: + ... diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py index ee179a0bb0..92e71d661a 100644 --- a/flax/experimental/nnx/nnx/pytreelib.py +++ b/flax/experimental/nnx/nnx/pytreelib.py @@ -218,7 +218,7 @@ def _pytree__unflatten( return pytree @classmethod - def _to_flax_state_dict(cls, pytree: 'Pytree') -> tp.Dict[str, tp.Any]: + def _to_flax_state_dict(cls, pytree: 'Pytree') -> dict[str, tp.Any]: from flax import serialization state_dict = { @@ -232,7 +232,7 @@ def _to_flax_state_dict(cls, pytree: 'Pytree') -> tp.Dict[str, tp.Any]: def _from_flax_state_dict( cls, pytree: P, - state: tp.Dict[str, tp.Any], + state: dict[str, tp.Any], ) -> P: """Restore the state of a data class.""" from flax import serialization diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index c4840f35cc..98eb2b12f7 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -39,7 +39,7 @@ Leaf = tp.Any Path = str -StateDict = tp.Dict[Path, tp.Any] +StateDict = dict[Path, tp.Any] StateMapping = tp.Mapping[Path, tp.Any] diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 3ee713634f..39bffbac46 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -45,12 +45,10 @@ tracers, variables, ) -from flax.experimental.nnx.nnx.module import ( +from flax.experimental.nnx.nnx.module import GraphDef, Module, ModuleMeta +from flax.experimental.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, - Module, - ModuleDef, - ModuleMeta, ) from flax.experimental.nnx.nnx.state import State @@ -178,8 +176,8 @@ def _create_jit(*args, **kwargs) -> JIT[M]: class JittedFn(tp.Protocol, tp.Generic[M]): def __call__( - self, state_and_def: tuple[State | tuple[State, ...], ModuleDef[M]] - ) -> tuple[tuple[State | tuple[State, ...], ModuleDef[M]], tp.Any]: + self, state_and_def: tuple[State | tuple[State, ...], GraphDef[M]] + ) -> tuple[tuple[State | tuple[State, ...], GraphDef[M]], tp.Any]: ... @@ -188,12 +186,12 @@ def get_jitted_fn(_module_type: type[M], f, options: JITOptions) -> JittedFn[M]: @functools.partial(jax.jit, **jit_kwargs) def jitted_fn( - state_and_def: tuple[State | tuple[State, ...], ModuleDef[M]], + state_and_def: tuple[State | tuple[State, ...], GraphDef[M]], *args, **kwargs, ): _check_args(args) - states, moduledef = state_and_def + states, graphdef = state_and_def if isinstance(states, State): states = (states,) @@ -202,7 +200,7 @@ def jitted_fn( with tracers.nnx_trace(nnx_trace): if 'rngs' in kwargs: kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) - module = moduledef.merge(*states) + module = graphdef.merge(*states) out = f(module, *args, **kwargs) updates = module.split() @@ -467,7 +465,7 @@ def grad_apply(options: GradOptions, f, module: Module, *args, **kwargs): predicate = filterlib.to_predicate(options.wrt) - diff, nondiff, moduledef = module.split(predicate, ...) + diff, nondiff, graphdef = module.split(predicate, ...) transform = jax.value_and_grad if options.return_value else jax.grad @functools.partial( @@ -479,13 +477,13 @@ def grad_apply(options: GradOptions, f, module: Module, *args, **kwargs): reduce_axes=options.reduce_axes, ) def grad_fn(diff: State): - nonlocal moduledef + nonlocal graphdef with tracers.nnx_trace(tracers.get_top_trace(diff)): - module = moduledef.merge(diff, nondiff) + module = graphdef.merge(diff, nondiff) out = f(module, *args, **kwargs) - updates, moduledef = module.split() + updates, graphdef = module.split() if options.has_aux: loss, aux = out out = (loss, (updates, aux)) @@ -511,7 +509,7 @@ def grad_fn(diff: State): else: out, updates = out - module.update((updates, moduledef)) + module.update((updates, graphdef)) return out @@ -791,10 +789,10 @@ def scan_init( split_keys = None broadcast_keys = None - moduledef: tp.Optional[ModuleDef[M]] = None + graphdef: tp.Optional[GraphDef[M]] = None def _init_state(split_keys, broadcast_keys): - nonlocal moduledef + nonlocal graphdef if split_keys is not None: assert broadcast_keys is not None @@ -805,7 +803,7 @@ def _init_state(split_keys, broadcast_keys): # lift module filters = (*options.variable_axes.keys(), ...) - *states, moduledef = module.split(*filters) + *states, graphdef = module.split(*filters) return tuple(states) @@ -819,7 +817,7 @@ def _init_state(split_keys, broadcast_keys): ) *axes_states, carry_state = _init_state(split_keys, broadcast_keys) - moduledef = tp.cast(ModuleDef[M], moduledef) + graphdef = tp.cast(GraphDef[M], graphdef) # add additional axis name to Variable.sharding if spmd.PARTITION_NAME in options.scan_metadata: @@ -828,7 +826,7 @@ def _init_state(split_keys, broadcast_keys): for state, index in zip(axes_states, options.variable_axes.values()) ] - module = moduledef.merge(*axes_states, carry_state) + module = graphdef.merge(*axes_states, carry_state) return module @@ -845,7 +843,7 @@ def scan_apply( # split module state filters = (*options.variable_axes.keys(), ...) - *scan_states, carry_state, moduledef = module.split(*filters) + *scan_states, carry_state, graphdef = module.split(*filters) # transpose axes state scan_states = tuple( @@ -919,7 +917,7 @@ def scan_apply( split_keys = None broadcast_keys = None - moduledef_out: tp.Optional[ModuleDef[Module]] = None + moduledef_out: tp.Optional[GraphDef[Module]] = None def scan_fn( carry: tuple[State, tp.Any], @@ -963,7 +961,7 @@ def scan_fn( ] # merge module state - module = moduledef.merge(*scan_states, carry_state) + module = graphdef.merge(*scan_states, carry_state) output = f(module, carry_arg, *args, **kwargs) @@ -1208,26 +1206,26 @@ def remat_apply( ): _check_args(args) - state, moduledef = module.split() + state, graphdef = module.split() keys = rngs.fork() if rngs is not None else None def _remat_fn( state: State, keys: tp.Optional[dict[str, jax.Array]], *args, - ) -> tuple[tuple[State, ModuleDef[Module]], tp.Any]: + ) -> tuple[tuple[State, GraphDef[Module]], tp.Any]: kwargs = {} if keys is not None: kwargs['rngs'] = rnglib.Rngs(keys) - module = moduledef.merge(state) + module = graphdef.merge(state) out = f(module, *args, **kwargs) state_and_def = module.split() return state_and_def, out - state_and_def: tuple[State, ModuleDef[Module]] + state_and_def: tuple[State, GraphDef[Module]] state_and_def, out = jax.checkpoint( _remat_fn, prevent_cse=options.prevent_cse, @@ -1422,10 +1420,10 @@ def vmap_init( split_keys = None broadcast_keys = None - moduledef: tp.Optional[ModuleDef[M]] = None + graphdef: tp.Optional[GraphDef[M]] = None def _init_state(split_keys, broadcast_keys): - nonlocal moduledef + nonlocal graphdef if split_keys is not None: assert broadcast_keys is not None @@ -1436,7 +1434,7 @@ def _init_state(split_keys, broadcast_keys): # lift module filters = (*options.variable_axes.keys(), ...) - *states, moduledef = module.split(*filters) + *states, graphdef = module.split(*filters) return tuple(states) @@ -1450,7 +1448,7 @@ def _init_state(split_keys, broadcast_keys): ) *axes_states, carry_state = _init_state(split_keys, broadcast_keys) - moduledef = tp.cast(ModuleDef[M], moduledef) + graphdef = tp.cast(GraphDef[M], graphdef) # add additional axis name to Variable.sharding if spmd.PARTITION_NAME in options.vmap_metadata: @@ -1459,7 +1457,7 @@ def _init_state(split_keys, broadcast_keys): for state, index in zip(axes_states, options.variable_axes.values()) ] - module = moduledef.merge(*axes_states, carry_state) + module = graphdef.merge(*axes_states, carry_state) return module @@ -1474,7 +1472,7 @@ def vmap_apply( # split module state filters = (*options.variable_axes.keys(), ...) - *vectorized_states, broadcast_state, moduledef = module.split(*filters) + *vectorized_states, broadcast_state, graphdef = module.split(*filters) # infer length axis_sizes: tp.Set[int] = set() @@ -1529,7 +1527,7 @@ def vmap_apply( split_keys = None broadcast_keys = None - moduledef_out: tp.Optional[ModuleDef[Module]] = None + moduledef_out: tp.Optional[GraphDef[Module]] = None keys_axes = 0 states_axes = list(options.variable_axes.values()) @@ -1568,7 +1566,7 @@ def vmap_fn( ] # merge module state - module = moduledef.merge(*vectorized_states, broadcast_state) + module = graphdef.merge(*vectorized_states, broadcast_state) output = f(module, *args, **kwargs) diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index ef47b7f457..6977f01f50 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -51,7 +51,7 @@ AddAxisHook = tp.Callable[[V, AxisName, AxisIndex], V] RemoveAxisHook = tp.Callable[[V, AxisName, AxisIndex], V] -VariableTypeCache: tp.Dict[str, tp.Type['Variable[tp.Any]']] = {} +VariableTypeCache: dict[str, tp.Type['Variable[tp.Any]']] = {} class Empty: diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py new file mode 100644 index 0000000000..46525bb90a --- /dev/null +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -0,0 +1,103 @@ +import pytest + +from flax.experimental import nnx + + +class TestGraphUtils: + def test_flatten(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + state, static = nnx.graph_utils.graph_flatten(g) + + state['0/b'] = 2 + state['3'] = 4 + + def test_unflatten(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + state, static = nnx.graph_utils.graph_flatten(g) + g = static.merge(state) + + assert g[0] is g[2] + + def test_unflatten_empty(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + state, static = nnx.graph_utils.graph_flatten(g) + g = static.merge(nnx.State({})) + + assert g[0] is g[2] + assert 'b' not in g[0] + assert g[3] is nnx.EMPTY + + def test_update_dynamic(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + state, static = nnx.graph_utils.graph_flatten(g) + + state['0/b'] = 3 + nnx.graph_utils.graph_update_dynamic(g, state) + + assert g[0]['b'].value == 3 + assert g[2]['b'].value == 3 + + def test_update_static(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + g2 = nnx.graph_utils.clone(g) + g2[0]['a'] = 5 + + nnx.graph_utils.graph_update_static(g, g2) + + assert g[0]['a'] == 5 + assert g[2]['a'] == 5 + + def test_update_static_inconsistent_types(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + g2 = [a, a, 3, nnx.Param(4)] + + with pytest.raises( + ValueError, match='Trying to update a node with a different type' + ): + nnx.graph_utils.graph_update_static(g, g2) + + def test_update_static_add_new(self): + a = {'a': 1, 'b': nnx.Param(2)} + b = [5, 6] + g = [a, 3, a, nnx.Param(4)] + g2 = [a, 3, a, nnx.Param(4), b] + + nnx.graph_utils.graph_update_static(g, g2) + + assert g[4][0] == 5 + assert g[4][1] == 6 + + def test_update_static_add_shared_error(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + g2 = [a, 3, a, nnx.Param(4), a] + + with pytest.raises(ValueError, match='Trying to add a new node at path'): + nnx.graph_utils.graph_update_static(g, g2) + + def test_module_list(self): + rngs = nnx.Rngs(0) + ls = [ + nnx.Linear(2, 2, rngs=rngs), + nnx.BatchNorm(2, rngs=rngs), + ] + + state, static = nnx.graph_utils.graph_flatten(ls) + + assert state['0/kernel'].shape == (2, 2) + assert state['0/bias'].shape == (2,) + assert state['1/scale'].shape == (2,) + assert state['1/bias'].shape == (2,) + assert state['1/mean'].shape == (2,) + assert state['1/var'].shape == (2,) diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/experimental/nnx/tests/test_helpers.py index 732358c0a4..c0f2a57db9 100644 --- a/flax/experimental/nnx/tests/test_helpers.py +++ b/flax/experimental/nnx/tests/test_helpers.py @@ -23,10 +23,10 @@ class TestHelpers: def test_train_state(self): m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) - params, batch_stats, moduledef = m.split(nnx.Param, nnx.BatchStat) + params, batch_stats, graphdef = m.split(nnx.Param, nnx.BatchStat) state = nnx.TrainState( - moduledef, + graphdef, params=params, tx=optax.sgd(1.0), batch_stats=nnx.TreeNode(batch_stats), @@ -53,10 +53,10 @@ def __call__(self, x: jax.Array, train: bool) -> jax.Array: return x module = Foo(rngs=nnx.Rngs(0)) - params, batch_stats, moduledef = module.split(nnx.Param, nnx.BatchStat) + params, batch_stats, graphdef = module.split(nnx.Param, nnx.BatchStat) state = nnx.TrainState( - moduledef, + graphdef, params=params, tx=optax.sgd(1.0), batch_stats=batch_stats, diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py index 87c48bae85..2644c19f1b 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/experimental/nnx/tests/test_integration.py @@ -94,8 +94,8 @@ def __call__(self, x): return x @jax.jit - def train_step(state: nnx.State, moduledef: nnx.ModuleDef[Model], x, y): - model = moduledef.merge(state) + def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): + model = graphdef.merge(state) @nnx.grad def loss_fn(model: Model): @@ -110,16 +110,16 @@ def loss_fn(model: Model): return model.split() - moduledef: nnx.ModuleDef[Model] - state, moduledef = Model(rngs=nnx.Rngs(0)).split() + graphdef: nnx.GraphDef[Model] + state, graphdef = Model(rngs=nnx.Rngs(0)).split() x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) for _i in range(3): - state, moduledef = train_step(state, moduledef, x, y) + state, graphdef = train_step(state, graphdef, x, y) - model = moduledef.merge(state) + model = graphdef.merge(state) assert model.block1.linear.bias is not None assert model.block2.linear.bias is not None @@ -186,12 +186,12 @@ def __call__(self, x): y = model(x) assert model.count == 1 - params, counts, moduledef = model.split(nnx.Param, Count) + params, counts, graphdef = model.split(nnx.Param, Count) @jax.jit def train_step(params, counts, x, y): def loss_fn(params): - y_pred, (updates, _) = moduledef.apply(params, counts)(x) + y_pred, (updates, _) = graphdef.apply(params, counts)(x) loss = jax.numpy.mean((y_pred - y) ** 2) return loss, updates.extract(Count) @@ -204,7 +204,7 @@ def loss_fn(params): # execute the training step params, counts = train_step(params, counts, x, y) - model = moduledef.merge(params, counts) + model = graphdef.merge(params, counts) assert model.count == 2 def test_intermediates_example(self): @@ -241,9 +241,9 @@ def __call__(self, x): model = Linear(12, 2, rngs=nnx.Rngs(0)) - state, moduledef = model.split() + state, graphdef = model.split() - y, (state, _) = moduledef.apply(state)(jnp.ones((8, 12))) + y, (state, _) = graphdef.apply(state)(jnp.ones((8, 12))) intermediates, state = state.split(nnx.Intermediate, ...) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 2553ca9c5d..08f05d0c3e 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -65,13 +65,13 @@ def test_split_merge(self): m = nnx.Dict(a=nnx.Param(1)) @jax.jit - def g(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[int]]): - m = moduledef.merge(state) + def g(state: nnx.State, graphdef: nnx.GraphDef[nnx.Dict[int]]): + m = graphdef.merge(state) m.a = 2 return m.split() - state, moduledef = g(*m.split()) - m2 = moduledef.merge(state) + state, graphdef = g(*m.split()) + m2 = graphdef.merge(state) assert m2.a == 2 @@ -138,10 +138,10 @@ def __init__(self): m = Foo() - state, moduledef = m.split() + state, graphdef = m.split() assert len(state) == 1 - m2 = moduledef.merge(state) + m2 = graphdef.merge(state) assert m2 is m2.sub def test_deref_through_jit(self): @@ -151,16 +151,16 @@ def test_deref_through_jit(self): m = m0 = nnx.Dict({'a': nnx.Sequence([r1, r2]), 'b': r1}) @jax.jit - def f(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[Any]]): - m = moduledef.merge(state) + def f(state: nnx.State, graphdef: nnx.GraphDef[nnx.Dict[Any]]): + m = graphdef.merge(state) assert m['a'][0] is not m['b'] assert m['a'][1] is not m['b'] return m.split() - state, moduledef = f(*m.split()) - m = moduledef.merge(state) + state, graphdef = f(*m.split()) + m = graphdef.merge(state) assert m['a'][0] is not m['b'] assert m['a'][1] is not m['b'] @@ -174,13 +174,13 @@ def test_cross_barrier(self): m = nnx.Dict(a=nnx.Param(1)) @jax.jit - def g(state: nnx.State, moduledef: nnx.ModuleDef[nnx.Dict[int]]): - m = moduledef.merge(state) + def g(state: nnx.State, graphdef: nnx.GraphDef[nnx.Dict[int]]): + m = graphdef.merge(state) m.a += 1 return m.split() - state, moduledef = g(*m.split()) - m2 = moduledef.merge(state) + state, graphdef = g(*m.split()) + m2 = graphdef.merge(state) assert m2 is not m assert m.a == 1 assert m2.a == 2 @@ -226,7 +226,7 @@ def test_deref_number_of_fields(self): } ) - p, moduledef = m.split() + p, graphdef = m.split() assert len(p) == 4 assert len(jax.tree_util.tree_leaves(p)) == 4 @@ -339,7 +339,7 @@ def add_field(self): m2 = Foo() m2.add_field() - m1.update(m2.get_moduledef()) + m1.update(m2.get_graphdef()) assert m1.a == 1 @@ -426,9 +426,7 @@ def add_submodule(self): assert hasattr(m2, 'c') - with pytest.raises( - ValueError, match='Trying to add a new submodule at path' - ): + with pytest.raises(ValueError, match='Trying to add a new node at path'): m1.update(m2) def test_update_add_shared_error_new_first(self): @@ -452,9 +450,7 @@ def add_submodule(self): m2 = m2.clone() # clone to sort the fields - with pytest.raises( - ValueError, match='Trying to update a submodule at path' - ): + with pytest.raises(ValueError, match='Trying to add a new node at path'): m1.update(m2) def test_create_abstract(self): @@ -516,7 +512,7 @@ class Foo(nnx.Module): f=nnx.Variable(6), # test that we can pass in a node ) - state, moduledef = m.split() + state, graphdef = m.split() assert len(state) == 4 assert state.variables['b'] == nnx.TreeNode(2) @@ -585,13 +581,13 @@ def __call__(self, x, *, rngs: nnx.Rngs): rngs = nnx.Rngs(0) foo = Foo(c=1.0, rngs=rngs) - states, moduledef = foo.split() + states, graphdef = foo.split() assert isinstance(states, nnx.State) assert isinstance(states.variables['w'], nnx.Param) # assert isinstance(states["c"], jax.Array) - y, _updates = moduledef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) + y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) assert isinstance(y, jax.Array) @@ -609,13 +605,13 @@ def __call__(self, x, *, rngs: nnx.Rngs): foo = Foo(c=1.0, rngs=nnx.Rngs(0)) - state, moduledef = foo.split() + state, graphdef = foo.split() - assert isinstance(moduledef, nnx.ModuleDef) + assert isinstance(graphdef, nnx.GraphDef) assert isinstance(state, nnx.State) assert isinstance(state.variables['w'], nnx.Param) assert isinstance(state.variables['c'], nnx.Variable) - y, (state, moduledef) = moduledef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) + y, (state, graphdef) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) assert isinstance(y, jax.Array) diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py index fb1862c080..099c02f5ea 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -27,7 +27,7 @@ def test_partition(self): c=100, ) - params, rest, moduledef = m.split(nnx.Param, ...) + params, rest, graphdef = m.split(nnx.Param, ...) assert len(params) == 2 assert len(rest) == 1 @@ -39,7 +39,7 @@ def test_partition(self): # check rest assert rest['a/1'] == m.a[1] - m2 = moduledef.merge(params, rest) + m2 = graphdef.merge(params, rest) assert m2.a[0] == m.a[0] assert m2.a[1] == m.a[1] @@ -110,7 +110,7 @@ def test_update_from_with_array_leaf(self): c=nnx.Variable(jax.numpy.array(100)), ) - state, moduledef = m.split() + state, graphdef = m.split() state = jax.tree_map(lambda x: x * 2, state) m.update(state) diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py index c3fb5614aa..5a50f674e7 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -16,7 +16,6 @@ import jax import jax.numpy as jnp -import numpy as np import pytest from flax.experimental import nnx @@ -66,7 +65,7 @@ def test_rng_fork(self): key1 = rngs1.params() key2 = rngs2.params() - assert not np.equal(key1, key2).all() + assert not jnp.allclose(key1, key2) def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) @@ -119,11 +118,11 @@ def test_partition_merge(self): key1 = rngs.dropout() key2 = rngs2.dropout() - assert not np.equal(key1, key2).all() + assert not jnp.allclose(key1, key2) rngs3 = nnx.Rngs(keys) key3 = rngs3.dropout() - assert np.equal(key2, key3).all() + assert jnp.allclose(key2, key3) def test_fork_broadcast(self): rngs = nnx.Rngs(params=0, dropout=1) diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py index 96d9cd065c..c4cb67317a 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -62,9 +62,9 @@ def __init__(self): def __call__(self, x): return x @ self.w - params, moduledef = Foo().split() + params, graphdef = Foo().split() state = nnx.TrainState( - moduledef, + graphdef, params=params, tx=optax.adam(1e-3), )