Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] simplify unflatten #4462

Draft
wants to merge 1 commit into
base: nnx-fast-jit
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,44 @@
flags.DEFINE_integer('depth', 5, 'Depth of the model')



class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.list = [
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
nnx.Param(jnp.zeros((dout,))),
]
self.dict = {
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
'b': nnx.Param(jnp.zeros((dout,))),
}
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))

def __call__(self, x):
return x @ self.w + self.b


class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))


class Count(nnx.Variable):
pass


class MLP(nnx.Module):
def __init__(self, depth, *, rngs: nnx.Rngs):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(10, 10, rngs=rngs) for _ in range(depth)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
x = self.linear_out(x)
return x


def main(argv):
Expand All @@ -63,14 +82,15 @@ def main(argv):
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)

#------------------------------------------------------------
# NNX
#------------------------------------------------------------
if mode in ['all', 'nnx']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass
Expand All @@ -93,6 +113,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
#------------------------------------------------------------

if mode in ['all', 'jax']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def step_jax(graphdef, state):
return graphdef, state
Expand Down
11 changes: 6 additions & 5 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import jax
from flax import struct
from flax.core import meta
from flax.nnx import spmd
from flax.nnx import graph, spmd
from flax.nnx import traversals
from flax.nnx import variablelib as variableslib
from flax.nnx.module import GraphDef
import typing as tp


Expand Down Expand Up @@ -192,9 +191,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
linen_structured = {}
for kp, v in traversals.flatten_mapping(
nnx_attrs,
is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef),
).items():
nnx_attrs,
is_leaf=lambda _, x: isinstance(
x, variableslib.Variable | graph.NodeDef | graph.NodeRef
),
).items():
if isinstance(v, variableslib.Variable):
col_name = variable_type_name(type(v))
else:
Expand Down
46 changes: 32 additions & 14 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,13 @@ class GraphDefState(struct.PyTreeNode):
graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: graph.GraphState = struct.field(pytree_node=True)

S = tp.TypeVar(
'S', bound=graph.GraphState | graph.GraphFlatState | list[tp.Any]
)

class NodeStates(struct.PyTreeNode):
class NodeStates(struct.PyTreeNode, tp.Generic[S]):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
states: tuple[S, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
Expand All @@ -264,7 +267,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState | graph.GraphFlatState:
def state(self) -> S:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
Expand All @@ -275,18 +278,18 @@ def state(self) -> graph.GraphState | graph.GraphFlatState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState | graph.GraphFlatState,
state: S,
/,
*states: graph.GraphState | graph.GraphFlatState,
*states: S,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
state: S,
*states: S,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

Expand Down Expand Up @@ -319,6 +322,15 @@ def to_tree(
ctxtag: str | None = None,
check_aliasing: bool = True,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graph.split_context(ctxtag) as split_ctx:
return jax.tree.map(
lambda x: split_fn(split_ctx, (), prefix, x)
if map_non_graph_nodes or graph.is_graph_node(x)
else x,
tree,
)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
Expand Down Expand Up @@ -373,6 +385,16 @@ def from_tree(
map_non_graph_nodes: bool = False,
ctxtag: str | None = None,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graph.merge_context(ctxtag) as merge_ctx:
return jax.tree.map(
lambda x: merge_fn(merge_ctx, (), prefix, x)
if map_non_graph_nodes or is_node_leaf(x)
else x,
tree,
is_leaf=is_leaf,
)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
Expand All @@ -387,13 +409,9 @@ def from_tree(

with graph.merge_context(ctxtag) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if is_node_leaf(leaf):
leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf_out)
else:
if map_non_graph_nodes:
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)
if map_non_graph_nodes or is_node_leaf(leaf):
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out
Expand Down
Loading