diff --git a/flax/struct.py b/flax/struct.py index b677df145d..b4da6847a2 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -58,7 +58,7 @@ def dataclass(clz: _T) -> _T: from flax import struct @struct.dataclass - class Model(): + class Model: params: Any # use pytree_node=False to indicate an attribute should not be touched # by Jax transformations. @@ -77,6 +77,30 @@ def __apply__(self, *args): model = Model(params, apply_fn) model_grad = jax.grad(some_loss_fn)(model) + Note that dataclasses have an auto-generated ``__init__`` where + the arguments of the constructor and the attributed of the created + instance match 1:1. This correspondance is what makes these objects + valid containers that work with JAX transformations and + more generally the `jax.tree_util` library. + + Sometimes a "smart constructor" is desired, for example because + some of the attributes can be (optionally) derived from others. + The way to do this with Flax dataclasses is to make a static or + class method that provides the smart constructor. + This way the simple constructor used by `jax.tree_util` is + preserved. Consider the following example:: + + @struct.dataclass + class DirectionAndScaleKernel: + direction: Array + scale: Array + + @classmethod + def create(cls, kernel): + scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) + directin = direction / scale + return cls(direction, scale) + Args: clz: the class that will be transformed by the decorator. Returns: