Skip to content

Commit

Permalink
[nnx] optimize graph
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 4, 2024
1 parent 5d896bc commit d9acfb5
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 79 deletions.
4 changes: 2 additions & 2 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "068208fc",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -280,7 +280,7 @@
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
"\n",
" for path, value in state.flat_state().items():\n",
" for path, value in state.flat_state():\n",
" for i, predicate in enumerate(predicates):\n",
" if predicate(path, value):\n",
" flat_states[i][path] = value\n",
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def split(node, *filters):
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state().items():
for path, value in state.flat_state():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]:

mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = state.flat_state()
state = dict(state.flat_state())
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
Expand Down
4 changes: 2 additions & 2 deletions examples/lm1b_nnx/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transfer_params(
params_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.axis_rules)
flat_params_nnx = params_nnx.flat_state()
flat_params_nnx = dict(params_nnx.flat_state())
flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/')

def apply_rules(names: tuple[str, ...]):
Expand Down Expand Up @@ -163,7 +163,7 @@ def transfer_cache(
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
flat_cache_nnx = cache_nnx.flat_state()
flat_cache_nnx = dict(cache_nnx.flat_state())
flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/')

def copy_var(nnx_name: str, linen_name: str):
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from . import traversals as traversals
71 changes: 37 additions & 34 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import FlatState, State
from flax.nnx.statelib import State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
Expand Down Expand Up @@ -110,15 +110,12 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
pop_key: tp.Callable[[Node, Key], Leaf]
create_empty: tp.Callable[[AuxData], Node]
clear: tp.Callable[[Node], None]

def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]):
for key, value in items:
self.set_key(node, key, value)
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None]


@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]


NodeImpl = tp.Union[
Expand All @@ -137,6 +134,7 @@ def register_graph_node_type(
pop_key: tp.Callable[[Node, Key], Leaf],
create_empty: tp.Callable[[AuxData], Node],
clear: tp.Callable[[Node], None],
init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None],
):
if type in GRAPH_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand All @@ -148,12 +146,13 @@ def register_graph_node_type(
pop_key=pop_key,
create_empty=create_empty,
clear=clear,
init=init,
)

def register_pytree_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node],
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
):
if type in PYTREE_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')
Expand Down Expand Up @@ -202,8 +201,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)
def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
self._mapping = dict(mapping) if copy else mapping

