diff --git a/docs/jax.rst b/docs/jax.rst index b281ae428af2..fe54a6b8cbe8 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -17,6 +17,7 @@ Subpackages jax.ops jax.random jax.tree_util + jax.flatten_util jax.dlpack jax.profiler diff --git a/jax/flatten_util.py b/jax/flatten_util.py index 26ba0f0727b7..4d26faf6cf3a 100644 --- a/jax/flatten_util.py +++ b/jax/flatten_util.py @@ -24,17 +24,21 @@ def ravel_pytree(pytree): + """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. + + Args: + pytree: a pytree to ravel. + + Returns: + A pair where the first element is a 1D array representing the flattened and + concatenated leaf values, and the second element is a callable for + unflattening a 1D vector of the same length back to a pytree of of the same + structure as the input ``pytree``. + """ leaves, treedef = tree_flatten(pytree) - flat, unravel_list = vjp(ravel_list, *leaves) + flat, unravel_list = vjp(_ravel_list, *leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree -def ravel_list(*lst): +def _ravel_list(*lst): return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) - - -@lu.transformation_with_aux -def ravel_fun(unravel_inputs, flat_in, **kwargs): - pytree_args = unravel_inputs(flat_in) - ans = yield pytree_args, {} - yield ravel_pytree(ans) diff --git a/jax/tree_util.py b/jax/tree_util.py index 9f4d76799dd6..6a407270d830 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -49,23 +49,26 @@ def tree_flatten(tree): Args: tree: a pytree to flatten. + Returns: - a pair with a list of leaves and the corresponding treedef. + A pair where the first element is a list of leaf values and the second + element is a treedef representing the structure of the flattened tree. """ return pytree.flatten(tree) def tree_unflatten(treedef, leaves): """Reconstructs a pytree from the treedef and the leaves. - The inverse of `tree_flatten`. + The inverse of :func:`tree_flatten`. Args: treedef: the treedef to reconstruct - leaves: the list of leaves to use for reconstruction. The list must - match the leaves of the treedef. + leaves: the list of leaves to use for reconstruction. The list must match + the leaves of the treedef. + Returns: - The reconstructed pytree, containing the `leaves` placed in the - structure described by `treedef`. + The reconstructed pytree, containing the ``leaves`` placed in the structure + described by ``treedef``. """ return treedef.unflatten(leaves) @@ -102,7 +105,7 @@ def all_leaves(iterable): iterable: Iterable of leaves. Returns: - True if all elements in the input are leaves false if not. + A boolean indicating if all elements in the input are leaves. """ return pytree.all_leaves(iterable) @@ -113,14 +116,14 @@ def register_pytree_node(nodetype, flatten_func, unflatten_func): Args: nodetype: a Python type to treat as an internal pytree node. - flatten_func: a function to be used during flattening, taking a value - of type `nodetype` and returning a pair, with (1) an iterable for - the children to be flattened recursively, and (2) some auxiliary data - to be stored in the treedef and to be passed to the `unflatten_func`. - unflatten_func: a function taking two arguments: the auxiliary data that - was returned by `flatten_func` and stored in the treedef, and the + flatten_func: a function to be used during flattening, taking a value of + type ``nodetype`` and returning a pair, with (1) an iterable for the + children to be flattened recursively, and (2) some auxiliary data to be + stored in the treedef and to be passed to the ``unflatten_func``. + unflatten_func: a function taking two arguments: the auxiliary data that was + returned by ``flatten_func`` and stored in the treedef, and the unflattened children. The function should return an instance of - `nodetype`. + ``nodetype``. """ pytree.register_node(nodetype, flatten_func, unflatten_func) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -149,13 +152,13 @@ def tree_map(f, tree): """Maps a function over a pytree to produce a new pytree. Args: - f: function to be applied at each leaf. + f: unary function to be applied at each leaf. tree: a pytree to be mapped over. Returns: A new pytree with the same structure as `tree` but with the value at each - leaf given by `f(x)` where `x` is the value at the corresponding leaf in - `tree`. + leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in + the input ``tree``. """ leaves, treedef = pytree.flatten(tree) return treedef.unflatten(map(f, leaves)) @@ -164,17 +167,18 @@ def tree_multimap(f, tree, *rest): """Maps a multi-input function over pytree args to produce a new pytree. Args: - f: function that takes `1 + len(rest)` arguments, to be applied at the + f: function that takes ``1 + len(rest)`` arguments, to be applied at the corresponding leaves of the pytrees. tree: a pytree to be mapped over, with each leaf providing the first - positional argument to `f`. + positional argument to ``f``. *rest: a tuple of pytrees, each of which has the same structure as tree or or has tree as a prefix. + Returns: - A new pytree with the same structure as `tree` but with the value at each - leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf - in `tree` and `xs` is the tuple of values at corresponding nodes in - `rest`. + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding + leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in + ``rest``. """ leaves, treedef = pytree.flatten(tree) all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] @@ -214,7 +218,7 @@ def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose): type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None), } def _replace_nones(sentinel, tree): - """Replaces `None` in `tree` with `sentinel`.""" + """Replaces ``None`` in ``tree`` with ``sentinel``.""" if tree is None: return sentinel else: