Skip to content

Commit

Permalink
Expose functools.reduce initializer argument to tree_util.tree_reduce (
Browse files Browse the repository at this point in the history
…jax-ml#2935)

* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`.

`functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter.

Example:
```
def l2_sum(total, param):
  return total + jnp.sum(param**2)

tree_reduce(l2_sum, params, 0.)
```

* Only call functools.reduce with initializer when it is not None.

* Change logic to check for number of args to allow None value as initializer

* Rename seq to tree, and add tree_leaves

* Change reduce to functools.reduce.

* Make tree_reduce self-documenting

* Replace jax.tree_leaves with tree_leaves

* Update to use custom sentinel instead of optional position argument

* jax.tree_leaves -> tree_leaves
  • Loading branch information
bastings authored May 5, 2020
1 parent e4d8cac commit dc234b6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,12 @@ def _replace_nones(sentinel, tree):
else:
return tree

def tree_reduce(f, tree):
return functools.reduce(f, tree_leaves(tree))
no_initializer = object()
def tree_reduce(function, tree, initializer=no_initializer):
if initializer is no_initializer:
return functools.reduce(function, tree_leaves(tree))
else:
return functools.reduce(function, tree_leaves(tree), initializer)

def tree_all(tree):
return all(tree_leaves(tree))
Expand Down

0 comments on commit dc234b6

Please sign in to comment.