From 2546724c148e54afd5e60004e0ff61f9455bc901 Mon Sep 17 00:00:00 2001 From: jheek Date: Wed, 20 Oct 2021 11:57:16 +0000 Subject: [PATCH 1/2] Add "smart constructor" example to datataclass docstring --- flax/struct.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/flax/struct.py b/flax/struct.py index b677df145d..c95d03fb9b 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -58,7 +58,7 @@ def dataclass(clz: _T) -> _T: from flax import struct @struct.dataclass - class Model(): + class Model: params: Any # use pytree_node=False to indicate an attribute should not be touched # by Jax transformations. @@ -77,6 +77,30 @@ def __apply__(self, *args): model = Model(params, apply_fn) model_grad = jax.grad(some_loss_fn)(model) + Note that dataclasses have an auto-generated ``__init__`` where + the arguments of the constructor and the attributed of the created + instance match 1:1. This correspondance is what makes these objects + valid containers that work with JAX transformations and + more generally the `jax.tree_util` library. + + Somtimes a "smart constructor" is desired, for example because + some of the attributes can be (optionally) derived from others. + The way to do this with Flax dataclasses is to make a static or + class method that provides the smart constructor. + This way the simple constructor used by `jax.tree_util` is + preserved. Consider the following example:: + + @struct.dataclass + class DirectionAndScaleKernel: + direction: Array + scale: Array + + @classmethod + def create(cls, kernel): + scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) + directin = direction / scale + return cls(direction, scale) + Args: clz: the class that will be transformed by the decorator. Returns: From 9228f8a194f7786c006b3d5272a567e98c08d675 Mon Sep 17 00:00:00 2001 From: jheek Date: Wed, 27 Oct 2021 10:27:50 +0200 Subject: [PATCH 2/2] Update flax/struct.py Co-authored-by: Marc van Zee --- flax/struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/struct.py b/flax/struct.py index c95d03fb9b..b4da6847a2 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -83,7 +83,7 @@ def __apply__(self, *args): valid containers that work with JAX transformations and more generally the `jax.tree_util` library. - Somtimes a "smart constructor" is desired, for example because + Sometimes a "smart constructor" is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor.