Skip to content

Commit

Permalink
Merge branch 'main' into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
melissatan authored Mar 9, 2022
2 parents 02c74c5 + da5e204 commit 5f8c5f9
Show file tree
Hide file tree
Showing 73 changed files with 3,625 additions and 1,040 deletions.
8 changes: 8 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ assignees: ''

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

### System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`:
- Python version:
- GPU/TPU model and memory:
- CUDA version (if applicable):


### Problem you have encountered:


Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ vNext
-
-
-
-
- Improved seq2seq example: Factored our model and input pipeline code.
- Added Optax update guide and deprecated `flax.optim`.
- Added `sep` argument to `flax.traverse_util.flatten_dict()`.
- Implemented Sequential module, in `flax.linen.combinators`.
-
-
- Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`.
-

0.4.0
Expand Down
28 changes: 22 additions & 6 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ We do all of our development using git, so basic knowledge is assumed.

Follow these steps to contribute code:

### Create a Pull Request in your own branch

1. Fork the Flax repository by clicking the 'Fork' button on the
[repository page](http://www.github.com/google/flax). This create a copy
[repository page](http://www.github.com/google/flax). This creates a copy
of the Flax repository in your own account.

2. Install Python >=3.6 and `svn` for running the tests (see below).
Expand All @@ -52,7 +54,7 @@ Follow these steps to contribute code:
pip install -e .
```

5. Add the Flax repo as an upstream remote, so you can use it to sync your
5. Add the Google Flax repo (not your fork) as an upstream remote, so you can use it to sync your
changes.

```bash
Expand All @@ -66,10 +68,10 @@ Follow these steps to contribute code:
git checkout -b name-of-change
```

