Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add tabulate #4493

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs_nnx/api_reference/flax.nnx/summary.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
summary
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: tabulate
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
2 changes: 1 addition & 1 deletion flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def check_consistent_aliasing(
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value) or isinstance(value, graph.Variable):
if isinstance(value, Object):
value.check_valid_context(
value._check_valid_context(
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
)
if isinstance(value, graph.Variable):
Expand Down
12 changes: 8 additions & 4 deletions flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
visualization,
)
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import SizeBytes, value_stats
from flax.typing import SizeBytes

G = tp.TypeVar('G', bound='Object')

Expand All @@ -55,7 +55,7 @@ def _collect_stats(
var_type = type(node)
if issubclass(var_type, nnx.RngState):
var_type = nnx.RngState
size_bytes = value_stats(node.value)
size_bytes = SizeBytes.from_any(node.value)
if size_bytes:
stats[var_type] = size_bytes

Expand Down Expand Up @@ -134,6 +134,10 @@ class Array(reprlib.Representable):
shape: tp.Tuple[int, ...]
dtype: tp.Any

@staticmethod
def from_array(array: jax.Array | np.ndarray) -> Array:
return Array(array.shape, array.dtype)

def __nnx_repr__(self):
yield reprlib.Object(type='Array', same_line=True)
yield reprlib.Attr('shape', self.shape)
Expand Down Expand Up @@ -165,12 +169,12 @@ def __setattr__(self, name: str, value: Any) -> None:
self._setattr(name, value)

def _setattr(self, name: str, value: tp.Any) -> None:
self.check_valid_context(
self._check_valid_context(
lambda: f"Cannot mutate '{type(self).__name__}' from different trace level"
)
object.__setattr__(self, name, value)

def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
if not self._object__state.trace_state.is_valid():
raise errors.TraceContextError(error_msg())

Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __post_init__(self):
raise TypeError(f'key must be a jax.Array, got {type(self.key)}')

def __call__(self) -> jax.Array:
self.check_valid_context(
self._check_valid_context(
lambda: 'Cannot call RngStream from a different trace level'
)
key = jax.random.fold_in(self.key.value, self.count.value)
Expand Down
553 changes: 553 additions & 0 deletions flax/nnx/summary.py

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

from flax import errors
from flax.nnx import filterlib, reprlib, tracers, visualization
from flax.typing import (
Missing,
PathParts,
value_stats,
)
from flax.typing import Missing, PathParts, SizeBytes
import jax.tree_util as jtu

A = tp.TypeVar('A')
Expand Down Expand Up @@ -317,7 +313,7 @@ def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)

def __nnx_repr__(self):
stats = value_stats(self.value)
stats = SizeBytes.from_any(self.value)
if stats:
comment = f' # {stats}'
else:
Expand All @@ -329,7 +325,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(name, repr(value))

def __treescope_repr__(self, path, subtree_renderer):
size_bytes = value_stats(self.value)
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
Expand Down Expand Up @@ -784,7 +780,7 @@ def __delattr__(self, name: str) -> None:
del self._var_metadata[name]

def __nnx_repr__(self):
stats = value_stats(self.value)
stats = SizeBytes.from_any(self.value)
if stats:
comment = f' # {stats}'
else:
Expand All @@ -798,7 +794,7 @@ def __nnx_repr__(self):
yield reprlib.Attr(name, value)

def __treescope_repr__(self, path, subtree_renderer):
size_bytes = value_stats(self.value)
size_bytes = SizeBytes.from_any(self.value)
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
Expand Down
26 changes: 13 additions & 13 deletions flax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,19 @@ class SizeBytes: # type: ignore[misc]
size: int
bytes: int

@staticmethod
def from_array(x: ShapeDtype) -> SizeBytes:
@classmethod
def from_array(cls, x: ShapeDtype):
size = int(np.prod(x.shape))
dtype: jnp.dtype
if isinstance(x.dtype, str):
dtype = jnp.dtype(x.dtype)
else:
dtype = x.dtype # type: ignore
bytes = size * dtype.itemsize # type: ignore
return SizeBytes(size, bytes)
return cls(size, bytes)

def __add__(self, other: SizeBytes) -> SizeBytes:
return SizeBytes(self.size + other.size, self.bytes + other.bytes)
def __add__(self, other: SizeBytes):
return type(self)(self.size + other.size, self.bytes + other.bytes)

def __bool__(self) -> bool:
return bool(self.size)
Expand All @@ -215,12 +215,12 @@ def __repr__(self) -> str:
bytes_repr = _bytes_repr(self.bytes)
return f'{self.size:,} ({bytes_repr})'

@classmethod
def from_any(cls, x):
leaves = jax.tree.leaves(x)
size_bytes = cls(0, 0)
for leaf in leaves:
if has_shape_dtype(leaf):
size_bytes += cls.from_array(leaf)

def value_stats(x):
leaves = jax.tree.leaves(x)
size_bytes = SizeBytes(0, 0)
for leaf in leaves:
if has_shape_dtype(leaf):
size_bytes += SizeBytes.from_array(leaf)

return size_bytes
return size_bytes
74 changes: 74 additions & 0 deletions tests/nnx/summary_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from absl.testing import absltest

from flax import nnx

CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000)


class SummaryTest(absltest.TestCase):
def test_tabulate(self):
class Block(nnx.Module):
def __init__(self, din, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)

def forward(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))

class Foo(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.block1 = Block(32, 128, rngs=rngs)
self.block2 = Block(128, 10, rngs=rngs)

def __call__(self, x):
return self.block2.forward(self.block1.forward(x))

foo = Foo(nnx.Rngs(0))
x = jnp.ones((1, 32))
table_repr = nnx.tabulate(
foo, x, console_kwargs=CONSOLE_TEST_KWARGS
).splitlines()

self.assertIn('Foo Summary', table_repr[0])
self.assertIn('path', table_repr[2])
self.assertIn('type', table_repr[2])
self.assertIn('BatchStat', table_repr[2])
self.assertIn('Param', table_repr[2])
self.assertIn('block1/forward', table_repr[6])
self.assertIn('Block', table_repr[6])
self.assertIn('block1/linear', table_repr[8])
self.assertIn('Linear', table_repr[8])
self.assertIn('block1/bn', table_repr[13])
self.assertIn('BatchNorm', table_repr[13])
self.assertIn('block1/dropout', table_repr[18])
self.assertIn('Dropout', table_repr[18])
self.assertIn('block2/forward', table_repr[20])
self.assertIn('Block', table_repr[20])
self.assertIn('block2/linear', table_repr[22])
self.assertIn('Linear', table_repr[22])
self.assertIn('block2/bn', table_repr[27])
self.assertIn('BatchNorm', table_repr[27])
self.assertIn('block2/dropout', table_repr[32])
self.assertIn('Dropout', table_repr[32])

self.assertIn('Total', table_repr[34])
self.assertIn('276 (1.1 KB)', table_repr[34])
self.assertIn('5,790 (23.2 KB)', table_repr[34])
self.assertIn('2 (12 B)', table_repr[34])
self.assertIn('Total Parameters: 6,068 (24.3 KB)', table_repr[37])
Loading