Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 714483196
  • Loading branch information
Jake VanderPlas authored and KfacJaxDev committed Jan 11, 2025
1 parent 23155a8 commit 9182734
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def safe_psd_eigh(

# The matrix is PSD by construction, but numerical inaccuracies can produce
# slightly negative eigenvalues. Hence, clip at zero.
return jnp.clip(s, a_min=0.0), q
return jnp.clip(s, min=0.0), q


def tnt_scale(factors: Sequence[Array]) -> Numeric:
Expand Down

0 comments on commit 9182734

Please sign in to comment.