Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886824
  • Loading branch information
jakeharmon8 authored and Flax Authors committed Dec 5, 2024
1 parent 30b438c commit d82e214
Show file tree
Hide file tree
Showing 23 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ vNext
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
(See [JEP 9263](https://github.com/jax-ml/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
- Added `has_improved` field to EarlyStopping and changed the return signature of
Expand Down
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ section above to keep the contents of both Markdown and Jupyter Notebook files i
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)).
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)).
You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else
re-saves the notebook.

Expand Down
2 changes: 1 addition & 1 deletion docs/flip/1777-default-dtype.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The current behavior is problematic and results in silent bugs, especially for d

### Dtypes in JAX

JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice:
JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice:

![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``
np.testing.assert_almost_equal(j_out, t_out, decimal=6)


.. _`pull request`: https://github.com/google/jax/pull/5772
.. _`pull request`: https://github.com/jax-ml/jax/pull/5772

.. |nn.ConvTranspose| replace:: ``nn.ConvTranspose``
.. _nn.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.ConvTranspose
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/flax_fundamentals/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@
"source": [
"### Exporting to Tensorflow's SavedModel with jax2tf\n",
"\n",
"JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
"JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax."
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/flax_fundamentals/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C

### Exporting to Tensorflow's SavedModel with jax2tf

JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
4 changes: 2 additions & 2 deletions docs/guides/flax_fundamentals/rng_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/google/jax/discussions/18480) for more details."
"Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details."
]
},
{
Expand Down Expand Up @@ -1467,7 +1467,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/google/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.\n",
"[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.\n",
"\n",
"Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail."
]
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/flax_fundamentals/rng_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ import hashlib
jax.devices()
```

Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/google/jax/discussions/18480) for more details.
Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details.

```{code-cell}
jax.config.update('jax_threefry_partitionable', True)
Expand Down Expand Up @@ -647,7 +647,7 @@ jax.debug.visualize_array_sharding(out)

+++

[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/google/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.
[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.

Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail.

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/flax_sharp_bits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"\n",
"### Background \n",
"\n",
"The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n",
"The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n",
"\n",
"> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/flax_sharp_bits.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Check out a full example below.

### Background

The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable.
The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable.

> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/parallel_training/flax_on_pjit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n",
"Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n",
"\n",
"For example:\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/parallel_training/flax_on_pjit.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class DotReluDot(nn.Module):
return z, None
```

Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.
Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.

For example:

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/training_techniques/transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"# Note that the Transformers library doesn't use the latest Flax version.\n",
"! pip install -q \"transformers[flax]\"\n",
"# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n",
"# visit https://github.com/google/jax#installation.\n",
"# visit https://github.com/jax-ml/jax#installation.\n",
"! pip install -U -q flax jax jaxlib"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/training_techniques/transfer_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Depending on your task, some of the content in this guide may be suboptimal. For
# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/google/jax#installation.
# visit https://github.com/jax-ml/jax#installation.
! pip install -U -q flax jax jaxlib
```

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/training_techniques/use_checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"source": [
"## Setup\n",
"\n",
"Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation)."
"Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation)."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/guides/training_techniques/use_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](ht

## Setup

Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation).
Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation).


Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Installation
# or to install the latest version of Flax:
pip install --upgrade git+https://github.com/google/flax.git
Flax installs the vanilla CPU version of JAX, if you need a custom version please check out `JAX's installation page <https://github.com/google/jax#installation>`__.
Flax installs the vanilla CPU version of JAX, if you need a custom version please check out `JAX's installation page <https://github.com/jax-ml/jax#installation>`__.

Basic usage
^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"\n",
"Welcome to Flax!\n",
"\n",
"Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural\n",
"Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural\n",
"network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n",
"the network for image classification on the MNIST dataset."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jupytext:

Welcome to Flax!

Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural
Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural
network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train
the network for image classification on the MNIST dataset.

Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ section above to keep the contents of both Markdown and Jupyter Notebook files i
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)).
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)).
You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else
re-saves the notebook.

Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/flip/1777-default-dtype.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The current behavior is problematic and results in silent bugs, especially for d

### Dtypes in JAX

JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice:
JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice:

![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg)

Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/mnist_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"\n",
"Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n",
"\n",
"Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n",
"Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n",
"\n",
"Let’s get started!"
]
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jupytext:

Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.

Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.
Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.

Let’s get started!

Expand Down

0 comments on commit d82e214

Please sign in to comment.