diff --git a/jax/tree_util.py b/jax/tree_util.py index 857c6c792f04..9f4d76799dd6 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -231,8 +231,12 @@ def _replace_nones(sentinel, tree): else: return tree -def tree_reduce(f, tree): - return functools.reduce(f, tree_leaves(tree)) +no_initializer = object() +def tree_reduce(function, tree, initializer=no_initializer): + if initializer is no_initializer: + return functools.reduce(function, tree_leaves(tree)) + else: + return functools.reduce(function, tree_leaves(tree), initializer) def tree_all(tree): return all(tree_leaves(tree))