Skip to content

Commit

Permalink
[nnx] simpllify unflatten
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 23, 2024
1 parent 6cf5b7d commit e9c635d
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 266 deletions.
55 changes: 40 additions & 15 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,44 @@
flags.DEFINE_integer('depth', 5, 'Depth of the model')



class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.list = [
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
nnx.Param(jnp.zeros((dout,))),
]
self.dict = {
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
'b': nnx.Param(jnp.zeros((dout,))),
}
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))

def __call__(self, x):
return x @ self.w + self.b


class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))


class Count(nnx.Variable):
pass


class MLP(nnx.Module):
def __init__(self, depth, *, rngs: nnx.Rngs):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(10, 10, rngs=rngs) for _ in range(depth)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
x = self.linear_out(x)
return x


def main(argv):
Expand All @@ -63,14 +82,15 @@ def main(argv):
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)

#------------------------------------------------------------
# NNX
#------------------------------------------------------------
if mode in ['all', 'nnx']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
pass
Expand All @@ -93,6 +113,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
#------------------------------------------------------------

if mode in ['all', 'jax']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def step_jax(graphdef, state):
return graphdef, state
Expand Down
11 changes: 6 additions & 5 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import jax
from flax import struct
from flax.core import meta
from flax.nnx import spmd
from flax.nnx import graph, spmd
from flax.nnx import traversals
from flax.nnx import variablelib as variableslib
from flax.nnx.module import GraphDef
import typing as tp


Expand Down Expand Up @@ -192,9 +191,11 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
linen_structured = {}
for kp, v in traversals.flatten_mapping(
nnx_attrs,
is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef),
).items():
nnx_attrs,
is_leaf=lambda _, x: isinstance(
x, variableslib.Variable | graph.NodeDef | graph.NodeRef
),
).items():
if isinstance(v, variableslib.Variable):
col_name = variable_type_name(type(v))
else:
Expand Down
46 changes: 32 additions & 14 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,13 @@ class GraphDefState(struct.PyTreeNode):
graphdef: graph.GraphDef[tp.Any] = struct.field(pytree_node=False)
state: graph.GraphState = struct.field(pytree_node=True)

S = tp.TypeVar(
'S', bound=graph.GraphState | graph.GraphFlatState | list[tp.Any]
)

class NodeStates(struct.PyTreeNode):
class NodeStates(struct.PyTreeNode, tp.Generic[S]):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
states: tuple[S, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
Expand All @@ -264,7 +267,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState | graph.GraphFlatState:
def state(self) -> S:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
Expand All @@ -275,18 +278,18 @@ def state(self) -> graph.GraphState | graph.GraphFlatState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState | graph.GraphFlatState,
state: S,
/,
*states: graph.GraphState | graph.GraphFlatState,
*states: S,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
state: S,
*states: S,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

Expand Down Expand Up @@ -319,6 +322,15 @@ def to_tree(
ctxtag: str | None = None,
check_aliasing: bool = True,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graph.split_context(ctxtag) as split_ctx:
return jax.tree.map(
lambda x: split_fn(split_ctx, (), prefix, x)
if map_non_graph_nodes or graph.is_graph_node(x)
else x,
tree,
)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
Expand Down Expand Up @@ -373,6 +385,16 @@ def from_tree(
map_non_graph_nodes: bool = False,
ctxtag: str | None = None,
) -> tp.Any:
if prefix is Missing or prefix is None:
# fast path, no need for prefix broadcasting or consistent aliasing checks
with graph.merge_context(ctxtag) as merge_ctx:
return jax.tree.map(
lambda x: merge_fn(merge_ctx, (), prefix, x)
if map_non_graph_nodes or is_node_leaf(x)
else x,
tree,
is_leaf=is_leaf,
)
leaf_prefixes = broadcast_prefix(
prefix,
tree,
Expand All @@ -387,13 +409,9 @@ def from_tree(

with graph.merge_context(ctxtag) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if is_node_leaf(leaf):
leaf_out = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf_out)
else:
if map_non_graph_nodes:
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)
if map_non_graph_nodes or is_node_leaf(leaf):
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out
Expand Down
Loading

0 comments on commit e9c635d

Please sign in to comment.