Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] flatten returns FlatState #4458

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
Expand All @@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
def __call__(self, x):
return x @ self.w + self.b

class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))

class Count(nnx.Variable):
pass
Expand All @@ -54,11 +63,11 @@ class Count(nnx.Variable):
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Linear(din, dhidden, rngs=rngs)
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Linear(dhidden, dout, rngs=rngs)
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
Expand All @@ -79,18 +88,14 @@ def main(argv):

print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')

if mode not in ['nnx', 'jax']:
raise ValueError(f'Invalid mode: {mode}')

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

if mode == 'nnx':
if mode == 'nnx' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
Expand All @@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch):

if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break
else:

print('### NNX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)

if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def train_step_jax(graphdef, state, batch):
Expand Down Expand Up @@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch):

if step % 1000 == 0:
state, logs = test_step_jax(graphdef, state, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

model, optimizer = nnx.merge(graphdef, state)

total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
print('### JAX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)


if __name__ == '__main__':
Expand Down
14 changes: 9 additions & 5 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):

class NodeStates(struct.PyTreeNode):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState, ...]
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
Expand All @@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState:
def state(self) -> graph.GraphState | graph.GraphFlatState:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
Expand All @@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState,
state: graph.GraphState | graph.GraphFlatState,
/,
*states: graph.GraphState,
*states: graph.GraphState | graph.GraphFlatState,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

@classmethod
Expand Down
115 changes: 100 additions & 15 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import State
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
Expand All @@ -53,6 +53,7 @@
StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand Down Expand Up @@ -377,7 +378,9 @@ def _apply(
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)
graphdef, flat_state = flatten(module)
state_ = State.from_flat_path(flat_state)
return out, (graphdef, state_)

return CallableProxy(_apply, accessor) # type: ignore

Expand All @@ -389,7 +392,7 @@ def _apply(

def flatten(
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
) -> tuple[GraphDef[Node], GraphState]:
) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
"""Flattens a graph node into a (graphdef, state) pair.

Args:
Expand All @@ -402,7 +405,7 @@ def flatten(
ref_index = RefMap()
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)
return graphdef, FlatState(flat_state)


def _graph_flatten(
Expand Down Expand Up @@ -811,8 +814,11 @@ def split(
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)
states = tuple(
State.from_flat_path(flat_state) for flat_state in flat_states
)
if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
Expand All @@ -822,6 +828,47 @@ def split(

return graphdef, *states

@tp.overload
def flatten(
self, graph_node: A, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self,
graph_node: A,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[
GraphDef[A],
FlatState[VariableState[tp.Any]],
tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
]: ...
def flatten(
self, node: A, *filters: filterlib.Filter
) -> tuple[
GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
]:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)

if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *flat_states


@contextlib.contextmanager
def split_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -874,6 +921,39 @@ def merge(
)
return node

def unflatten(
self,
graphdef: GraphDef[A],
flat_state: GraphFlatState,
/,
*flat_states: GraphFlatState,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping
)
else:
# inner merge (2)
index_ref_cache = None

state = FlatState.merge(flat_state, *flat_states).to_nested_state()
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node


@contextlib.contextmanager
def merge_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -1001,9 +1081,11 @@ def split(
filters are passed, a single :class:`State` is returned.
"""
ref_index: RefMap[tp.Any, Index] = RefMap()
graphdef, state = flatten(node, ref_index)
states = _split_state(state, filters)

graphdef, flat_state = flatten(node, ref_index)
states = tuple(
State.from_flat_path(flat_state)
for flat_state in _split_state(flat_state, filters)
)
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
Expand Down Expand Up @@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
# --------------------------------------------------------

def _split_state(
state: GraphState,
state: FlatState[tp.Any],
filters: tuple[filterlib.Filter, ...],
) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
if not filters:
return (state,)
states = state.split(*filters)
if isinstance(states, State):
if not isinstance(states, tuple):
return (states,)
assert len(states) > 0
return states # type: ignore[return-value]
Expand Down Expand Up @@ -1292,9 +1374,11 @@ def split(
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
"""
graphdef, state = flatten(node)
states = _split_state(state, filters)
return graphdef, *states
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return graphdef, *states # type: ignore[return-value]


def merge(
graphdef: GraphDef[A],
Expand Down Expand Up @@ -1486,6 +1570,7 @@ def state(
One or more :class:`State` mappings.
"""
_, state = flatten(node)
state = state.to_nested_state()

states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
Expand Down
8 changes: 8 additions & 0 deletions flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def __nnx_repr__(self):
for key, value in self.items():
yield Attr(repr(key), value)

class SequenceReprMixin(tp.Sequence[A], Representable):
def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self:
yield Attr('', value)


@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
Expand Down
Loading
Loading