Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expose functools.reduce initializer argument to tree_util.tree_reduce (…
…jax-ml#2935) * Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`. `functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter. Example: ``` def l2_sum(total, param): return total + jnp.sum(param**2) tree_reduce(l2_sum, params, 0.) ``` * Only call functools.reduce with initializer when it is not None. * Change logic to check for number of args to allow None value as initializer * Rename seq to tree, and add tree_leaves * Change reduce to functools.reduce. * Make tree_reduce self-documenting * Replace jax.tree_leaves with tree_leaves * Update to use custom sentinel instead of optional position argument * jax.tree_leaves -> tree_leaves
- Loading branch information