And implement your changes using your favorite editor (we recommend
7. Implement your changes using your favorite editor (we recommend
[Visual Studio Code](https://code.visualstudio.com/)).

7. Make sure the tests pass by running the following command from the top of
Make sure the tests pass by running the following command from the top of
the repository:

```bash
Expand All @@ -90,19 +92,33 @@ Follow these steps to contribute code:
git rebase upstream/main
```

Finally push your commit on your development branch and create a remote
9. Finally push your commit on your development branch and create a remote
branch in your fork that you can use to create a Pull Request from:

```bash
git push --set-upstream origin name-of-change
```

After running the command, you should see a Github link in your terminal output that you can click on to create a Pull Request.
If you do not see this link in the terminal after doing a `git push`, go to the Github web UI; there should be a button there that lets you turn the commit into a Pull Request yourself.

9. Make sure your PR passes the
10. Make sure your PR passes the
[PR checklist](https://github.com/google/flax/blob/main/.github/pull_request_template.md#checklist).
If so, create a Pull Request from the Flax repository and send it for review.
Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/)
for more information on using pull requests.

### Updating the Pull Request contents

Every Pull Request should ideally be limited to just one commit, so if you have multiple commits please squash them.

Assuming you now have only one commit in your Pull Request, and want to add changes requested during review:

1. Make the changes locally in your editor.
2. Run `git commit -a --amend`. This updates the commit contents and allows you to edit the commit message.
3. At this point, `git push` alone will result in an error. Instead, use `git push --force`.
4. Check that it's done: The changes to your commit should be immediately reflected in the Github web UI.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Expand Down
3 changes: 3 additions & 0 deletions docs/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Transformations
jvp
vjp
custom_vjp
while_loop


Linear modules
Expand All @@ -88,6 +89,7 @@ Linear modules
DenseGeneral
Conv
ConvTranspose
ConvLocal
Embed


Expand All @@ -111,6 +113,7 @@ Pooling

max_pool
avg_pool
pool


Activation functions
Expand Down
132 changes: 132 additions & 0 deletions docs/flip/1777-default-dtype.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# FLIP: Default dtypes


- Start Date: 2022-01-11
- FLIP PR: [#1776](https://github.com/google/flax/pull/1776)
- FLIP Issue: [#1777](https://github.com/google/flax/issues/1777)


## Summary

This FLIP proposes to replace the default dtype which is currently fixed to float32, and instead use the JAX type promotion results to derive a default dtype from the input and parameters of a layer.


## Motivation

Currently, Linen Modules always produce `module.dtype` (defaults to float32) outputs regardless of input and parameter dtypes. Half-precision types like float16 and bfloat16 are supported by explicitly passing the half-precision type to each Module. The way this is currently implemented is that each Module has a dtype argument with float32 as the default value. The layer guarantees that this dtype will be the return type of the result returned by `__call__`.

The current behavior is problematic and results in silent bugs especially for dtypes that do not fit inside float32 (complex, float64). Also the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes.


### Dtypes in JAX

JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice:

![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg)


## Dtypes in Linen

Beside input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32.

By using the dtype promotion rules already available in JAX we can avoid this issue. A public API is available called `jax.numpy.result_dtype(*args)`, which returns the dtype that JAX would promote the given arguments to, in accordance with the type promotion lattice. For Linen layers the arguments would be the layer inputs together with the parameters. For example, for a linear layer this would be inputs, kernel, and bias.

Note that there is also a `param_dtype` attribute in standard Linen Modules that also defaults to flaot32. This behavior is left untouched and encodes the common case of having float32 parameters.
There are a few reasons why float32 is almost always the correct dtype for parameters:
1. Storing weights in half-precision often leads to underflow during optimization.
2. Double precision is rarely used because it severly slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for.
3. Complex Modules are relatively uncommon. Even within complex networks the complex inputs can be projected with a real matrix.


# Implementation

A simplified example implementation:


```python
def promote_arrays(*xs, dtype):
if dtype is None:
dtype = jnp.result_type(*jax.tree_leaves(xs))
return jax.tree_map(lambda x: jnp.asarray(x, dtype), xs)

Dtype = Any
class Dense(nn.Module):
features: int
kernel_init: Callable
bias_init: Callable
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32

@nn.compact
def __call__(self, x):
kernel = self.param("kernel",
self.kernel_init,
(x.shape[-1], self.features), self.param_dtype)
bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype)
x, kernel, bias = promote_arrays(x, kernel, bias, dtype=self.dtype)
return x @ kernel + bias
```


## Half-precision dtypes

Some layers don’t work with half-precision dtypes internally. For example: The normalization layers currently compute mean and variance in float32 even when a half-precision dtype is specified to avoid numerical issues. We can replicate this behavior by calling result_dtype with a dummy argument that has the minimum precision for the sub computation to work correctly.


## Backward compatibility

This proposal causes some layers to behave differently in cases where the dtype is not specified to a Linen Module. By default, parameters are in float32. Therefore, passing in half or float32 precision inputs will cause a float32 dtype and no functional differences with current behavior.

When passing complex or float64 precision, the result will no longer truncate the imaginary component or the precision. The silent truncation is problematic and has caused [user complaints](https://github.com/google/flax/issues/805#issuecomment-981468837). Therefore, this change can be considered a bugfix.

Thus, although this proposal strictly speaking changes behavior it is unlikely to cause problems for users. There are 2 exceptions to this which should be rare and easy to fix:
1. A user relies on the enforced float32 to downcast a double precision value.
2. A user relies on the float32 to explicitly upcast a half precision value even though the weights are in half precision.


## Corner cases

In this section we describe corner cases where the implementation of the proposal is not obvious. The two main concerns are how complex numbers are handled in existing layers and how to determine the dtype of state variables.

**Autoregressive decoding cache**

Currently, only attention implements autoregressive caching and the stored key and value mirror the dtype of the key and value passed to the layer. Forcing the cache dtype to be the same as the output dtype could result in reduced precision during cached decoding vs uncached. This seems undesirable. Decision: keep the current behavior.

**Batch statistics**

BatchNorm layers are often used with a half precision output dtype. However, calculating statistics is by default always done in float32 to avoid numerical precision issues and over/underflow for float16. With float64 this would actually cause a downcast so we should now use `np.promote_types(float32, dtype)` such that the precision is at least float32. The running batch statistics will be stored with the same dtype for consistency.

**Complex number support**

Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex number correctly:

1. Normalization layers use the complex conjugate to calculate norms instead of normal squaring.
2. Attention: It’s not exactly clear how the dot product and softmax are defined in this case. Raise an error on complex inputs.
3. Recurrent layers: might require special gating / activation functions to function correctly, but these can be specified by the user.


# Discussion

Summarizing the main points from the discussion:


## Consider implicit complex truncation an error

Q:
I I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do.
(Maybe it's a contrived example, but I can imagine cases where layers have their dtype set by parent modules based on assumptions without complex numbers in mind)

A:
This is worth considering in a follow-up CL but this might as well be solved in JAX directly where the safeguard would apply more generally. In NumPy this was also considered but abandoned because it is not backwards compatible.


## Dtype attribute names

Q:
Are the dtype and param_dtype arguments confusion? In particular should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit?

A:
This would be a large and orthogonal change wrt to this proposal so leaving it out for now.
Also this breaks with the standard dtype argument in NumPY/JAX.
Although dtype indeed constrains the output dtype it is also a hint for the dtype we would like the computation to happen in.

Loading

0 comments on commit 5f8c5f9

Please sign in to comment.