From e79a100976eda1136bec1a7c8a90edd9c2a40cbc Mon Sep 17 00:00:00 2001 From: Flax Team Date: Wed, 27 Oct 2021 02:24:53 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 405841167 --- .../linen_design_test/attention_simple.py | 23 ++++++++----------- examples/linen_design_test/mlp_explicit.py | 12 +++++----- flax/__init__.py | 8 +++---- flax/core/nn/__init__.py | 13 ++++++----- flax/core/nn/attention.py | 7 ++---- flax/core/nn/linear.py | 8 ++----- flax/core/nn/normalization.py | 4 ++-- flax/deprecated/__init__.py | 14 +++++++++++ flax/{ => deprecated}/nn/__init__.py | 0 flax/{ => deprecated}/nn/activation.py | 0 flax/{ => deprecated}/nn/attention.py | 11 ++++----- flax/{ => deprecated}/nn/base.py | 0 flax/{ => deprecated}/nn/initializers.py | 0 flax/{ => deprecated}/nn/linear.py | 0 flax/{ => deprecated}/nn/normalization.py | 0 flax/{ => deprecated}/nn/pooling.py | 0 flax/{ => deprecated}/nn/recurrent.py | 0 flax/{ => deprecated}/nn/stochastic.py | 0 flax/{ => deprecated}/nn/utils.py | 0 flax/optim/base.py | 2 -- tests/nn_attention_test.py | 4 +--- tests/nn_linear_test.py | 2 +- tests/nn_test.py | 12 +++++----- tests/optim_test.py | 4 ++-- tests/serialization_test.py | 5 ++-- 25 files changed, 64 insertions(+), 65 deletions(-) create mode 100644 flax/deprecated/__init__.py rename flax/{ => deprecated}/nn/__init__.py (100%) rename flax/{ => deprecated}/nn/activation.py (100%) rename flax/{ => deprecated}/nn/attention.py (98%) rename flax/{ => deprecated}/nn/base.py (100%) rename flax/{ => deprecated}/nn/initializers.py (100%) rename flax/{ => deprecated}/nn/linear.py (100%) rename flax/{ => deprecated}/nn/normalization.py (100%) rename flax/{ => deprecated}/nn/pooling.py (100%) rename flax/{ => deprecated}/nn/recurrent.py (100%) rename flax/{ => deprecated}/nn/stochastic.py (100%) rename flax/{ => deprecated}/nn/utils.py (100%) diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index fb5a9068f6..a883a29eb5 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -14,17 +14,14 @@ import functools from pprint import pprint -from typing import Any, Callable, Iterable, Sequence, List, Optional, Tuple, Type, Union - -import jax -from jax import numpy as jnp, random, lax -import numpy as np - -from flax.nn import initializers -from flax.core.frozen_dict import freeze, unfreeze +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union from flax.core import Scope - +from flax.core.frozen_dict import freeze, unfreeze +from flax.deprecated.nn import initializers from flax.linen import Module, compact, vmap +import jax +from jax import lax, numpy as jnp, random +import numpy as np @@ -170,10 +167,10 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): dropout=(None, not self.broadcast_dropout), axis_size=self.num_heads) for axis in reversed(sorted(self.batch_axes)): - Attn = concise_vmap(Attn, - (axis, axis, axis), axis, - param=(None, False), - dropout=(None, not self.broadcast_dropout)) + Attn = concise_vmap(Attn, + (axis, axis, axis), axis, + param=(None, False), + dropout=(None, not self.broadcast_dropout)) attn = Attn(attn_module=self.attn_module, qkv_features=qkv_features // self.num_heads, diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index eb0f97dfc3..dec624b003 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -from jax import numpy as jnp, random, lax -from flax import nn -from flax.nn import initializers +from pprint import pprint from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from flax.deprecated import nn +from flax.deprecated.nn import initializers +from dense import Dense from flax.linen import Module +import jax +from jax import lax, numpy as jnp, random import numpy as np -from pprint import pprint -from dense import Dense # Add `in_features` to the built-in Dense layer that normally works diff --git a/flax/__init__.py b/flax/__init__.py index 6804569765..23a871db96 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -30,11 +30,11 @@ """Flax API.""" -from .version import __version__ - -# Allow `import flax`; `flax.optim.[...]`, etc from . import core from . import linen -from . import nn from . import optim +from .deprecated import nn # DO NOT REMOVE - Marker for internal logging. +from .version import __version__ + +# Allow `import flax`; `flax.optim.[...]`, etc diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index 98fb5f0930..cd1467b37c 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -17,12 +17,13 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions from .attention import (dot_product_attention, multi_head_dot_product_attention) -from flax.nn import activation -from flax.nn import initializers -from flax.nn.activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, - log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh) -from flax.nn.pooling import avg_pool, max_pool +from flax.deprecated.nn import activation +from flax.deprecated.nn import initializers +from flax.deprecated.nn.activation import (celu, elu, gelu, glu, leaky_relu, + log_sigmoid, log_softmax, relu, + sigmoid, silu, soft_sign, softmax, + softplus, swish, tanh) +from flax.deprecated.nn.pooling import avg_pool, max_pool from .linear import Embedding, conv, conv_transpose, dense, dense_general, embedding from .normalization import batch_norm, group_norm, layer_norm from .stochastic import dropout diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index a1d73dbf15..0cd3c12626 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -16,17 +16,14 @@ from collections.abc import Iterable # pylint: disable=g-importing-member import functools -import warnings from typing import Any - +import warnings from . import stochastic from flax import jax_utils from flax import struct -from flax.nn import initializers - from flax.core import Scope - +from flax.deprecated.nn import initializers import jax from jax import lax from jax import random diff --git a/flax/core/nn/linear.py b/flax/core/nn/linear.py index 8a34f3e0cc..ae57840a2f 100644 --- a/flax/core/nn/linear.py +++ b/flax/core/nn/linear.py @@ -15,13 +15,9 @@ """Linear modules.""" from collections.abc import Iterable # pylint: disable=g-importing-member - -from flax.nn import initializers - -from flax.core import Scope - from flax import struct - +from flax.core import Scope +from flax.deprecated.nn import initializers from jax import lax import jax.numpy as jnp diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index 3a12f165fa..a5f5ea143c 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -14,10 +14,10 @@ """Normalization modules for Flax.""" +from flax.core import Scope +from flax.deprecated.nn import initializers from jax import lax -from flax.nn import initializers import jax.numpy as jnp -from flax.core import Scope def _absolute_dims(rank, dims): diff --git a/flax/deprecated/__init__.py b/flax/deprecated/__init__.py new file mode 100644 index 0000000000..decb786913 --- /dev/null +++ b/flax/deprecated/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/flax/nn/__init__.py b/flax/deprecated/nn/__init__.py similarity index 100% rename from flax/nn/__init__.py rename to flax/deprecated/nn/__init__.py diff --git a/flax/nn/activation.py b/flax/deprecated/nn/activation.py similarity index 100% rename from flax/nn/activation.py rename to flax/deprecated/nn/activation.py diff --git a/flax/nn/attention.py b/flax/deprecated/nn/attention.py similarity index 98% rename from flax/nn/attention.py rename to flax/deprecated/nn/attention.py index 40265f3f6e..e520a96cd3 100644 --- a/flax/nn/attention.py +++ b/flax/deprecated/nn/attention.py @@ -19,13 +19,12 @@ import warnings from flax import jax_utils -from flax.nn.activation import softmax -from flax.nn.base import Collection, Module, collection_from_iterable, iterate_collection -from flax.nn.initializers import zeros -from flax.nn.stochastic import make_rng -from flax.nn.linear import DenseGeneral, default_kernel_init from flax import struct - +from flax.deprecated.nn.activation import softmax +from flax.deprecated.nn.base import Collection, Module, collection_from_iterable, iterate_collection +from flax.deprecated.nn.initializers import zeros +from flax.deprecated.nn.linear import DenseGeneral, default_kernel_init +from flax.deprecated.nn.stochastic import make_rng import jax from jax import lax from jax import random diff --git a/flax/nn/base.py b/flax/deprecated/nn/base.py similarity index 100% rename from flax/nn/base.py rename to flax/deprecated/nn/base.py diff --git a/flax/nn/initializers.py b/flax/deprecated/nn/initializers.py similarity index 100% rename from flax/nn/initializers.py rename to flax/deprecated/nn/initializers.py diff --git a/flax/nn/linear.py b/flax/deprecated/nn/linear.py similarity index 100% rename from flax/nn/linear.py rename to flax/deprecated/nn/linear.py diff --git a/flax/nn/normalization.py b/flax/deprecated/nn/normalization.py similarity index 100% rename from flax/nn/normalization.py rename to flax/deprecated/nn/normalization.py diff --git a/flax/nn/pooling.py b/flax/deprecated/nn/pooling.py similarity index 100% rename from flax/nn/pooling.py rename to flax/deprecated/nn/pooling.py diff --git a/flax/nn/recurrent.py b/flax/deprecated/nn/recurrent.py similarity index 100% rename from flax/nn/recurrent.py rename to flax/deprecated/nn/recurrent.py diff --git a/flax/nn/stochastic.py b/flax/deprecated/nn/stochastic.py similarity index 100% rename from flax/nn/stochastic.py rename to flax/deprecated/nn/stochastic.py diff --git a/flax/nn/utils.py b/flax/deprecated/nn/utils.py similarity index 100% rename from flax/nn/utils.py rename to flax/deprecated/nn/utils.py diff --git a/flax/optim/base.py b/flax/optim/base.py index 20c34c5688..a79a079ccd 100644 --- a/flax/optim/base.py +++ b/flax/optim/base.py @@ -26,8 +26,6 @@ import jax import jax.numpy as jnp -from ..nn import base - from ..core import FrozenDict, unfreeze # Backwards compatibility symbol import. diff --git a/tests/nn_attention_test.py b/tests/nn_attention_test.py index ccfd2b9c76..19a7b6a166 100644 --- a/tests/nn_attention_test.py +++ b/tests/nn_attention_test.py @@ -16,10 +16,8 @@ from absl.testing import absltest from absl.testing import parameterized - -from flax import nn from flax import jax_utils - +from flax.deprecated import nn import jax from jax import lax from jax import random diff --git a/tests/nn_linear_test.py b/tests/nn_linear_test.py index 7288556ca5..fe3c0c2fd9 100644 --- a/tests/nn_linear_test.py +++ b/tests/nn_linear_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -from flax import nn +from flax.deprecated import nn import jax from jax import random diff --git a/tests/nn_test.py b/tests/nn_test.py index 4347733223..3f1d89d543 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -17,7 +17,7 @@ import threading from absl.testing import absltest -from flax import nn +from flax.deprecated import nn import jax from jax import random @@ -288,7 +288,7 @@ def apply(self, x): _, params = MultiMethod.init(random.PRNGKey(0), x) model = nn.Model(MultiMethod, params) self.assertEqual(model.l2(), 2.) - + y, _ = MultiMethodModel.init(random.PRNGKey(0), x) self.assertEqual(y, 2.) @@ -695,8 +695,8 @@ def test_optimized_lstm_cell_matches_regular(self): self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x) - lstm = nn.Model(nn.LSTMCell, initial_params) - + lstm = nn.Model(nn.LSTMCell, initial_params) + # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) @@ -706,9 +706,9 @@ def test_optimized_lstm_cell_matches_regular(self): self.assertEqual(h0.shape, (2, 4)) (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial( name='LSTMCell').init(key2, (c0, h0), x) - lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), + lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), initial_params) - + np.testing.assert_allclose(y, y_opt, rtol=1e-6) jtu.check_eq(lstm.params, lstm_opt.params) diff --git a/tests/optim_test.py b/tests/optim_test.py index e383a9a372..c34fb39c9a 100644 --- a/tests/optim_test.py +++ b/tests/optim_test.py @@ -16,15 +16,15 @@ from functools import partial from absl.testing import absltest -from flax import nn from flax import optim from flax import traverse_util from flax.core.frozen_dict import FrozenDict +from flax.deprecated import nn +from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState from flax.optim.adadelta import _AdadeltaHyperParams, _AdadeltaParamState from flax.optim.adafactor import _AdafactorHyperParams, _AdafactorParamState from flax.optim.adagrad import _AdagradHyperParams, _AdagradParamState from flax.optim.adam import _AdamHyperParams, _AdamParamState -from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState from flax.optim.momentum import _MomentumHyperParams, _MomentumParamState from flax.optim.rmsprop import _RMSPropHyperParams, _RMSPropParamState from flax.optim.sgd import _GradientDescentHyperParams diff --git a/tests/serialization_test.py b/tests/serialization_test.py index f53955ddba..9de4bd247d 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -19,11 +19,10 @@ from typing import Any from absl.testing import absltest -from flax import nn from flax import optim from flax import serialization from flax import struct - +from flax.deprecated import nn import jax from jax import random import jax.numpy as jnp @@ -245,7 +244,7 @@ def test_namedtuple_serialization(self): restored_x1 = serialization.from_bytes(x2, x1_serialized) self.assertEqual(type(x1), type(restored_x1)) self.assertEqual(x1, restored_x1) - + def test_namedtuple_restore_legacy(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3)