Skip to content

Commit

Permalink
Merge pull request #2987 from IvyZX:push
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 520124550
  • Loading branch information
Flax Authors committed Mar 28, 2023
2 parents 8957389 + b8c89e8 commit e061c69
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 5 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ vNext
-
-

0.6.8
-----
- The automatic checkpoint migration was temporarily rolled back due to legacy compatibility issues.
- We still recommend you to use the [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and migrate completely to the Orbax API to ensure stability.
- Or alternatively, add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project to avoid being impacted by the automatic migration process.
- Added utility functions to frozen_dict api.
- Migrated Flax away from `register_keypaths`.
- Fixes kwargs in convert_to_graphs_tuple_fn.
- Fixed examples in a few ways:
- Bumped the TF version
- Used latest checkpoint formats
- Other misc fixes.

0.6.7
-----
- New checkpoints will be saved using Orbax! Please check out [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and consider migrating completely to the Orbax API.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ To cite this repository:
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.6.7},
version = {0.6.8},
year = {2023},
}
```
Expand Down
3 changes: 2 additions & 1 deletion flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def unbox(self, apply_constraint=True) -> Any:
if apply_constraint and (_global_mesh_defined() or self.mesh is not None):
axis_resource = self.get_partition_spec()
if self.mesh is not None:
axis_resource = jax.sharding.NamedSharding(self.mesh, axis_resource)
sharding = jax.sharding.NamedSharding(self.mesh, axis_resource)
return jax.lax.with_sharding_constraint(self.value, sharding)
return jax.lax.with_sharding_constraint(
self.value, axis_resource)
else:
Expand Down
5 changes: 3 additions & 2 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def _with_sharding_constraint(
if jax.devices()[0].platform == 'cpu' or (not _global_mesh_defined() and mesh is None):
return x
else:
if mesh is not None:
axis_resources = jax.sharding.NamedSharding(mesh, axis_resources)
if mesh is not None and axis_resources is not None:
sharding = jax.sharding.NamedSharding(mesh, axis_resources)
return pjit.with_sharding_constraint(x, sharding)
return pjit.with_sharding_constraint(x, axis_resources)


Expand Down
7 changes: 6 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ filterwarnings =
# Deprecated sharding symbol
ignore:jax.experimental.maps.Mesh is deprecated. Use jax.sharding.Mesh.*:DeprecationWarning
# Deprecated legacy checkpoint - just want to keep the tests running for a while
ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning
ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning
# Some Tensorflow IO error on 3/27/2023
ignore:file system plugins are not loaded.*:UserWarning
ignore:unable to load libtensorflow_io_plugins.so.*:UserWarning
# Remove this after next Optax release after 3/27/2023
ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
tests_require = [
"atari-py==0.2.5", # Last version does not have the ROMs we test on pre-packaged
"clu", # All examples.
"einops",
"gym==0.18.3",
"jaxlib",
"jraph>=0.0.6dev0",
Expand Down
58 changes: 58 additions & 0 deletions tests/import_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Import in Colab\n",
"\n",
"\"Run all\" to test that all the Flax imports work in head.\n",
"\n",
"Change runtime type as needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"# Install from head\n",
"!pip install -q git+https://github.com/google/flax.git"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import flax\n",
"from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,\n",
" orbax_utils, prefetch_iterator, train_state, common_utils)\n",
"from flax.metrics import tensorboard"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions tests/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if $RUN_PYTEST; then
PYTEST_IGNORE+=" --ignore=$file"
done
# Run battery of core FLAX API tests.
echo "pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE"
pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE

# Per-example tests.
Expand Down

0 comments on commit e061c69

Please sign in to comment.