From e847613fb77af882c3f693281bf0f68dec4dfc13 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:45:33 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886640 --- README.md | 4 ++-- docs/index.rst | 4 ++-- examples/losses.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ca14c96..591106f 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ implementation of the [K-FAC] optimizer and curvature estimator. KFAC-JAX is written in pure Python, but depends on C++ code via JAX. -First, follow [these instructions](https://github.com/google/jax#installation) +First, follow [these instructions](https://github.com/jax-ml/jax#installation) to install JAX with the relevant accelerator support. Then, install KFAC-JAX using pip: @@ -219,6 +219,6 @@ and the year corresponds to the project's open-source release. [K-FAC]: https://arxiv.org/abs/1503.05671 -[JAX]: https://github.com/google/jax +[JAX]: https://github.com/jax-ml/jax [Haiku]: https://github.com/google-deepmind/dm-haiku [documentation]: https://kfac-jax.readthedocs.io/ diff --git a/docs/index.rst b/docs/index.rst index 8b03930..b5d0a18 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ KFAC-JAX Documentation ====================== -KFAC-JAX is a library built on top of `JAX `_ for +KFAC-JAX is a library built on top of `JAX `_ for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use @@ -16,7 +16,7 @@ Installation KFAC-JAX is written in pure Python, but depends on C++ code via JAX. -First, follow `these instructions `_ +First, follow `these instructions `_ to install JAX with the relevant accelerator support. Then, install KFAC-JAX using pip:: diff --git a/examples/losses.py b/examples/losses.py index f7b7bfb..aa5bb68 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -104,7 +104,7 @@ def softmax_cross_entropy( max_logits = jnp.max(logits, keepdims=True, axis=-1) # It's unclear whether this stop_gradient is a good idea. - # See https://github.com/google/jax/issues/13529 + # See https://github.com/jax-ml/jax/issues/13529 max_logits = lax.stop_gradient(max_logits) logits = logits - max_logits