Skip to content

Commit

Permalink
Merge pull request #1630 from jheek:improve-dataclass-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 405841421
  • Loading branch information
Flax Authors committed Oct 27, 2021
2 parents e79a100 + 9228f8a commit d23a577
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def dataclass(clz: _T) -> _T:
from flax import struct
@struct.dataclass
class Model():
class Model:
params: Any
# use pytree_node=False to indicate an attribute should not be touched
# by Jax transformations.
Expand All @@ -77,6 +77,30 @@ def __apply__(self, *args):
model = Model(params, apply_fn)
model_grad = jax.grad(some_loss_fn)(model)
Note that dataclasses have an auto-generated ``__init__`` where
the arguments of the constructor and the attributed of the created
instance match 1:1. This correspondance is what makes these objects
valid containers that work with JAX transformations and
more generally the `jax.tree_util` library.
Sometimes a "smart constructor" is desired, for example because
some of the attributes can be (optionally) derived from others.
The way to do this with Flax dataclasses is to make a static or
class method that provides the smart constructor.
This way the simple constructor used by `jax.tree_util` is
preserved. Consider the following example::
@struct.dataclass
class DirectionAndScaleKernel:
direction: Array
scale: Array
@classmethod
def create(cls, kernel):
scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
directin = direction / scale
return cls(direction, scale)
Args:
clz: the class that will be transformed by the decorator.
Returns:
Expand Down

0 comments on commit d23a577

Please sign in to comment.