Skip to content

Commit

Permalink
[nnx] add tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 21, 2025
1 parent e4418e2 commit f0dcac7
Show file tree
Hide file tree
Showing 8 changed files with 655 additions and 28 deletions.
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
2 changes: 1 addition & 1 deletion flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def check_consistent_aliasing(
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value) or isinstance(value, graph.Variable):
if isinstance(value, Object):
value.check_valid_context(
value._check_valid_context(
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
)
if isinstance(value, graph.Variable):
Expand Down
12 changes: 8 additions & 4 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from flax import nnx
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import SizeBytes, value_stats
from flax.typing import SizeBytes

G = tp.TypeVar('G', bound='Object')

Expand All @@ -55,7 +55,7 @@ def _collect_stats(
var_type = type(node)
if issubclass(var_type, nnx.RngState):
var_type = nnx.RngState
size_bytes = value_stats(node.value)
size_bytes = SizeBytes.from_any(node.value)
if size_bytes:
stats[var_type] = size_bytes

Expand Down Expand Up @@ -134,6 +134,10 @@ class Array(reprlib.Representable):
shape: tp.Tuple[int, ...]
dtype: tp.Any

@staticmethod
def from_array(array: jax.Array | np.ndarray) -> Array:
return Array(array.shape, array.dtype)

def __nnx_repr__(self):
yield reprlib.Object(type='Array', same_line=True)
yield reprlib.Attr('shape', self.shape)
Expand Down Expand Up @@ -163,12 +167,12 @@ def __setattr__(self, name: str, value: Any) -> None:
self._setattr(name, value)

def _setattr(self, name: str, value: tp.Any) -> None:
self.check_valid_context(
self._check_valid_context(
lambda: f"Cannot mutate '{type(self).__name__}' from different trace level"
)
object.__setattr__(self, name, value)

def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
if not self._object__state.trace_state.is_valid():
raise errors.TraceContextError(error_msg())

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __post_init__(self):
raise TypeError(f'key must be a jax.Array, got {type(self.key)}')

def __call__(self) -> jax.Array:
self.check_valid_context(
self._check_valid_context(
lambda: 'Cannot call RngStream from a different trace level'
)
key = jax.random.fold_in(self.key.value, self.count.value)
Expand Down
552 changes: 552 additions & 0 deletions flax/nnx/summary.py

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

from flax import errors
from flax.nnx import filterlib, reprlib, tracers, visualization
from flax.typing import (
Missing,
PathParts,
value_stats,
)
from flax.typing import Missing, PathParts, SizeBytes
import jax.tree_util as jtu

A = tp.TypeVar('A')
Expand Down Expand Up @@ -317,7 +313,7 @@ def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)

def __nnx_repr__(self):
stats = value_stats(self.value)
stats = SizeBytes.from_any(self.value)
if stats:
comment = f' # {stats}'
else:
Expand All @@ -329,7 +325,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(name, repr(value))

def __treescope_repr__(self, path, subtree_renderer):
size_bytes = value_stats(self.value)
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
Expand Down Expand Up @@ -784,7 +780,7 @@ def __delattr__(self, name: str) -> None:
del self._var_metadata[name]

def __nnx_repr__(self):
stats = value_stats(self.value)
stats = SizeBytes.from_any(self.value)
if stats:
comment = f' # {stats}'
else:
Expand All @@ -798,7 +794,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(name, value)

def __treescope_repr__(self, path, subtree_renderer):
size_bytes = value_stats(self.value)
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
Expand Down
26 changes: 13 additions & 13 deletions flax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,19 @@ class SizeBytes: # type: ignore[misc]
size: int
bytes: int

@staticmethod
def from_array(x: ShapeDtype) -> SizeBytes:
@classmethod
def from_array(cls, x: ShapeDtype):
size = int(np.prod(x.shape))
dtype: jnp.dtype
if isinstance(x.dtype, str):
dtype = jnp.dtype(x.dtype)
else:
dtype = x.dtype # type: ignore
bytes = size * dtype.itemsize # type: ignore
return SizeBytes(size, bytes)
return cls(size, bytes)

def __add__(self, other: SizeBytes) -> SizeBytes:
return SizeBytes(self.size + other.size, self.bytes + other.bytes)
def __add__(self, other: SizeBytes):
return type(self)(self.size + other.size, self.bytes + other.bytes)

def __bool__(self) -> bool:
return bool(self.size)
Expand All @@ -215,12 +215,12 @@ def __repr__(self) -> str:
bytes_repr = _bytes_repr(self.bytes)
return f'{self.size:,} ({bytes_repr})'

@classmethod
def from_any(cls, x):
leaves = jax.tree.leaves(x)
size_bytes = cls(0, 0)
for leaf in leaves:
if has_shape_dtype(leaf):
size_bytes += cls.from_array(leaf)

def value_stats(x):
leaves = jax.tree.leaves(x)
size_bytes = SizeBytes(0, 0)
for leaf in leaves:
if has_shape_dtype(leaf):
size_bytes += SizeBytes.from_array(leaf)

return size_bytes
return size_bytes
74 changes: 74 additions & 0 deletions tests/nnx/summary_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from absl.testing import absltest

from flax import nnx

CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000)


class SummaryTest(absltest.TestCase):
def test_tabulate(self):
class Block(nnx.Module):
def __init__(self, din, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)

def forward(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))

class Foo(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.block1 = Block(32, 128, rngs=rngs)
self.block2 = Block(128, 10, rngs=rngs)

def __call__(self, x):
return self.block2.forward(self.block1.forward(x))

foo = Foo(nnx.Rngs(0))
x = jnp.ones((1, 32))
table_repr = nnx.tabulate(
foo, x, console_kwargs=CONSOLE_TEST_KWARGS
).splitlines()

self.assertIn('Foo Summary', table_repr[0])
self.assertIn('path', table_repr[2])
self.assertIn('type', table_repr[2])
self.assertIn('BatchStat', table_repr[2])
self.assertIn('Param', table_repr[2])
self.assertIn('block1/forward', table_repr[6])
self.assertIn('Block', table_repr[6])
self.assertIn('block1/linear', table_repr[8])
self.assertIn('Linear', table_repr[8])
self.assertIn('block1/bn', table_repr[13])
self.assertIn('BatchNorm', table_repr[13])
self.assertIn('block1/dropout', table_repr[18])
self.assertIn('Dropout', table_repr[18])
self.assertIn('block2/forward', table_repr[20])
self.assertIn('Block', table_repr[20])
self.assertIn('block2/linear', table_repr[22])
self.assertIn('Linear', table_repr[22])
self.assertIn('block2/bn', table_repr[27])
self.assertIn('BatchNorm', table_repr[27])
self.assertIn('block2/dropout', table_repr[32])
self.assertIn('Dropout', table_repr[32])

self.assertIn('Total', table_repr[34])
self.assertIn('276 (1.1 KB)', table_repr[34])
self.assertIn('5,790 (23.2 KB)', table_repr[34])
self.assertIn('2 (12 B)', table_repr[34])
self.assertIn('Total Parameters: 6,068 (24.3 KB)', table_repr[37])

0 comments on commit f0dcac7

Please sign in to comment.