Skip to content

Commit

Permalink
add docstring to ravel_pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jun 12, 2020
1 parent 4a836ff commit ae9df75
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Subpackages
jax.ops
jax.random
jax.tree_util
jax.flatten_util
jax.dlpack
jax.profiler

Expand Down
22 changes: 13 additions & 9 deletions jax/flatten_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@


def ravel_pytree(pytree):
"""Ravel (i.e. flatten) a pytree of arrays down to a 1D array.
Args:
pytree: a pytree to ravel.
Returns:
A pair where the first element is a 1D array representing the flattened and
concatenated leaf values, and the second element is a callable for
unflattening a 1D vector of the same length back to a pytree of of the same
structure as the input ``pytree``.
"""
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = vjp(ravel_list, *leaves)
flat, unravel_list = vjp(_ravel_list, *leaves)
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
return flat, unravel_pytree

def ravel_list(*lst):
def _ravel_list(*lst):
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])


@lu.transformation_with_aux
def ravel_fun(unravel_inputs, flat_in, **kwargs):
pytree_args = unravel_inputs(flat_in)
ans = yield pytree_args, {}
yield ravel_pytree(ans)
52 changes: 28 additions & 24 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,26 @@ def tree_flatten(tree):
Args:
tree: a pytree to flatten.
Returns:
a pair with a list of leaves and the corresponding treedef.
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
"""
return pytree.flatten(tree)

def tree_unflatten(treedef, leaves):
"""Reconstructs a pytree from the treedef and the leaves.
The inverse of `tree_flatten`.
The inverse of :func:`tree_flatten`.
Args:
treedef: the treedef to reconstruct
leaves: the list of leaves to use for reconstruction. The list must
match the leaves of the treedef.
leaves: the list of leaves to use for reconstruction. The list must match
the leaves of the treedef.
Returns:
The reconstructed pytree, containing the `leaves` placed in the
structure described by `treedef`.
The reconstructed pytree, containing the ``leaves`` placed in the structure
described by ``treedef``.
"""
return treedef.unflatten(leaves)

Expand Down Expand Up @@ -102,7 +105,7 @@ def all_leaves(iterable):
iterable: Iterable of leaves.
Returns:
True if all elements in the input are leaves false if not.
A boolean indicating if all elements in the input are leaves.
"""
return pytree.all_leaves(iterable)

Expand All @@ -113,14 +116,14 @@ def register_pytree_node(nodetype, flatten_func, unflatten_func):
Args:
nodetype: a Python type to treat as an internal pytree node.
flatten_func: a function to be used during flattening, taking a value
of type `nodetype` and returning a pair, with (1) an iterable for
the children to be flattened recursively, and (2) some auxiliary data
to be stored in the treedef and to be passed to the `unflatten_func`.
unflatten_func: a function taking two arguments: the auxiliary data that
was returned by `flatten_func` and stored in the treedef, and the
flatten_func: a function to be used during flattening, taking a value of
type ``nodetype`` and returning a pair, with (1) an iterable for the
children to be flattened recursively, and (2) some auxiliary data to be
stored in the treedef and to be passed to the ``unflatten_func``.
unflatten_func: a function taking two arguments: the auxiliary data that was
returned by ``flatten_func`` and stored in the treedef, and the
unflattened children. The function should return an instance of
`nodetype`.
``nodetype``.
"""
pytree.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
Expand Down Expand Up @@ -149,13 +152,13 @@ def tree_map(f, tree):
"""Maps a function over a pytree to produce a new pytree.
Args:
f: function to be applied at each leaf.
f: unary function to be applied at each leaf.
tree: a pytree to be mapped over.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x)` where `x` is the value at the corresponding leaf in
`tree`.
leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in
the input ``tree``.
"""
leaves, treedef = pytree.flatten(tree)
return treedef.unflatten(map(f, leaves))
Expand All @@ -164,17 +167,18 @@ def tree_multimap(f, tree, *rest):
"""Maps a multi-input function over pytree args to produce a new pytree.
Args:
f: function that takes `1 + len(rest)` arguments, to be applied at the
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
corresponding leaves of the pytrees.
tree: a pytree to be mapped over, with each leaf providing the first
positional argument to `f`.
positional argument to ``f``.
*rest: a tuple of pytrees, each of which has the same structure as tree or
or has tree as a prefix.
Returns:
A new pytree with the same structure as `tree` but with the value at each
leaf given by `f(x, *xs)` where `x` is the value at the corresponding leaf
in `tree` and `xs` is the tuple of values at corresponding nodes in
`rest`.
A new pytree with the same structure as ``tree`` but with the value at each
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
``rest``.
"""
leaves, treedef = pytree.flatten(tree)
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
Expand Down Expand Up @@ -214,7 +218,7 @@ def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
}
def _replace_nones(sentinel, tree):
"""Replaces `None` in `tree` with `sentinel`."""
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
if tree is None:
return sentinel
else:
Expand Down

0 comments on commit ae9df75

Please sign in to comment.