Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886640
  • Loading branch information
jakeharmon8 authored and OptaxDev committed Dec 5, 2024
1 parent 3f0a64b commit 7af9810
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ can be used to obtain the most recent version of Optax::
pip install git+git://github.com/google-deepmind/optax.git

Note that Optax is built on top of JAX.
See `here <https://github.com/google/jax?tab=readme-ov-file#installation>`_
See `here <https://github.com/jax-ml/jax?tab=readme-ov-file#installation>`_
for instructions on installing JAX.


Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/differentially_private_sgd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"\n",
"A large portion of this code is forked from the differentially private SGD\n",
"example in the [JAX repo](\n",
"https://github.com/google/jax/blob/main/examples/differentially_private_sgd.py).\n",
"https://github.com/jax-ml/jax/blob/main/examples/differentially_private_sgd.py).\n",
"\n",
"[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter\n",
"gradients, which is non-trivial to implement efficiently for convolutional\n",
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
# are promoted (strictly) to float32 when jitted
# see https://github.com/google/jax/issues/23337
# see https://github.com/jax-ml/jax/issues/23337
# This may end up letting updates have a dtype different from params.
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree_utils.tree_bias_correction).
Expand Down Expand Up @@ -851,7 +851,7 @@ def test_minimize_bad_initialization(self):
chex.assert_trees_all_close(jnp_fun(optax_sol), minimum, atol=tol, rtol=tol)

def test_steep_objective(self):
# See jax related issue https://github.com/google/jax/issues/4594
# See jax related issue https://github.com/jax-ml/jax/issues/4594
tol = 1e-5
n = 2
mat = jnp.eye(n) * 1e4
Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_preserve_dtype(
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
# are promoted (strictly) to float32 when jitted
# see https://github.com/google/jax/issues/23337
# see https://github.com/jax-ml/jax/issues/23337
# This may end up letting updates have a dtype different from params.
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree_utils.tree_bias_correction).
Expand Down

0 comments on commit 7af9810

Please sign in to comment.