diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py new file mode 100644 index 0000000000..9348079e20 --- /dev/null +++ b/flax/nnx/summary.py @@ -0,0 +1,395 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import inspect +import io +import typing as tp +from itertools import groupby +from types import MappingProxyType + +import jax +import rich.console +import rich.table +import rich.text +import yaml +import jax.numpy as jnp + +from flax import nnx +from flax import typing +from flax.nnx import graph, rnglib, statelib, variablelib + +try: + from IPython import get_ipython + + in_ipython = get_ipython() is not None +except ImportError: + in_ipython = False + + +class ObjectInfo(tp.NamedTuple): + path: statelib.PathParts + stats: dict[type[variablelib.Variable], typing.SizeBytes] + + +def _collect_stats( + path: statelib.PathParts, + node: tp.Any, + node_stats: dict[int, ObjectInfo], + object_types: set[type], +): + if not graph.is_node(node) and not isinstance(node, variablelib.Variable): + raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.') + + if id(node) in node_stats: + return + + stats: dict[type[variablelib.Variable], typing.SizeBytes] = {} + node_stats[id(node)] = ObjectInfo(path, stats) + + if isinstance(node, nnx.Object): + node._nnx_tabulate_id = id(node) # type: ignore + object_types.add(type(node)) + + if isinstance(node, variablelib.Variable): + var_type = type(node) + if issubclass(var_type, nnx.RngState): + var_type = nnx.RngState + size_bytes = typing.value_stats(node.value) + if size_bytes: + stats[var_type] = size_bytes + + else: + node_dict = graph.get_node_impl(node).node_dict(node) + for key, value in node_dict.items(): + if id(value) in node_stats: + continue + if graph.is_node(value) or isinstance(value, variablelib.Variable): + _collect_stats((*path, key), value, node_stats, object_types) + child_info = node_stats[id(value)] + for var_type, size_bytes in child_info.stats.items(): + if var_type in stats: + stats[var_type] += size_bytes + else: + stats[var_type] = size_bytes + + +class CallInfo(tp.NamedTuple): + object_id: int + inputs: tp.Any + outputs: tp.Any + + +def get_method_wrapper(method: tp.Callable) -> tp.Callable: + @functools.wraps(method) + def method_wrapper(obj, *args, **kwargs): + return method(obj, *args, **kwargs) + + return method_wrapper + + +def _call_obj(object_types: set[type], obj, *args, **kwargs): + original_methods: dict[type, dict[str, tp.Callable]] = {} + for obj_type in object_types: + methods: dict[str, tp.Callable] = {} + for name, method in inspect.getmembers(obj_type, inspect.isfunction): + if not name.startswith('_') or name == '__call__': + methods[name] = method + method_wrapper = get_method_wrapper(method) + setattr(obj_type, name, method_wrapper) + + original_methods[obj_type] = methods + + +def tabulate( + obj, + depth: int | None = None, + table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), + console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), +) -> str: + """Creates a summary of the graph object represented as a table. + + The table summarizes the object's state and metadata. The table is + structured as follows: + + - The first column represents the path of the object in the graph. + - The second column represents the type of the object. + - The following columns provide information about the object's state, + grouped by Variable types. + + Example: + + >>> from flax import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.bn = nnx.BatchNorm(dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.2, rngs=rngs) + ... + ... def __call__(self, x): + ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) + ... + >>> class Foo(nnx.Module): + ... def __init__(self, rngs: nnx.Rngs): + ... self.block1 = Block(32, 128, rngs=rngs) + ... self.block2 = Block(128, 10, rngs=rngs) + ... + ... def __call__(self, x): + ... return self.block2(self.block1(x)) + ... + >>> foo = Foo(nnx.Rngs(0)) + >>> # print(nnx.tabulate(foo)) + + Foo Summary + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ + ┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ + │ block1/bn │ BatchNorm │ mean: float32[128] │ bias: float32[128] │ │ + │ │ │ var: float32[128] │ scale: float32[128] │ │ + │ │ │ │ │ │ + │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block1/dropout/rngs/default │ RngStream │ │ │ count: │ + │ │ │ │ │ value: uint32[] │ + │ │ │ │ │ tag: default │ + │ │ │ │ │ key: │ + │ │ │ │ │ value: key[] │ + │ │ │ │ │ tag: default │ + │ │ │ │ │ │ + │ │ │ │ │ 2 (12 B) │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block1/linear │ Linear │ │ bias: float32[128] │ │ + │ │ │ │ kernel: float32[32,128] │ │ + │ │ │ │ │ │ + │ │ │ │ 4,224 (16.9 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block2/bn │ BatchNorm │ mean: float32[10] │ bias: float32[10] │ │ + │ │ │ var: float32[10] │ scale: float32[10] │ │ + │ │ │ │ │ │ + │ │ │ 20 (80 B) │ 20 (80 B) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ block2/linear │ Linear │ │ bias: float32[10] │ │ + │ │ │ │ kernel: float32[128,10] │ │ + │ │ │ │ │ │ + │ │ │ │ 1,290 (5.2 KB) │ │ + ├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤ + │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ + └─────────────────────────────┴───────────┴────────────────────┴─────────────────────────┴─────────────────────┘ + + Total Parameters: 6,068 (24.3 KB) + + + Note that ``block2/dropout`` is not shown in the table because it shares the + same ``RngState`` with ``block1/dropout``. + + Args: + obj: A object to summarize. It can a pytree or a graph objects + such as nnx.Module or nnx.Optimizer. + depth: The depth of the table. + table_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table`` constructor. + column_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table.add_column`` when adding columns to + the table. + console_kwargs: An optional dictionary with additional keyword arguments + that are passed to `rich.console.Console` when rendering the table. + Default arguments are ``'force_terminal': True``, and ``'force_jupyter'`` + is set to ``True`` if the code is running in a Jupyter notebook, otherwise + it is set to ``False``. + + Returns: + A string summarizing the object. + """ + _console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython} + _console_kwargs.update(console_kwargs) + state = graph.state(obj) + graph_map = dict(graph.iter_graph(obj)) + flat_state = sorted(state.flat_state()) + + def key_fn( + path_state: tuple[graph.PathParts, variablelib.VariableState[tp.Any]], + ): + path, _ = path_state + if depth is None or len(path) <= depth: + return path[:-1] + else: + return path[:depth] + + rows = groupby(flat_state, key_fn) + table = sorted((path, list(flat_states)) for path, flat_states in rows) + + state_types_set = {variable_state.type for _, variable_state in flat_state} + # replace RngKey and RngCount with RngState + if rnglib.RngKey in state_types_set: + state_types_set.remove(rnglib.RngKey) + state_types_set.add(rnglib.RngState) + if rnglib.RngCount in state_types_set: + state_types_set.remove(rnglib.RngCount) + state_types_set.add(rnglib.RngState) + # sort based on MRO + state_types = _sort_variable_types(state_types_set) + + rich_table = rich.table.Table( + show_header=True, + show_lines=True, + show_footer=True, + title=f'{type(obj).__name__} Summary', + **table_kwargs, + ) + + rich_table.add_column('path', **column_kwargs) + rich_table.add_column('type', **column_kwargs) + + for state_type in state_types: + rich_table.add_column(state_type.__name__, **column_kwargs) + + for key_path, row_states in table: + row: list[str] = [] + node = graph_map[key_path] + type_state_groups = variablelib.split_flat_state(row_states, state_types) + path_str = '/'.join(map(str, key_path)) + node_type = type(node).__name__ + row.extend([path_str, node_type]) + + for state_type, type_path_and_states in zip(state_types, type_state_groups): + attributes = {} + for state_path, variable_state in type_path_and_states: + if len(state_path) == len(key_path) + 1: + name = str(state_path[-1]) + value = variable_state.value + value_repr = _render_array(value) if _has_shape_dtype(value) else '' + metadata = variable_state.get_metadata() + + if metadata: + attributes[name] = { + 'value': value_repr, + **metadata, + } + elif value_repr: + attributes[name] = value_repr + + if attributes: + col_repr = _as_yaml_str(attributes) + '\n\n' + else: + col_repr = '' + + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + col_repr += f'[bold]{_size_and_bytes_repr(size_, bytes_)}[/bold]' + row.append(col_repr) + + rich_table.add_row(*row) + + rich_table.columns[1].footer = rich.text.Text.from_markup( + 'Total', justify='right' + ) + flat_states = variablelib.split_flat_state(flat_state, state_types) + + for i, (state_type, type_path_and_states) in enumerate( + zip(state_types, flat_states) + ): + type_states = [state for _, state in type_path_and_states] + size_, bytes_ = _size_and_bytes(type_states) + size_repr = _size_and_bytes_repr(size_, bytes_) + rich_table.columns[i + 2].footer = size_repr + + rich_table.caption_style = 'bold' + rich_table.caption = ( + f'\nTotal Parameters: {_size_and_bytes_repr(*_size_and_bytes(state))}' + ) + + return _get_rich_repr(rich_table, _console_kwargs) + + +def _get_rich_repr(obj, console_kwargs): + f = io.StringIO() + console = rich.console.Console(file=f, **console_kwargs) + console.print(obj) + return f.getvalue() + + +def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: + leaves = jax.tree.leaves(pytree) + size = sum(x.size for x in leaves if hasattr(x, 'size')) + num_bytes = sum( + x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') + ) + return size, num_bytes + + +def _size_and_bytes_repr(size: int, num_bytes: int) -> str: + if not size: + return '' + bytes_repr = _bytes_repr(num_bytes) + return f'{size:,} [dim]({bytes_repr})[/dim]' + + +def _bytes_repr(num_bytes): + count, units = ( + (f'{num_bytes / 1e9:,.1f}', 'GB') + if num_bytes > 1e9 + else (f'{num_bytes / 1e6:,.1f}', 'MB') + if num_bytes > 1e6 + else (f'{num_bytes / 1e3:,.1f}', 'KB') + if num_bytes > 1e3 + else (f'{num_bytes:,}', 'B') + ) + + return f'{count} {units}' + + +def _has_shape_dtype(value): + return hasattr(value, 'shape') and hasattr(value, 'dtype') + + +def _normalize_values(x): + if isinstance(x, type): + return f'type[{x.__name__}]' + else: + return x + + +def _as_yaml_str(value) -> str: + if (hasattr(value, '__len__') and len(value) == 0) or value is None: + return '' + + value = jax.tree.map(_normalize_values, value) + + file = io.StringIO() + yaml.safe_dump( + value, + file, + default_flow_style=False, + indent=2, + sort_keys=False, + explicit_end=False, + ) + return file.getvalue().replace('\n...', '').replace("'", '').strip() + + +def _render_array(x): + shape, dtype = jnp.shape(x), jnp.result_type(x) + shape_repr = ','.join(str(x) for x in shape) + return f'[dim]{dtype}[/dim][{shape_repr}]' + + +def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: + def _variable_parents_count(t: type): + return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) + + type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} + return sorted(types, key=lambda t: type_sort_key[t])