Skip to content

Commit

Permalink
Merge pull request #4438 from google:nnx-tabulate
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713824899
  • Loading branch information
Flax Authors committed Jan 10, 2025
2 parents 9ebdbdc + c59bc1a commit adbad95
Show file tree
Hide file tree
Showing 24 changed files with 834 additions and 475 deletions.
22 changes: 11 additions & 11 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

235 changes: 59 additions & 176 deletions docs_nnx/mnist_tutorial.ipynb

Large diffs are not rendered by default.

49 changes: 16 additions & 33 deletions docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Let's put the CNN model to the test! Here, you’ll perform a forward pass with
import jax.numpy as jnp # JAX NumPy
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)
y
```

## 4. Create the optimizer and define some metrics
Expand Down Expand Up @@ -179,6 +179,9 @@ the accuracy) during the process. Typically this leads to the model achieving ar
```{code-cell} ipython3
:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87
from IPython.display import clear_output
import matplotlib.pyplot as plt
metrics_history = {
'train_loss': [],
'train_accuracy': [],
Expand Down Expand Up @@ -208,40 +211,20 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
metrics_history[f'test_{metric}'].append(value)
metrics.reset() # Reset the metrics for the next training epoch.
print(
f"[train] step: {step}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
)
print(
f"[test] step: {step}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
)
```

## 7. Visualize the metrics

With Matplotlib, you can create plots for the loss and the accuracy:

```{code-cell} ipython3
:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac
import matplotlib.pyplot as plt # Visualization
# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
clear_output(wait=True)
# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()
```

## 10. Perform inference on the test set
## 7. Perform inference on the test set

Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.

Expand Down
129 changes: 88 additions & 41 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

15 changes: 2 additions & 13 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,7 @@ jupytext:

Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.

In this guide you will learn about:

- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
- Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
- Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
- Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
- [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
- [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
- [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
- Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
To begin, install Flax with `pip` and import necessary dependencies:

## Setup

Expand Down Expand Up @@ -106,7 +95,7 @@ to handle them, as demonstrated in later sections of this guide.

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.

```{code-cell} ipython3
class MLP(nnx.Module):
Expand Down
18 changes: 11 additions & 7 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
LogicalNames,
)

try:
from IPython import get_ipython

in_ipython = get_ipython() is not None
except ImportError:
in_ipython = False


class _ValueRepresentation(ABC):
"""A class that represents a value in the summary table."""
Expand Down Expand Up @@ -242,11 +249,6 @@ def tabulate(
Total Parameters: 50 (200 B)
**Note**: rows order in the table does not represent execution order,
instead it aligns with the order of keys in `variables` which are sorted
alphabetically.
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
Expand All @@ -267,7 +269,9 @@ def tabulate(
mutable.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to `rich.console.Console` when rendering the table.
Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
is set to ``True`` if the code is running in a Jupyter notebook, otherwise
it is set to ``False``.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
Expand Down Expand Up @@ -564,7 +568,7 @@ def _render_table(
non_params_cols: list[str],
) -> str:
"""A function that renders a Table to a string representation using rich."""
console_kwargs = {'force_terminal': True, 'force_jupyter': False}
console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
if console_extras is not None:
console_kwargs.update(console_extras)

Expand Down
4 changes: 3 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')

def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
def filters_to_predicates(
filters: tp.Sequence[Filter],
) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
Expand Down
14 changes: 5 additions & 9 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import typing_extensions as tpe

from flax.nnx import filterlib, reprlib
from flax.nnx import filterlib, reprlib, visualization
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
Expand Down Expand Up @@ -63,7 +63,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)


class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin):
"""A mapping that uses object id as the hash for the keys."""

def __init__(
Expand Down Expand Up @@ -248,8 +248,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('index', self.index)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={'type': self.type, 'index': self.index},
path=path,
Expand All @@ -272,9 +271,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
Expand Down Expand Up @@ -353,8 +350,7 @@ def __nnx_repr__(self):
)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
return treescope.repr_lib.render_object_constructor(
return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
Expand Down
17 changes: 0 additions & 17 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,23 +403,6 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=treescope.formatting_util.color_from_string(
type(self).__qualname__
)
)

# -------------------------
# Pytree Definition
# -------------------------
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ class Embed(Module):
>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
'embedding': VariableState(
'embedding': VariableState( # 15 (60 B)
type=Param,
value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
[ 0.01070483, 0.27923733, 1.7487359 ],
Expand Down
10 changes: 5 additions & 5 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,11 @@ class LayerNorm(Module):
>>> nnx.state(layer)
State({
'bias': VariableState(
'bias': VariableState( # 6 (24 B)
type=Param,
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
),
'scale': VariableState(
'scale': VariableState( # 6 (24 B)
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
Expand Down Expand Up @@ -531,7 +531,7 @@ class RMSNorm(Module):
>>> nnx.state(layer)
State({
'scale': VariableState(
'scale': VariableState( # 6 (24 B)
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
Expand Down Expand Up @@ -655,11 +655,11 @@ class GroupNorm(Module):
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
'bias': VariableState(
'bias': VariableState( # 6 (24 B)
type=Param,
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
),
'scale': VariableState(
'scale': VariableState( # 6 (24 B)
type=Param,
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flax.nnx.module import Module, first_from


@dataclasses.dataclass
@dataclasses.dataclass(repr=False)
class Dropout(Module):
"""Create a dropout layer.
Expand Down
Loading

0 comments on commit adbad95

Please sign in to comment.