Skip to content

Commit

Permalink
Merge pull request #1055 from avital:0.3.2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 360138181
  • Loading branch information
Flax Authors committed Mar 1, 2021
2 parents 8bc5b52 + eea3420 commit c5223d0
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 18 deletions.
74 changes: 62 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,85 @@ vNext
-
-
-
- Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list.
-
-
- Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`.
See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns.
-
-
-
- Some Module arguments can now be passed either as dataclass attribute or
as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html)
- `use_running_average` and `deterministic` no longer have a default. They should be passed explicitly
- `broadcast_dims` is now a attribute to `Dropout` instead of a `__call__` argument.
-
-
-
-
-
-
- Added OptimizedLSTM: ~33% faster than the original LSTM when using <=1024 units
- Bug Fix `Scope.variable` mutability check, before a variable could only be initialized
if the 'params' collection was mutable.
- Linen `Module` instances are now Frozen after `setup` has been called.
-
-
-
-
-
-
-
-
-

0.3.2
------

`flax.nn` deprecation message no longer appears if you import flax directly.

NOTE: You must now explicitly import `flax.nn` if you want to use the old
pre-Linen `flax.nn.Module`.

0.3.1
------

Many improvements to Linen, and the old `flax.nn` is officially reprecated!

Notably, there's a clean API for extracting intermediates from modules
defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules
defined using `setup`, support for `MultiOptimizer` with Linen, and multiple safety, performance
and error message improvements.

Possible breaking changes:
- Call setup lazily. See #938 for motivation and more details.
- Linen `Module` instances are now frozen after `setup` has been called.
Previously mutations after setup could be dropped silently. Now the stateless requirement
is enforced by raising a TypeError in `__setattr__` after `setup`.
- Pytrees of dicts and lists are transformed into FrozenDict and tuples during attribute assignment.
- Pytrees of dicts and lists are transformed into FrozenDict and tuples during
attribute assignment.
This avoids undetected submodules and inner state.
- Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple
containing the output and a frozen empty
collection when `mutable` is specified as an empty list.
- `broadcast_dims` is now a attribute to `Dropout` instead of a `__call__`
argument.
- `use_running_average` and `deterministic` no longer have a default. They
should be passed explicitly
- Bug Fix `Scope.variable` mutability check, before a variable could only be
initialized if the 'params' collection was mutable.

Other Improvements:
- Re-introduced the `lm1b` language modeling example
- Recognizes batch free inputs in pooling layers. (for use with vmap)
- Add Adadelta optimizer
- Fully deprecate all "pre-Linen" `flax.nn` classes and methods.
- Some Module arguments can now be passed either as dataclass attribute or
as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/design_notes/arguments.html)
- Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`.
See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns.
- Support passing in modules directly as attributes to other modules, and
deal with them correctly both in top-level modules and in submodules.
- Don't require the `variable` argument to `Module.apply` to be a FrozenDict
- Add support for dict/FrozenDict when using `ModelParamTraversal`
As a result `MultiOptimizer` can be used properly with linen modules.
- Added OptimizedLSTM: ~33% faster than the original LSTM when using <=1024 units
- Fix dtype handling for Adam and LAMB optimizers in 64bit mode.
- Added `is_mutable()` method to `Variable` and `is_mutable_collection()` to `flax.linen.Module`.
- Add `axis_name` arg to `flax.linen.vmap`
- Enable broadcast in `flax.linen.scan`
- Fix behavior when inner module classes were defined in another module
- Add automatic giant array chunking in msgpack checkpoints.
- Log info message when a checkpoint is not found in the directory.

v0.3
-----
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ To cite this repository:
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.3.0},
version = {0.3.2},
year = {2020},
}
```
Expand Down
2 changes: 1 addition & 1 deletion contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ git clone https://github.com/google/flax
cd flax
python3.6 -m virtualenv env
. env/bin/activate
pip install -e . .[testing]
pip install -e . '.[testing]'
./tests/run_all_tests.sh
```

Expand Down
2 changes: 1 addition & 1 deletion flax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .version import __version__

# Allow `import flax`; `flax.nn.[...]`, and the same for `flax.optim.[...]`
# Allow `import flax`; `flax.optim.[...]`, etc
from . import core
from . import linen
from . import nn
Expand Down
1 change: 0 additions & 1 deletion flax/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@
import warnings
# Makes sure the user sees the warning, as deprecation warnings are silent by default
warnings.filterwarnings("default", category=DeprecationWarning, module=__name__)
warnings.warn("The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md", DeprecationWarning)
4 changes: 3 additions & 1 deletion flax/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class Module(metaclass=_ModuleMeta):
Functional modules."""

def __new__(cls, *args, name=None, **kwargs):
warnings.warn("The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md", DeprecationWarning)
if not _module_stack:
raise ValueError('A Module should only be instantiated directly inside'
' another module.')
Expand Down Expand Up @@ -953,7 +954,8 @@ class Model:
"""DEPRECATION WARNING:
The `flax.nn` module is Deprecated, use `flax.linen` instead.
Learn more and find an upgrade guide at
https://github.com/google/flax/blob/master/flax/linen/README.md"
https://github.com/google/flax/blob/master/flax/linen/README.md
A Model contains the model parameters, state and definition."""

module: Type[Module] = struct.field(pytree_node=False)
Expand Down
2 changes: 1 addition & 1 deletion flax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.3.0"
__version__ = "0.3.2"

0 comments on commit c5223d0

Please sign in to comment.