-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4386 from 8bitmp3:update-flax-readme
PiperOrigin-RevId: 698508184
- Loading branch information
Showing
1 changed file
with
66 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,39 +6,40 @@ | |
|
||
![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax) | ||
|
||
|
||
[**Overview**](#overview) | ||
| [**Quick install**](#quick-install) | ||
| [**What does Flax look like?**](#what-does-flax-look-like) | ||
| [**Documentation**](https://flax.readthedocs.io/) | ||
|
||
**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API! | ||
Released in 2024, Flax NNX is a new simplified Flax API that is designed to make | ||
it easier to create, inspect, debug, and analyze neural networks in | ||
[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support | ||
for Python reference semantics. This allows users to express their models using | ||
regular Python objects, enabling reference sharing and mutability. | ||
|
||
Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which | ||
was released in 2020 by engineers and researchers at Google Brain in close collaboration | ||
with the JAX team. | ||
|
||
This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).** | ||
You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out: | ||
|
||
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community. | ||
* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html) | ||
* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html) | ||
* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html) | ||
* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) | ||
|
||
Flax is being used by a growing | ||
community of hundreds of folks in various Alphabet research departments | ||
for their daily work, as well as a [growing community | ||
of open source | ||
projects](https://github.com/google/flax/network/dependents?dependent_type=REPOSITORY). | ||
**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/). | ||
|
||
The Flax team's mission is to serve the growing JAX neural network | ||
research ecosystem -- both within Alphabet and with the broader community, | ||
research ecosystem - both within Alphabet and with the broader community, | ||
and to explore the use-cases where JAX shines. We use GitHub for almost | ||
all of our coordination and planning, as well as where we discuss | ||
upcoming design changes. We welcome feedback on any of our discussion, | ||
issue and pull request threads. We are in the process of moving some | ||
remaining internal design docs and conversation threads to GitHub | ||
discussions, issues and pull requests. We hope to increasingly engage | ||
with the needs and clarifications of the broader ecosystem. Please let | ||
us know how we can help! | ||
issue and pull request threads. | ||
|
||
Please report any feature requests, | ||
issues, questions or concerns in our [discussion | ||
forum](https://github.com/google/flax/discussions), or just let us | ||
know what you're working on! | ||
You can make feature requests, let us know what you are working on, | ||
report issues, ask questions in our [Flax GitHub discussion | ||
forum](https://github.com/google/flax/discussions). | ||
|
||
We expect to improve Flax, but we don't anticipate significant | ||
breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md) | ||
|
@@ -51,31 +52,22 @@ In case you want to reach us directly, we're at [email protected]. | |
Flax is a high-performance neural network library and ecosystem for | ||
JAX that is **designed for flexibility**: | ||
Try new forms of training by forking an example and by modifying the training | ||
loop, not by adding features to a framework. | ||
loop, not adding features to a framework. | ||
|
||
Flax is being developed in close collaboration with the JAX team and | ||
comes with everything you need to start your research, including: | ||
|
||
* **Neural network API** (`flax.linen`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout | ||
|
||
* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device | ||
* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout). | ||
|
||
* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging | ||
* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device. | ||
|
||
* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b | ||
* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma), [Transformer LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx). | ||
|
||
## Quick install | ||
|
||
You will need Python 3.6 or later, and a working [JAX](https://github.com/google/jax/blob/main/README.md) | ||
installation (with or without GPU support - refer to [the instructions](https://github.com/google/jax/blob/main/README.md)). | ||
For a CPU-only version of JAX: | ||
|
||
``` | ||
pip install --upgrade pip # To support manylinux2010 wheels. | ||
pip install --upgrade jax jaxlib # CPU-only | ||
``` | ||
Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html). | ||
|
||
Then, install Flax from PyPi: | ||
You will need Python 3.8 or later. Install Flax from PyPi: | ||
|
||
``` | ||
pip install flax | ||
|
@@ -86,6 +78,7 @@ To upgrade to the latest version of Flax, you can use: | |
``` | ||
pip install --upgrade git+https://github.com/google/flax.git | ||
``` | ||
|
||
To install some additional dependencies (like `matplotlib`) that are required but not included | ||
by some dependencies, you can use: | ||
|
||
|
@@ -101,95 +94,60 @@ To learn more about the `Module` abstraction, check out our [docs](https://flax. | |
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and | ||
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html). | ||
|
||
Example of an MLP: | ||
|
||
```py | ||
from typing import Sequence | ||
class MLP(nnx.Module): | ||
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): | ||
self.linear1 = Linear(din, dmid, rngs=rngs) | ||
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) | ||
self.bn = nnx.BatchNorm(dmid, rngs=rngs) | ||
self.linear2 = Linear(dmid, dout, rngs=rngs) | ||
|
||
def __call__(self, x: jax.Array): | ||
x = nnx.gelu(self.dropout(self.bn(self.linear1(x)))) | ||
return self.linear2(x) | ||
``` | ||
|
||
import numpy as np | ||
import jax | ||
import jax.numpy as jnp | ||
import flax.linen as nn | ||
Example of a CNN: | ||
|
||
class MLP(nn.Module): | ||
features: Sequence[int] | ||
```py | ||
class CNN(nnx.Module): | ||
def __init__(self, *, rngs: nnx.Rngs): | ||
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) | ||
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) | ||
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) | ||
self.linear1 = nnx.Linear(3136, 256, rngs=rngs) | ||
self.linear2 = nnx.Linear(256, 10, rngs=rngs) | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
for feat in self.features[:-1]: | ||
x = nn.relu(nn.Dense(feat)(x)) | ||
x = nn.Dense(self.features[-1])(x) | ||
x = self.avg_pool(nnx.relu(self.conv1(x))) | ||
x = self.avg_pool(nnx.relu(self.conv2(x))) | ||
x = x.reshape(x.shape[0], -1) # flatten | ||
x = nnx.relu(self.linear1(x)) | ||
x = self.linear2(x) | ||
return x | ||
|
||
model = MLP([12, 8, 4]) | ||
batch = jnp.ones((32, 10)) | ||
variables = model.init(jax.random.key(0), batch) | ||
output = model.apply(variables, batch) | ||
``` | ||
|
||
```py | ||
class CNN(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Conv(features=32, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | ||
x = nn.Conv(features=64, kernel_size=(3, 3))(x) | ||
x = nn.relu(x) | ||
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) | ||
x = x.reshape((x.shape[0], -1)) # flatten | ||
x = nn.Dense(features=256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(features=10)(x) | ||
x = nn.log_softmax(x) | ||
return x | ||
Example of an autoencoder: | ||
|
||
model = CNN() | ||
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format | ||
variables = model.init(jax.random.key(0), batch) | ||
output = model.apply(variables, batch) | ||
``` | ||
|
||
```py | ||
class AutoEncoder(nn.Module): | ||
encoder_widths: Sequence[int] | ||
decoder_widths: Sequence[int] | ||
input_shape: Sequence[int] | ||
|
||
def setup(self): | ||
input_dim = np.prod(self.input_shape) | ||
self.encoder = MLP(self.encoder_widths) | ||
self.decoder = MLP(self.decoder_widths + (input_dim,)) | ||
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) | ||
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) | ||
|
||
def __call__(self, x): | ||
return self.decode(self.encode(x)) | ||
class AutoEncoder(nnx.Module): | ||
def __init__(self, rngs): | ||
self.encoder = Encoder(rngs) | ||
self.decoder = Decoder(rngs) | ||
|
||
def encode(self, x): | ||
assert x.shape[1:] == self.input_shape | ||
return self.encoder(jnp.reshape(x, (x.shape[0], -1))) | ||
def __call__(self, x) -> jax.Array: | ||
return self.decoder(self.encoder(x)) | ||
|
||
def decode(self, z): | ||
z = self.decoder(z) | ||
x = nn.sigmoid(z) | ||
x = jnp.reshape(x, (x.shape[0],) + self.input_shape) | ||
return x | ||
|
||
model = AutoEncoder(encoder_widths=[20, 10, 5], | ||
decoder_widths=[5, 10, 20], | ||
input_shape=(12,)) | ||
batch = jnp.ones((16, 12)) | ||
variables = model.init(jax.random.key(0), batch) | ||
encoded = model.apply(variables, batch, method=model.encode) | ||
decoded = model.apply(variables, encoded, method=model.decode) | ||
def encode(self, x) -> jax.Array: | ||
return self.encoder(x) | ||
``` | ||
|
||
## 🤗 Hugging Face | ||
|
||
In-detail examples to train and evaluate a variety of Flax models for | ||
Natural Language Processing, Computer Vision, and Speech Recognition are | ||
actively maintained in the [🤗 Transformers repository](https://github.com/huggingface/transformers/tree/main/examples/flax). | ||
|
||
As of October 2021, the [19 most-used Transformer architectures](https://huggingface.co/transformers/#supported-frameworks) are supported in Flax | ||
and over 5000 pretrained checkpoints in Flax have been uploaded to the [🤗 Hub](https://huggingface.co/models?library=jax&sort=downloads). | ||
|
||
## Citing Flax | ||
|
||
To cite this repository: | ||
|
@@ -209,4 +167,4 @@ is intended to be that from [flax/version.py](https://github.com/google/flax/blo | |
|
||
## Note | ||
|
||
Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product. | ||
Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product. |