diff --git a/CHANGELOG.md b/CHANGELOG.md index e8db88776a..8a79bcb6fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,19 @@ vNext - - +0.6.8 +----- +- The automatic checkpoint migration was temporarily rolled back due to legacy compatibility issues. + - We still recommend you to use the [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and migrate completely to the Orbax API to ensure stability. + - Or alternatively, add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project to avoid being impacted by the automatic migration process. +- Added utility functions to frozen_dict api. +- Migrated Flax away from `register_keypaths`. +- Fixes kwargs in convert_to_graphs_tuple_fn. +- Fixed examples in a few ways: + - Bumped the TF version + - Used latest checkpoint formats + - Other misc fixes. + 0.6.7 ----- - New checkpoints will be saved using Orbax! Please check out [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and consider migrating completely to the Orbax API. diff --git a/README.md b/README.md index 01f920c30e..4e2ad46db1 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ To cite this repository: author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, - version = {0.6.7}, + version = {0.6.8}, year = {2023}, } ``` diff --git a/flax/core/meta.py b/flax/core/meta.py index f8e7f18f11..7c84592cba 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -238,7 +238,8 @@ def unbox(self, apply_constraint=True) -> Any: if apply_constraint and (_global_mesh_defined() or self.mesh is not None): axis_resource = self.get_partition_spec() if self.mesh is not None: - axis_resource = jax.sharding.NamedSharding(self.mesh, axis_resource) + sharding = jax.sharding.NamedSharding(self.mesh, axis_resource) + return jax.lax.with_sharding_constraint(self.value, sharding) return jax.lax.with_sharding_constraint( self.value, axis_resource) else: diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index 001c0947b9..c604ba6f66 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -210,8 +210,9 @@ def _with_sharding_constraint( if jax.devices()[0].platform == 'cpu' or (not _global_mesh_defined() and mesh is None): return x else: - if mesh is not None: - axis_resources = jax.sharding.NamedSharding(mesh, axis_resources) + if mesh is not None and axis_resources is not None: + sharding = jax.sharding.NamedSharding(mesh, axis_resources) + return pjit.with_sharding_constraint(x, sharding) return pjit.with_sharding_constraint(x, axis_resources) diff --git a/pytest.ini b/pytest.ini index f7f5833355..d6edc13edd 100644 --- a/pytest.ini +++ b/pytest.ini @@ -20,4 +20,9 @@ filterwarnings = # Deprecated sharding symbol ignore:jax.experimental.maps.Mesh is deprecated. Use jax.sharding.Mesh.*:DeprecationWarning # Deprecated legacy checkpoint - just want to keep the tests running for a while - ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning \ No newline at end of file + ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning +# Some Tensorflow IO error on 3/27/2023 + ignore:file system plugins are not loaded.*:UserWarning + ignore:unable to load libtensorflow_io_plugins.so.*:UserWarning +# Remove this after next Optax release after 3/27/2023 + ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning \ No newline at end of file diff --git a/setup.py b/setup.py index 19e350a7a8..eaa27ef864 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ tests_require = [ "atari-py==0.2.5", # Last version does not have the ROMs we test on pre-packaged "clu", # All examples. + "einops", "gym==0.18.3", "jaxlib", "jraph>=0.0.6dev0", diff --git a/tests/import_test.ipynb b/tests/import_test.ipynb new file mode 100644 index 0000000000..02bb884a0b --- /dev/null +++ b/tests/import_test.ipynb @@ -0,0 +1,58 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test Import in Colab\n", + "\n", + "\"Run all\" to test that all the Flax imports work in head.\n", + "\n", + "Change runtime type as needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-execution" + ] + }, + "outputs": [], + "source": [ + "# Install from head\n", + "!pip install -q git+https://github.com/google/flax.git" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import flax\n", + "from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,\n", + " orbax_utils, prefetch_iterator, train_state, common_utils)\n", + "from flax.metrics import tensorboard" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 6b6692b44d..45820a4f76 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -106,6 +106,7 @@ if $RUN_PYTEST; then PYTEST_IGNORE+=" --ignore=$file" done # Run battery of core FLAX API tests. + echo "pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE" pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests.