Skip to content

Commit

Permalink
Merge pull request #4386 from 8bitmp3:update-flax-readme
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698508184
  • Loading branch information
Flax Authors committed Nov 20, 2024
2 parents 4e25898 + bddf5f2 commit be26138
Showing 1 changed file with 66 additions and 108 deletions.
174 changes: 66 additions & 108 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,40 @@

![Build](https://github.com/google/flax/workflows/Build/badge.svg?branch=main) [![coverage](https://badgen.net/codecov/c/gh/google/flax)](https://codecov.io/gh/google/flax)


[**Overview**](#overview)
| [**Quick install**](#quick-install)
| [**What does Flax look like?**](#what-does-flax-look-like)
| [**Documentation**](https://flax.readthedocs.io/)

**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API!
Released in 2024, Flax NNX is a new simplified Flax API that is designed to make
it easier to create, inspect, debug, and analyze neural networks in
[JAX](https://jax.readthedocs.io/). It achieves this by adding first class support
for Python reference semantics. This allows users to express their models using
regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which
was released in 2020 by engineers and researchers at Google Brain in close collaboration
with the JAX team.

This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).**
You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out:

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.
* [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html)
* [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html)
* [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)
* [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html)

Flax is being used by a growing
community of hundreds of folks in various Alphabet research departments
for their daily work, as well as a [growing community
of open source
projects](https://github.com/google/flax/network/dependents?dependent_type=REPOSITORY).
**Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/).

The Flax team's mission is to serve the growing JAX neural network
research ecosystem -- both within Alphabet and with the broader community,
research ecosystem - both within Alphabet and with the broader community,
and to explore the use-cases where JAX shines. We use GitHub for almost
all of our coordination and planning, as well as where we discuss
upcoming design changes. We welcome feedback on any of our discussion,
issue and pull request threads. We are in the process of moving some
remaining internal design docs and conversation threads to GitHub
discussions, issues and pull requests. We hope to increasingly engage
with the needs and clarifications of the broader ecosystem. Please let
us know how we can help!
issue and pull request threads.

Please report any feature requests,
issues, questions or concerns in our [discussion
forum](https://github.com/google/flax/discussions), or just let us
know what you're working on!
You can make feature requests, let us know what you are working on,
report issues, ask questions in our [Flax GitHub discussion
forum](https://github.com/google/flax/discussions).

We expect to improve Flax, but we don't anticipate significant
breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md)
Expand All @@ -51,31 +52,22 @@ In case you want to reach us directly, we're at [email protected].
Flax is a high-performance neural network library and ecosystem for
JAX that is **designed for flexibility**:
Try new forms of training by forking an example and by modifying the training
loop, not by adding features to a framework.
loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and
comes with everything you need to start your research, including:

* **Neural network API** (`flax.linen`): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device
* **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout).

* **Educational examples** that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging
* **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device.

* **Fast, tuned large-scale end-to-end examples**: CIFAR10, ResNet on ImageNet, Transformer LM1b
* **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma), [Transformer LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx).

## Quick install

You will need Python 3.6 or later, and a working [JAX](https://github.com/google/jax/blob/main/README.md)
installation (with or without GPU support - refer to [the instructions](https://github.com/google/jax/blob/main/README.md)).
For a CPU-only version of JAX:

```
pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only
```
Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html).

Then, install Flax from PyPi:
You will need Python 3.8 or later. Install Flax from PyPi:

```
pip install flax
Expand All @@ -86,6 +78,7 @@ To upgrade to the latest version of Flax, you can use:
```
pip install --upgrade git+https://github.com/google/flax.git
```

To install some additional dependencies (like `matplotlib`) that are required but not included
by some dependencies, you can use:

Expand All @@ -101,95 +94,60 @@ To learn more about the `Module` abstraction, check out our [docs](https://flax.
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).

Example of an MLP:

```py
from typing import Sequence
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)

def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
```

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
Example of a CNN:

class MLP(nn.Module):
features: Sequence[int]
```py
class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)

@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

```py
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Example of an autoencoder:

model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
```

```py
class AutoEncoder(nn.Module):
encoder_widths: Sequence[int]
decoder_widths: Sequence[int]
input_shape: Sequence[int]

def setup(self):
input_dim = np.prod(self.input_shape)
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths + (input_dim,))
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

def __call__(self, x):
return self.decode(self.encode(x))
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)

def encode(self, x):
assert x.shape[1:] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))

def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
decoder_widths=[5, 10, 20],
input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
def encode(self, x) -> jax.Array:
return self.encoder(x)
```

## 🤗 Hugging Face

In-detail examples to train and evaluate a variety of Flax models for
Natural Language Processing, Computer Vision, and Speech Recognition are
actively maintained in the [🤗 Transformers repository](https://github.com/huggingface/transformers/tree/main/examples/flax).

As of October 2021, the [19 most-used Transformer architectures](https://huggingface.co/transformers/#supported-frameworks) are supported in Flax
and over 5000 pretrained checkpoints in Flax have been uploaded to the [🤗 Hub](https://huggingface.co/models?library=jax&sort=downloads).

## Citing Flax

To cite this repository:
Expand All @@ -209,4 +167,4 @@ is intended to be that from [flax/version.py](https://github.com/google/flax/blo

## Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.
Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.

0 comments on commit be26138

Please sign in to comment.