From e9c635d761b36e0abbb35cfd5c8d4dadddf8f3c5 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sun, 22 Dec 2024 18:04:38 -0500 Subject: [PATCH] [nnx] simpllify unflatten --- benchmarks/nnx_graph_overhead.py | 55 ++-- flax/nnx/bridge/variables.py | 11 +- flax/nnx/extract.py | 46 +++- flax/nnx/graph.py | 403 +++++++++++++++-------------- flax/nnx/statelib.py | 35 ++- flax/nnx/transforms/autodiff.py | 4 +- flax/nnx/transforms/compilation.py | 4 +- flax/nnx/transforms/iteration.py | 31 +-- tests/nnx/bridge/wrappers_test.py | 4 +- tests/nnx/graph_utils_test.py | 25 +- tests/nnx/module_test.py | 2 +- 11 files changed, 354 insertions(+), 266 deletions(-) diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index 88809f777..cffee7c4a 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -30,25 +30,44 @@ flags.DEFINE_integer('depth', 5, 'Depth of the model') - class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): - self.list = [ - nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), - nnx.Param(jnp.zeros((dout,))), - ] - self.dict = { - 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), - 'b': nnx.Param(jnp.zeros((dout,))), - } + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Block(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.linear = Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) +class Count(nnx.Variable): + pass + class MLP(nnx.Module): - def __init__(self, depth, *, rngs: nnx.Rngs): + def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ - Linear(10, 10, rngs=rngs) for _ in range(depth) + Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] + self.linear_out = Block(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count.value += 1 + x = nnx.relu(self.linear_in(x)) + for layer in self.intermediates: + x = nnx.relu(layer(x)) + x = self.linear_out(x) + return x def main(argv): @@ -63,14 +82,15 @@ def main(argv): X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - model = MLP(depth=depth, rngs=nnx.Rngs(0)) - tx = optax.sgd(1e-3) - optimizer = nnx.Optimizer(model, tx) - #------------------------------------------------------------ # NNX #------------------------------------------------------------ if mode in ['all', 'nnx']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + @nnx.jit def step_nnx(model: MLP, optimizer: nnx.Optimizer): pass @@ -93,6 +113,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): #------------------------------------------------------------ if mode in ['all', 'jax']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + @jax.jit def step_jax(graphdef, state): return graphdef, state diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 93531bb48..b1b78d168 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -18,10 +18,9 @@ import jax from flax import struct from flax.core import meta -from flax.nnx import spmd +from flax.nnx import graph, spmd from flax.nnx import traversals from flax.nnx import variablelib as variableslib -from flax.nnx.module import GraphDef import typing as tp @@ -192,9 +191,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: linen_structured = {} for kp, v in traversals.flatten_mapping( - nnx_attrs, - is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef), - ).items(): + nnx_attrs, + is_leaf=lambda _, x: isinstance( + x, variableslib.Variable | graph.NodeDef | graph.NodeRef + ), + ).items(): if isinstance(v, variableslib.Variable): col_name = variable_type_name(type(v)) else: diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index e5662e104..1572d609e 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -251,10 +251,13 @@ class GraphDefState(struct.PyTreeNode): graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False) state: graph.GraphState = struct.field(pytree_node=True) +S = tp.TypeVar( + 'S', bound=graph.GraphState | graph.GraphFlatState | list[tp.Any] +) -class NodeStates(struct.PyTreeNode): +class NodeStates(struct.PyTreeNode, tp.Generic[S]): _graphdef: graph.GraphDef[tp.Any] | None - states: tuple[graph.GraphState | graph.GraphFlatState, ...] + states: tuple[S, ...] metadata: tp.Any = struct.field(pytree_node=False) @property @@ -264,7 +267,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]: return self._graphdef @property - def state(self) -> graph.GraphState | graph.GraphFlatState: + def state(self) -> S: if len(self.states) != 1: raise ValueError( f'Expected exactly one GraphDefState, got {len(self.states)}' @@ -275,9 +278,9 @@ def state(self) -> graph.GraphState | graph.GraphFlatState: def from_split( cls, graphdef: graph.GraphDef[tp.Any], - state: graph.GraphState | graph.GraphFlatState, + state: S, /, - *states: graph.GraphState | graph.GraphFlatState, + *states: S, metadata: tp.Any = None, ): return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata) @@ -285,8 +288,8 @@ def from_split( @classmethod def from_states( cls, - state: graph.GraphState | graph.GraphFlatState, - *states: graph.GraphState | graph.GraphFlatState, + state: S, + *states: S, ): return cls(_graphdef=None, states=(state, *states), metadata=None) @@ -319,6 +322,15 @@ def to_tree( ctxtag: str | None = None, check_aliasing: bool = True, ) -> tp.Any: + if prefix is Missing or prefix is None: + # fast path, no need for prefix broadcasting or consistent aliasing checks + with graph.split_context(ctxtag) as split_ctx: + return jax.tree.map( + lambda x: split_fn(split_ctx, (), prefix, x) + if map_non_graph_nodes or graph.is_graph_node(x) + else x, + tree, + ) leaf_prefixes = broadcast_prefix( prefix, tree, @@ -373,6 +385,16 @@ def from_tree( map_non_graph_nodes: bool = False, ctxtag: str | None = None, ) -> tp.Any: + if prefix is Missing or prefix is None: + # fast path, no need for prefix broadcasting or consistent aliasing checks + with graph.merge_context(ctxtag) as merge_ctx: + return jax.tree.map( + lambda x: merge_fn(merge_ctx, (), prefix, x) + if map_non_graph_nodes or is_node_leaf(x) + else x, + tree, + is_leaf=is_leaf, + ) leaf_prefixes = broadcast_prefix( prefix, tree, @@ -387,13 +409,9 @@ def from_tree( with graph.merge_context(ctxtag) as merge_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): - if is_node_leaf(leaf): - leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) - leaves_out.append(leaf_out) - else: - if map_non_graph_nodes: - leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) - leaves_out.append(leaf) + if map_non_graph_nodes or is_node_leaf(leaf): + leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) + leaves_out.append(leaf) pytree_out = jax.tree.unflatten(treedef, leaves_out) return pytree_out diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index c18a710b3..44750d630 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections import deque import contextlib import dataclasses import functools @@ -290,24 +291,6 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(VariableDef) -@dataclasses.dataclass(frozen=True, slots=True) -class SubGraphAttribute: - key: Key - value: NodeDef[tp.Any] | NodeRef[tp.Any] - - -@dataclasses.dataclass(frozen=True, slots=True) -class StaticAttribute: - key: Key - value: tp.Any - - -@dataclasses.dataclass(frozen=True, slots=True) -class LeafAttribute: - key: Key - value: VariableDef | NodeRef[tp.Any] - - @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a @@ -316,28 +299,17 @@ class NodeDef(GraphDef[Node], reprlib.Representable): type: tp.Type[Node] index: int - attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] + attributes: tuple[ + tuple[ + Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any] + ], + ..., + ] metadata: tp.Any index_mapping: HashableMapping[Index, Index] | None - @classmethod - def create( - cls, - type: tp.Type[Node], - index: int, - attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], - metadata: tp.Any, - index_mapping: tp.Mapping[Index, Index] | None, - ): - return cls( - type=type, - index=index, - attributes=attributes, - metadata=metadata, - index_mapping=HashableMapping(index_mapping) - if index_mapping is not None - else None, - ) + def replace(self, **kwargs): + return dataclasses.replace(self, **kwargs) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) @@ -387,12 +359,31 @@ def _apply( jax.tree_util.register_static(NodeDef) -PureState = tuple[GraphDef[A], GraphState] - +PureState = tuple[GraphDef[Node], GraphState] +@tp.overload +def flatten( + node: Node, + /, + *, + ref_index: RefMap[tp.Any, Index] | None = None, + with_paths: tp.Literal[True] = True, +) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]]]: ... +@tp.overload +def flatten( + node: Node, + /, + *, + ref_index: RefMap[tp.Any, Index] | None = None, + with_paths: tp.Literal[False], +) -> tuple[GraphDef[Node], list[tp.Any]]: ... def flatten( - node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None -) -> tuple[GraphDef[Node], FlatState[tp.Any]]: + node: Node, + /, + *, + ref_index: RefMap[tp.Any, Index] | None = None, + with_paths: bool = True, +) -> tuple[GraphDef[Node], FlatState[VariableState[tp.Any]] | list[tp.Any]]: """Flattens a graph node into a (graphdef, state) pair. Args: @@ -400,19 +391,28 @@ def flatten( ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. + with_paths: A boolean that indicates whether to return a FlatState object that includes + the paths to VariableState objects, or just a list of the Variable's inner values. """ if ref_index is None: ref_index = RefMap() - flat_state: list[tuple[PathParts, StateLeaf]] = [] - graphdef = _graph_flatten((), ref_index, flat_state, node) - return graphdef, FlatState(flat_state) + leaves: list[StateLeaf] = [] + path: list[Key] | None = [] if with_paths else None + paths: list[PathParts] | None = [] if with_paths else None + graphdef = _graph_flatten(node, path, ref_index, leaves, paths) + + if paths is not None: + return graphdef, FlatState.from_sorted_keys_values(paths, leaves) + else: + return graphdef, leaves def _graph_flatten( - path: PathParts, - ref_index: RefMap[tp.Any, Index], - flat_state: list[tuple[PathParts, StateLeaf]], node: Node, + path: list[Key] | None, + ref_index: RefMap[tp.Any, Index], + leaves: list[StateLeaf], + paths: list[PathParts] | None, ) -> NodeDef[Node] | NodeRef: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') @@ -429,36 +429,50 @@ def _graph_flatten( else: index = -1 - attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] + attributes: list[ + tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] + ] = [] values, metadata = node_impl.flatten(node) for key, value in values: + if path is not None: + path.append(key) if is_node(value): - nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) - # subgraphs.append((key, nodedef)) - attributes.append(SubGraphAttribute(key, nodedef)) + nodedef = _graph_flatten(value, path, ref_index, leaves, paths) + attributes.append((key, nodedef)) elif isinstance(value, Variable): if value in ref_index: - attributes.append( - LeafAttribute(key, NodeRef(type(value), ref_index[value])) - ) + attributes.append((key, NodeRef(type(value), ref_index[value]))) else: - flat_state.append(((*path, key), value.to_state())) + if path is None: + leaf = value.raw_value + else: + leaf = value.to_state() + leaves.append(leaf) + if path is not None: + assert paths is not None + paths.append(tuple(path)) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value._var_metadata) ) - attributes.append(LeafAttribute(key, variabledef)) + attributes.append((key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): - path_str = '/'.join(map(str, (*path, key))) - raise ValueError( + if path is not None: + path_str = '/'.join(map(str, path)) + raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' - ) + ) + else: + raise ValueError(f'Arrays leaves are not supported, found {value}') # static_fields.append((key, value)) - attributes.append(StaticAttribute(key, value)) + attributes.append((key, Static(value))) + + if path is not None: + path.pop() - nodedef = NodeDef.create( + nodedef = NodeDef( type=node_impl.type, index=index, attributes=tuple(attributes), @@ -467,10 +481,28 @@ def _graph_flatten( ) return nodedef +def _get_sorted_leaves( + xs: tp.Mapping[tp.Any, tp.Any], +) -> deque[tp.Any]: + if not isinstance(xs, tp.Mapping): # type: ignore + raise TypeError(f'expected Mapping; got {type(xs).__qualname__}') + leaves = deque() + + def _flatten(xs): + if not isinstance(xs, tp.Mapping): + leaves.append(xs) + else: + for _, value in sorted(xs.items()): + _flatten(value) + + _flatten(xs) + return leaves def unflatten( graphdef: GraphDef[Node], - state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], + state: State[KeyT, tp.Any | dict[KeyT, tp.Any]] + | FlatState[tp.Any] + | list[tp.Any], /, *, index_ref: dict[Index, tp.Any] | None = None, @@ -491,17 +523,27 @@ def unflatten( existing graph nodes are mutated to have the new content/topology specified by the graphdef. """ - if isinstance(state, State): - state = state.raw_mapping # type: ignore + if isinstance(state, (State, dict)): + leaves = _get_sorted_leaves(state) + elif isinstance(state, FlatState): + leaves = deque(state.get_values()) + elif isinstance(state, list): # type: ignore + leaves = deque(state) + else: + raise ValueError(f'Unsupported state type: {type(state)}') if index_ref is None: index_ref = {} assert isinstance(graphdef, (NodeDef, NodeRef)) - node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) + node = _graph_unflatten(graphdef, leaves, index_ref, index_ref_cache) + if leaves: + raise ValueError( + f'Incorrect number of leaves: got an extra {len(leaves)} leaves in the state' + ) return node def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], - state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], + leaves: deque[tp.Any], index_ref: dict[Index, tp.Any], index_ref_cache: dict[Index, tp.Any] | None, ) -> Node: @@ -531,113 +573,50 @@ def _graph_unflatten( def _get_children(): children: list[tuple[Key, NodeLeaf | Node]] = [] - state_keys: set = set(state.keys()) - - # for every key in attributes there are 6 possible cases: - # - (2) the key can either be present in the state or not - # - (3) the key can be a subgraph, a leaf, or a static attribute - for attribute in nodedef.attributes: - key = attribute.key - if key not in state: - # if key is not present create an empty types - if type(attribute) is StaticAttribute: - children.append((key, attribute.value)) - elif type(attribute) is SubGraphAttribute: - # if the key is a subgraph we create an empty node - subgraphdef = attribute.value - assert not isinstance(subgraphdef, VariableDef) - if isinstance(subgraphdef, NodeRef): - # subgraph exists, take it from the cache - children.append((key, index_ref[subgraphdef.index])) - else: - # create a node from an empty state, reasoning: - # * its a node with no state - # * its a node with state but only through references of already - # created nodes - substate = {} - subnode = _graph_unflatten( - subgraphdef, substate, index_ref, index_ref_cache - ) - children.append((key, subnode)) - elif type(attribute) is LeafAttribute: - variabledef = attribute.value - if variabledef.index in index_ref: - # variable exists, take it from the cache - children.append((key, index_ref[variabledef.index])) - else: - # key for a variable is missing, raise an error - raise ValueError( - f'Expected key {key!r} in state while building node of type ' - f'{nodedef.type.__name__}.' - ) - else: - raise RuntimeError(f'Unknown static field: {key!r}') - else: - state_keys.remove(key) - value = state[key] - # if key in nodedef.static_fields: - if type(attribute) is StaticAttribute: - raise ValueError( - f'Got state for static field {key!r}, this is not supported.' - ) - elif type(attribute) is SubGraphAttribute: - if is_state_leaf(value): + + for key, value in nodedef.attributes: + if type(value) is Static: + children.append((key, value.value)) + elif type(value) is NodeRef: + children.append((key, index_ref[value.index])) + elif type(value) is NodeDef: + # if the key is a subgraph we create an empty node + subgraphdef = value + subnode = _graph_unflatten( + subgraphdef, leaves, index_ref, index_ref_cache + ) + children.append((key, subnode)) + elif type(value) is VariableDef: + variabledef = value + if not leaves: + raise ValueError('Not enough leaves to unflatten the graph') + # its a unseen variable, create a new one + value = leaves.popleft() + # when idxmap is present, check if the Varable exists there + # and update existing variables if it does + if index_ref_cache is not None and variabledef.index in index_ref_cache: + # if variable exists, update it + variable = index_ref_cache[variabledef.index] + if not isinstance(variable, Variable): raise ValueError( - f'Expected value of type {attribute.value} for ' - f'{key!r}, but got {value!r}' + f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - assert isinstance(value, dict) - subgraphdef = attribute.value - - if isinstance(subgraphdef, NodeRef): - children.append((key, index_ref[subgraphdef.index])) + if isinstance(value, VariableState): + variable.update_from_state(value) else: - subnode = _graph_unflatten( - subgraphdef, value, index_ref, index_ref_cache - ) - children.append((key, subnode)) - - elif type(attribute) is LeafAttribute: - variabledef = attribute.value - - if variabledef.index in index_ref: - # add an existing variable - assert isinstance(variabledef, NodeRef) - children.append((key, index_ref[variabledef.index])) + variable.raw_value = value + else: # variabledef.index not in index_ref_cache + # variable reference does not exist outside, create a new one + if isinstance(value, VariableState): + variable = value.to_variable() else: - # its a unseen variable, create a new one - assert isinstance(variabledef, VariableDef) - # when idxmap is present, check if the Varable exists there - # and update existing variables if it does - if ( - index_ref_cache is not None - and variabledef.index in index_ref_cache - ): - # if variable exists, update it - variable = index_ref_cache[variabledef.index] - if not isinstance(variable, Variable): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(variable)}.' - ) - if isinstance(value, VariableState): - variable.update_from_state(value) - else: - variable.raw_value = value - else: # if it doesn't, create a new variable - if isinstance(value, VariableState): - variable = value.to_variable() - else: - variable = variabledef.type.from_metadata( - value, variabledef.metadata - ) - children.append((key, variable)) - index_ref[variabledef.index] = variable - else: - raise RuntimeError(f'Unknown key: {key!r}, this is a bug.') - - # NOTE: we could allw adding new StateLeafs here - if state_keys: - raise ValueError(f'Unknown keys: {state_keys}') + variable = variabledef.type.from_metadata( + value, variabledef.metadata + ) + children.append((key, variable)) + index_ref[variabledef.index] = variable + else: + raise RuntimeError(f'Unknown static field: {key!r}') return children @@ -814,7 +793,7 @@ def split( ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, flat_state = flatten(node, self.ref_index) + graphdef, flat_state = flatten(node, ref_index=self.ref_index) flat_states = _split_state(flat_state, filters) states = tuple( State.from_flat_path(flat_state) for flat_state in flat_states @@ -822,12 +801,16 @@ def split( if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) + graphdef = graphdef.replace( + index_mapping=HashableMapping(index_to_index, copy=False) ) return graphdef, *states + @tp.overload + def flatten( + self, graph_node: A, /, *, with_paths: tp.Literal[False] + ) -> tuple[GraphDef[A], list[tp.Any]]: ... @tp.overload def flatten( self, graph_node: A, / @@ -850,24 +833,36 @@ def flatten( tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], ]: ... def flatten( - self, node: A, *filters: filterlib.Filter + self, node: A, *filters: filterlib.Filter, with_paths: bool = True ) -> tuple[ - GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]] + GraphDef[A], + FlatState[VariableState[tp.Any]] | list[tp.Any], + tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]], ]: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) - graphdef, flat_state = flatten(node, self.ref_index) - flat_states = _split_state(flat_state, filters) + if with_paths: + graphdef, flat_state = flatten( + node, ref_index=self.ref_index, with_paths=True + ) + flat_states = _split_state(flat_state, filters) + else: + if filters: + raise ValueError('Cannot use filters with with_paths=False') + graphdef, flat_state = flatten( + node, ref_index=self.ref_index, with_paths=False + ) + flat_states = (flat_state,) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) + graphdef = graphdef.replace( + index_mapping=HashableMapping(index_to_index, copy=False) ) - return graphdef, *flat_states + return graphdef, *flat_states # type: ignore @contextlib.contextmanager @@ -924,7 +919,7 @@ def merge( def unflatten( self, graphdef: GraphDef[A], - flat_state: GraphFlatState, + flat_state: GraphFlatState | list[tp.Any], /, *flat_states: GraphFlatState, ) -> A: @@ -945,7 +940,15 @@ def unflatten( # inner merge (2) index_ref_cache = None - state = FlatState.merge(flat_state, *flat_states).to_nested_state() + if type(flat_state) is list: + if flat_states: + raise ValueError( + 'Cannot use multiple flat_states when flat_state is a list, ' + f'got flat_state: {flat_state!r}, flat_states: {flat_states!r}' + ) + state = flat_state + else: + state = FlatState.merge(flat_state, *flat_states) node = unflatten( graphdef, state, @@ -1081,15 +1084,15 @@ def split( filters are passed, a single :class:`State` is returned. """ ref_index: RefMap[tp.Any, Index] = RefMap() - graphdef, flat_state = flatten(node, ref_index) + graphdef, flat_state = flatten(node, ref_index=ref_index) states = tuple( State.from_flat_path(flat_state) for flat_state in _split_state(flat_state, filters) ) if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) + graphdef = graphdef.replace( + index_mapping=HashableMapping(index_to_index, copy=False) ) self.flatten_end(ref_index) @@ -1872,21 +1875,15 @@ class Static(tp.Generic[A]): # --------------------------------------------------------- class GenericPytree: ... +from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY def is_pytree_node(x: tp.Any) -> bool: - t = type(x) - if t in PYTREE_REGISTRY: + if type(x) in JAX_PYTREE_REGISTRY: return True - elif t in GRAPH_REGISTRY: - return False - # known non-pytree types - elif isinstance(x, Variable): - return False - # known pytree types - elif type(x) is VariableState or type(x) is State: + elif isinstance(x, tuple): return True else: - return not jax.tree_util.all_leaves((x,)) + return False def _key_path_to_key(key: tp.Any) -> Key: @@ -1905,20 +1902,28 @@ def _key_path_to_key(key: tp.Any) -> Key: else: return str(key) +class IndexesPytreeDef(tp.NamedTuple): + key_index: HashableMapping[Key, int] + treedef: jax.tree_util.PyTreeDef def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) - nodes = tuple((_key_path_to_key(path[0]), value) for path, value in leaves) - - return nodes, treedef + nodes = [(_key_path_to_key(path[0]), value) for path, value in leaves] + key_index = HashableMapping( + {key: i for i, (key, _) in enumerate(nodes)}, copy=False + ) + nodes.sort() # sort by key + return nodes, IndexesPytreeDef(key_index, treedef) def _unflatten_pytree( - nodes: tuple[tuple[Key, tp.Any], ...], treedef: jax.tree_util.PyTreeDef + nodes: tuple[tuple[Key, tp.Any], ...], metadata: IndexesPytreeDef ): - pytree = treedef.unflatten(value for _, value in nodes) + # sort to original order + sorted_nodes = sorted(nodes, key=lambda x: metadata.key_index[x[0]]) + pytree = metadata.treedef.unflatten(value for _, value in sorted_nodes) return pytree diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 1c1b1b512..16cc2cf9e 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -55,17 +55,36 @@ def __treescope_repr__(self, path, subtree_renderer): return subtree_renderer(children, path=path) class FlatState(reprlib.SequenceReprMixin[tuple[PathParts, V]]): + __slots__ = ('_keys', '_values') + _keys: tuple[PathParts, ...] _values: list[V] - def __init__(self, items: tp.Iterable[tuple[PathParts, V]]): + def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort: bool): keys, values = [], [] + if sort: + items = sorted(items) for key, value in items: keys.append(key) values.append(value) self._keys = tuple(keys) self._values = values + @staticmethod + def from_sorted_keys_values( + keys: list[PathParts], values: list[V], / + ) -> FlatState[V]: + flat_state = object.__new__(FlatState) + flat_state._keys = tuple(keys) + flat_state._values = values + return flat_state + + def get_keys(self) -> tp.Tuple[PathParts, ...]: + return self._keys + + def get_values(self) -> tp.List[V]: + return self._values + @tp.overload def __getitem__(self, index: int) -> tuple[PathParts, V]: ... @tp.overload @@ -75,7 +94,7 @@ def __getitem__( ) -> tuple[PathParts, V] | FlatState[V]: if isinstance(index, int): return self._keys[index], self._values[index] - return FlatState(zip(self._keys[index], self._values[index])) + return FlatState(zip(self._keys[index], self._values[index]), sort=False) def __len__(self) -> int: return len(self._keys) @@ -158,9 +177,15 @@ def merge( /, *flat_states: tp.Iterable[tuple[PathParts, V]], ) -> FlatState[V]: + if not flat_states: + if isinstance(flat_state, FlatState): + return flat_state + return FlatState(flat_state, sort=True) flat_states = (flat_state, *flat_states) - return FlatState(elem for flat_state in flat_states for elem in flat_state) + return FlatState( + (elem for flat_state in flat_states for elem in flat_state), sort=True + ) def _flat_state_pytree_flatten(x: FlatState[V]): @@ -282,7 +307,7 @@ def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: return State.from_flat_path(result) def flat_state(self) -> FlatState[V]: - return FlatState(traversals.flatten_to_sequence(self._mapping)) + return FlatState(traversals.flatten_to_sequence(self._mapping), sort=True) @classmethod def from_flat_path( @@ -563,7 +588,7 @@ def _split_state( # if we didn't break, set leaf to last state flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? - return tuple(FlatState(flat_state) for flat_state in flat_states) + return tuple(FlatState(flat_state, sort=False) for flat_state in flat_states) def create_path_filters(state: State): diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 5ef0d183b..3af7f3913 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -431,7 +431,7 @@ def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): if isinstance(x, graph.NodeDef): assert x.index_mapping is not None index_mappings.append(x.index_mapping) - return dataclasses.replace(x, index_mapping=None) + return x.replace(index_mapping=None) return x @dataclasses.dataclass(eq=False) @@ -665,7 +665,7 @@ def __call__( def _insert_index_mappings(x): if isinstance(x, graph.NodeDef): index_mapping: graph.HashableMapping = index_mappings.popleft() - return dataclasses.replace(x, index_mapping=index_mapping) + return x.replace(index_mapping=index_mapping) return x pure_args_out, pure_out = jax.tree_util.tree_map( diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index d3420dd43..4763ea849 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -93,13 +93,13 @@ def _jit_split_fn(ctx: graph.SplitContext, path, prefix, x): return extract.NodeStates.from_split( *ctx.flatten(x, *prefix.filters), metadata=prefix ) - return extract.NodeStates.from_split(*ctx.flatten(x)) + return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False)) def _jit_merge_fn(ctx: graph.MergeContext, path, prefix, leaf) -> tp.Any: if not isinstance(leaf, extract.NodeStates): raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') - return ctx.unflatten(leaf.graphdef, *leaf.states) # type: ignore + return ctx.unflatten(leaf.graphdef, *leaf.states) @dataclasses.dataclass(eq=False) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 994e58286..b2eec6510 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -660,9 +660,7 @@ def extract_index_mappings(x): index_mapping = x._graphdef.index_mapping assert index_mapping is not None carry_index_mappings.append(index_mapping) - x = x.replace( - _graphdef=dataclasses.replace(x._graphdef, index_mapping=None) - ) + x = x.replace(_graphdef=x._graphdef.replace(index_mapping=None)) return x pure_carry_arg_out = jax.tree.map( @@ -683,9 +681,7 @@ def insert_index_mappings(x): x._graphdef, graph.NodeDef ): index_mapping = carry_index_mappings.popleft() - x = x.replace( - _graphdef=dataclasses.replace(x._graphdef, index_mapping=index_mapping) - ) + x = x.replace(_graphdef=x._graphdef.replace(index_mapping=index_mapping)) return x pure_carry_arg_out = jax.tree.map( @@ -1342,22 +1338,21 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): if isinstance(nd, graph.NodeRef): return - for attribute in nd.attributes: - if type(attribute) is graph.SubGraphAttribute: - per_node_def(attribute.value) + for _, value in nd.attributes: + if type(value) is graph.NodeDef: + per_node_def(value) elif ( - type(attribute) is graph.LeafAttribute - and isinstance(attribute.value, (graph.VariableDef, graph.NodeRef)) - and attribute.value.index >= 0 + isinstance(value, (graph.VariableDef, graph.NodeRef)) + and value.index >= 0 ): - global_index_mapping[attribute.value.index] = attribute.value.index + global_index_mapping[value.index] = value.index return per_node_def(ns._graphdef) return dataclasses.replace( ns, - _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) + _graphdef=ns._graphdef.replace( + index_mapping=graph.HashableMapping(global_index_mapping) ), ) @@ -1373,9 +1368,9 @@ def per_node_state(ns: extract.NodeStates | tp.Any): ): return ns assert isinstance(ns._graphdef, graph.NodeDef) - return dataclasses.replace(ns, _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=None - )) + return dataclasses.replace( + ns, _graphdef=ns._graphdef.replace(index_mapping=None) + ) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 5b65603a2..b353dd492 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -228,7 +228,9 @@ def test_nnx_to_linen(self): assert y.shape == (1, 64) np.testing.assert_allclose(y, x @ variables['params']['kernel']) assert 'nnx' in variables - assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef) + assert isinstance( + variables['nnx']['graphdef'], nnx.graph.NodeDef | nnx.graph.NodeRef + ) def test_nnx_to_linen_multiple_rngs(self): class NNXInner(nnx.Module): diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 85c4f2a4c..daad0873a 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -65,10 +65,25 @@ def test_flatten(self): refmap = nnx.graph.RefMap() graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap) - state = flat_state.to_nested_state() - state[0]['b'].raw_value = 2 - state[3].raw_value = 4 + assert flat_state[0][1].value == 2 + assert flat_state[1][1].value == 4 + + assert len(refmap) == 2 + assert a['b'] in refmap + assert g[3] in refmap + + def test_flatten_no_paths(self): + a = {'a': 1, 'b': nnx.Param(2)} + g = [a, 3, a, nnx.Param(4)] + + refmap = nnx.graph.RefMap() + graphdef, flat_state = nnx.graph.flatten( + g, ref_index=refmap, with_paths=False + ) + + assert flat_state[0] == 2 + assert flat_state[1] == 4 assert len(refmap) == 2 assert a['b'] in refmap @@ -109,7 +124,9 @@ def test_unflatten_empty(self): graphdef, state = nnx.split(g) - with self.assertRaisesRegex(ValueError, 'Expected key'): + with self.assertRaisesRegex( + ValueError, 'Not enough leaves to unflatten the graph' + ): nnx.graph.unflatten(graphdef, nnx.State({})) def test_update_dynamic(self): diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index ce65186dd..058e69b89 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -676,7 +676,7 @@ def __call__(self, x, *, rngs: nnx.Rngs): graphdef, state = nnx.split(foo) - assert isinstance(graphdef, nnx.GraphDef) + assert isinstance(graphdef, nnx.graph.NodeDef | nnx.graph.NodeRef) assert isinstance(state, nnx.State) assert issubclass(state.w.type, nnx.Param) assert issubclass(state.c.type, nnx.Variable)