Skip to content

Commit

Permalink
[nnx] fast jit
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 22, 2024
1 parent 53bde74 commit 6cf5b7d
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 96 deletions.
57 changes: 37 additions & 20 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
Expand All @@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
def __call__(self, x):
return x @ self.w + self.b

class Block(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))

class Count(nnx.Variable):
pass
Expand All @@ -54,11 +63,11 @@ class Count(nnx.Variable):
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Linear(din, dhidden, rngs=rngs)
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Linear(dhidden, dout, rngs=rngs)
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
Expand All @@ -79,18 +88,14 @@ def main(argv):

print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')

if mode not in ['nnx', 'jax']:
raise ValueError(f'Invalid mode: {mode}')

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

if mode == 'nnx':
if mode == 'nnx' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
Expand All @@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch):

if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break
else:

print('### NNX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)

if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def train_step_jax(graphdef, state, batch):
Expand Down Expand Up @@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch):

if step % 1000 == 0:
state, logs = test_step_jax(graphdef, state, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

model, optimizer = nnx.merge(graphdef, state)

total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
print('### JAX ###')
print(f"final loss: {logs['loss']}")
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)


if __name__ == '__main__':
Expand Down
14 changes: 9 additions & 5 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):

class NodeStates(struct.PyTreeNode):
_graphdef: graph.GraphDef[tp.Any] | None
states: tuple[graph.GraphState, ...]
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
metadata: tp.Any = struct.field(pytree_node=False)

@property
Expand All @@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
return self._graphdef

@property
def state(self) -> graph.GraphState:
def state(self) -> graph.GraphState | graph.GraphFlatState:
if len(self.states) != 1:
raise ValueError(
f'Expected exactly one GraphDefState, got {len(self.states)}'
Expand All @@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
def from_split(
cls,
graphdef: graph.GraphDef[tp.Any],
state: graph.GraphState,
state: graph.GraphState | graph.GraphFlatState,
/,
*states: graph.GraphState,
*states: graph.GraphState | graph.GraphFlatState,
metadata: tp.Any = None,
):
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)

@classmethod
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
def from_states(
cls,
state: graph.GraphState | graph.GraphFlatState,
*states: graph.GraphState | graph.GraphFlatState,
):
return cls(_graphdef=None, states=(state, *states), metadata=None)

@classmethod
Expand Down
115 changes: 100 additions & 15 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
CallableProxy,
DelayedAccessor,
)
from flax.nnx.statelib import State
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts, is_key_like
Expand All @@ -53,6 +53,7 @@
StateLeaf = VariableState[tp.Any]
NodeLeaf = Variable[tp.Any]
GraphState = State[Key, StateLeaf]
GraphFlatState = FlatState[StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand Down Expand Up @@ -377,7 +378,9 @@ def _apply(
module = merge(self, state, *states)
fn = accessor(module)
out = fn(*args, **kwargs)
return out, flatten(module)
graphdef, flat_state = flatten(module)
state_ = State.from_flat_path(flat_state)
return out, (graphdef, state_)

return CallableProxy(_apply, accessor) # type: ignore

Expand All @@ -389,7 +392,7 @@ def _apply(

def flatten(
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
) -> tuple[GraphDef[Node], GraphState]:
) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
"""Flattens a graph node into a (graphdef, state) pair.
Args:
Expand All @@ -402,7 +405,7 @@ def flatten(
ref_index = RefMap()
flat_state: list[tuple[PathParts, StateLeaf]] = []
graphdef = _graph_flatten((), ref_index, flat_state, node)
return graphdef, GraphState.from_flat_path(flat_state)
return graphdef, FlatState(flat_state)


def _graph_flatten(
Expand Down Expand Up @@ -811,8 +814,11 @@ def split(
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)
states = tuple(
State.from_flat_path(flat_state) for flat_state in flat_states
)
if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
Expand All @@ -822,6 +828,47 @@ def split(

return graphdef, *states

@tp.overload
def flatten(
self, graph_node: A, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
@tp.overload
def flatten(
self,
graph_node: A,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[
GraphDef[A],
FlatState[VariableState[tp.Any]],
tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
]: ...
def flatten(
self, node: A, *filters: filterlib.Filter
) -> tuple[
GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
]:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, flat_state = flatten(node, self.ref_index)
flat_states = _split_state(flat_state, filters)

if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
)

return graphdef, *flat_states


@contextlib.contextmanager
def split_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -874,6 +921,39 @@ def merge(
)
return node

def unflatten(
self,
graphdef: GraphDef[A],
flat_state: GraphFlatState,
/,
*flat_states: GraphFlatState,
) -> A:
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping
)
else:
# inner merge (2)
index_ref_cache = None

state = FlatState.merge(flat_state, *flat_states).to_nested_state()
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node


@contextlib.contextmanager
def merge_context(ctxtag: str | None = None):
Expand Down Expand Up @@ -1001,9 +1081,11 @@ def split(
filters are passed, a single :class:`State` is returned.
"""
ref_index: RefMap[tp.Any, Index] = RefMap()
graphdef, state = flatten(node, ref_index)
states = _split_state(state, filters)

graphdef, flat_state = flatten(node, ref_index)
states = tuple(
State.from_flat_path(flat_state)
for flat_state in _split_state(flat_state, filters)
)
if self.index_ref is not None and isinstance(graphdef, NodeDef):
index_to_index = compose_mapping(self.index_ref, ref_index)
graphdef = dataclasses.replace(
Expand Down Expand Up @@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
# --------------------------------------------------------

def _split_state(
state: GraphState,
state: FlatState[tp.Any],
filters: tuple[filterlib.Filter, ...],
) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
if not filters:
return (state,)
states = state.split(*filters)
if isinstance(states, State):
if not isinstance(states, tuple):
return (states,)
assert len(states) > 0
return states # type: ignore[return-value]
Expand Down Expand Up @@ -1292,9 +1374,11 @@ def split(
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
filters are passed, a single ``State`` is returned.
"""
graphdef, state = flatten(node)
states = _split_state(state, filters)
return graphdef, *states
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
return graphdef, *states # type: ignore[return-value]


def merge(
graphdef: GraphDef[A],
Expand Down Expand Up @@ -1486,6 +1570,7 @@ def state(
One or more :class:`State` mappings.
"""
_, state = flatten(node)
state = state.to_nested_state()

states: GraphState | tuple[GraphState, ...]
if len(filters) == 0:
Expand Down
8 changes: 8 additions & 0 deletions flax/nnx/reprlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def __nnx_repr__(self):
for key, value in self.items():
yield Attr(repr(key), value)

class SequenceReprMixin(tp.Sequence[A], Representable):
def __nnx_repr__(self):
yield Object(type='', value_sep='', start='[', end=']')

for value in self:
yield Attr('', value)


@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
Expand Down
Loading

0 comments on commit 6cf5b7d

Please sign in to comment.