def __contains__(self, key: object) -> bool:
return key in self._mapping
Expand Down Expand Up @@ -401,15 +400,15 @@ def flatten(
"""
if ref_index is None:
ref_index = RefMap()
flat_state: dict[PathParts, StateLeaf] = {}
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)


def _graph_flatten(
path: PathParts,
ref_index: RefMap[tp.Any, Index],
flat_state: dict[PathParts, StateLeaf],
flat_state: list[tuple[PathParts, StateLeaf]],
node: Node,
) -> NodeDef[Node] | NodeRef:
if not is_node(node):
Expand Down Expand Up @@ -441,10 +440,10 @@ def _graph_flatten(
LeafAttribute(key, NodeRef(type(value), ref_index[value]))
)
else:
flat_state[(*path, key)] = value.to_state()
flat_state.append(((*path, key), value.to_state()))
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type(value), variable_index, HashableMapping(value.get_metadata())
type(value), variable_index, HashableMapping(value._var_metadata)
)
attributes.append(LeafAttribute(key, variabledef))
else:
Expand Down Expand Up @@ -528,7 +527,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, NodeLeaf | Node] = {}
children: list[tuple[Key, NodeLeaf | Node]] = []
state_keys: set = set(state.keys())

# for every key in attributes there are 6 possible cases:
Expand All @@ -539,28 +538,29 @@ def _get_children():
if key not in state:
# if key is not present create an empty types
if type(attribute) is StaticAttribute:
children[key] = attribute.value
children.append((key, attribute.value))
elif type(attribute) is SubGraphAttribute:
# if the key is a subgraph we create an empty node
subgraphdef = attribute.value
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
# create a node from an empty state, reasoning:
# * its a node with no state
# * its a node with state but only through references of already
# created nodes
substate = {}
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, substate, index_ref, index_ref_cache
)
children.append((key, subnode))
elif type(attribute) is LeafAttribute:
variabledef = attribute.value
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# key for a variable is missing, raise an error
raise ValueError(
Expand All @@ -587,19 +587,20 @@ def _get_children():
subgraphdef = attribute.value

if isinstance(subgraphdef, NodeRef):
children[key] = index_ref[subgraphdef.index]
children.append((key, index_ref[subgraphdef.index]))
else:
children[key] = _graph_unflatten(
subnode = _graph_unflatten(
subgraphdef, value, index_ref, index_ref_cache
)
children.append((key, subnode))

elif type(attribute) is LeafAttribute:
variabledef = attribute.value

if variabledef.index in index_ref:
# add an existing variable
assert isinstance(variabledef, NodeRef)
children[key] = index_ref[variabledef.index]
children.append((key, index_ref[variabledef.index]))
else:
# its a unseen variable, create a new one
assert isinstance(variabledef, VariableDef)
Expand All @@ -626,7 +627,7 @@ def _get_children():
variable = variabledef.type.from_metadata(
value, variabledef.metadata
)
children[key] = variable
children.append((key, variable))
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')
Expand All @@ -651,13 +652,11 @@ def _get_children():
else:
node = node_impl.create_empty(nodedef.metadata)
index_ref[nodedef.index] = node
children = _get_children()
node_impl.init(node, tuple(children.items()))
node_impl.init(node, _get_children())
else:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
children = _get_children()
node = node_impl.unflatten(tuple(children.items()), nodedef.metadata)
node = node_impl.unflatten(_get_children(), nodedef.metadata)

return node

Expand All @@ -669,7 +668,9 @@ def graph_pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple(
{} for _ in predicates
)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(
GraphState.from_flat_path(flat_state) for flat_state in flat_states
Expand All @@ -680,7 +681,7 @@ def _graph_pop(
node: tp.Any,
id_to_index: dict[int, Index],
path_parts: PathParts,
flat_states: tuple[FlatState[StateLeaf], ...],
flat_states: tuple[dict[PathParts, StateLeaf], ...],
predicates: tuple[filterlib.Predicate, ...],
) -> None:
if not is_node(node):
Expand Down Expand Up @@ -816,7 +817,7 @@ def split(
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *states
Expand Down Expand Up @@ -1006,7 +1007,7 @@ def split(
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index)
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

self.flatten_end(ref_index)
Expand Down Expand Up @@ -1570,7 +1571,9 @@ def pop(
id_to_index: dict[int, Index] = {}
path_parts: PathParts = ()
predicates = tuple(filterlib.to_predicate(filter) for filter in filters)
flat_states: tuple[FlatState[StateLeaf], ...] = tuple({} for _ in predicates)
flat_states: tuple[dict[PathParts, StateLeaf], ...] = tuple(
{} for _ in predicates
)
_graph_pop(
node=node,
id_to_index=id_to_index,
Expand Down Expand Up @@ -1786,8 +1789,8 @@ def is_pytree_node(x: tp.Any) -> bool:
# known non-pytree types
elif isinstance(x, Variable):
return False
# knon pytree types
elif isinstance(x, (VariableState, State)):
# known pytree types
elif type(x) is VariableState or type(x) is State:
return True
else:
return not jax.tree_util.all_leaves((x,))
Expand Down Expand Up @@ -1829,7 +1832,7 @@ def _unflatten_pytree(
PYTREE_NODE_IMPL = PytreeNodeImpl(
type=GenericPytree,
flatten=_flatten_pytree,
unflatten=_unflatten_pytree,
unflatten=_unflatten_pytree, # type: ignore
)

# common pytrees
Expand Down
21 changes: 11 additions & 10 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from flax.nnx import graph
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')
Expand Down Expand Up @@ -109,10 +108,11 @@ def __init_subclass__(cls) -> None:
graph.register_graph_node_type(
type=cls,
flatten=cls._graph_node_flatten,
set_key=cls._graph_node_set_key,
pop_key=cls._graph_node_pop_key,
set_key=cls._graph_node_set_key, # type: ignore
pop_key=cls._graph_node_pop_key, # type: ignore
create_empty=cls._graph_node_create_empty,
clear=cls._graph_node_clear,
init=cls._graph_node_init, # type: ignore
)

if not tp.TYPE_CHECKING:
Expand Down Expand Up @@ -189,14 +189,12 @@ def __treescope_repr__(self, path, subtree_renderer):

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
(key, value)
for key, value in vars(self).items()
if key != '_object__state'
)
nodes = vars(self).copy()
del nodes['_object__state']
nodes = sorted(nodes.items())
return nodes, (type(self), self._object__state._initializing)

def _graph_node_set_key(self, key: Key, value: tp.Any):
def _graph_node_set_key(self, key: str, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
Expand All @@ -208,7 +206,7 @@ def _graph_node_set_key(self, key: Key, value: tp.Any):
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key: Key):
def _graph_node_pop_key(self, key: str):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
return vars(self).pop(key)
Expand All @@ -225,3 +223,6 @@ def _graph_node_clear(self):
module_vars = vars(self)
module_vars.clear()
module_vars['_object__state'] = module_state

def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
vars(self).update(attributes)
Loading

0 comments on commit d9acfb5

Please sign in to comment.