From dc234b6f11b25237fea1eb9c851add83812fc5f8 Mon Sep 17 00:00:00 2001 From: Joost Bastings Date: Tue, 5 May 2020 10:11:10 +0200 Subject: [PATCH] Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935) * Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`. `functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter. Example: ``` def l2_sum(total, param): return total + jnp.sum(param**2) tree_reduce(l2_sum, params, 0.) ``` * Only call functools.reduce with initializer when it is not None. * Change logic to check for number of args to allow None value as initializer * Rename seq to tree, and add tree_leaves * Change reduce to functools.reduce. * Make tree_reduce self-documenting * Replace jax.tree_leaves with tree_leaves * Update to use custom sentinel instead of optional position argument * jax.tree_leaves -> tree_leaves --- jax/tree_util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/tree_util.py b/jax/tree_util.py index 857c6c792f04..9f4d76799dd6 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -231,8 +231,12 @@ def _replace_nones(sentinel, tree): else: return tree -def tree_reduce(f, tree): - return functools.reduce(f, tree_leaves(tree)) +no_initializer = object() +def tree_reduce(function, tree, initializer=no_initializer): + if initializer is no_initializer: + return functools.reduce(function, tree_leaves(tree)) + else: + return functools.reduce(function, tree_leaves(tree), initializer) def tree_all(tree): return all(tree_leaves(tree))