Skip to content

Commit

Permalink
[nnx] cache flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 30, 2024
1 parent e9c635d commit 45f8f34
Show file tree
Hide file tree
Showing 17 changed files with 903 additions and 384 deletions.
4 changes: 3 additions & 1 deletion benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')
Expand Down
5 changes: 3 additions & 2 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,12 @@ def from_tree(
is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
is_leaf: tp.Callable[[Leaf], bool] = is_tree_node,
map_non_graph_nodes: bool = False,
is_inner: bool | None = None,
ctxtag: str | None = None,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graph.merge_context(ctxtag) as merge_ctx:
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
return jax.tree.map(
lambda x: merge_fn(merge_ctx, (), prefix, x)
if map_non_graph_nodes or is_node_leaf(x)
Expand All @@ -407,7 +408,7 @@ def from_tree(
assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []

with graph.merge_context(ctxtag) as merge_ctx:
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if map_non_graph_nodes or is_node_leaf(leaf):
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
Expand Down
Loading

0 comments on commit 45f8f34

Please sign in to comment.