diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index a883a29eb..fb5a9068f 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -14,15 +14,18 @@ import functools from pprint import pprint -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union -from flax.core import Scope -from flax.core.frozen_dict import freeze, unfreeze -from flax.deprecated.nn import initializers -from flax.linen import Module, compact, vmap +from typing import Any, Callable, Iterable, Sequence, List, Optional, Tuple, Type, Union + import jax -from jax import lax, numpy as jnp, random +from jax import numpy as jnp, random, lax import numpy as np +from flax.nn import initializers +from flax.core.frozen_dict import freeze, unfreeze +from flax.core import Scope + +from flax.linen import Module, compact, vmap + class Dense(Module): @@ -167,10 +170,10 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): dropout=(None, not self.broadcast_dropout), axis_size=self.num_heads) for axis in reversed(sorted(self.batch_axes)): - Attn = concise_vmap(Attn, - (axis, axis, axis), axis, - param=(None, False), - dropout=(None, not self.broadcast_dropout)) + Attn = concise_vmap(Attn, + (axis, axis, axis), axis, + param=(None, False), + dropout=(None, not self.broadcast_dropout)) attn = Attn(attn_module=self.attn_module, qkv_features=qkv_features // self.num_heads, diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index dec624b00..eb0f97dfc 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pprint import pprint +import jax +from jax import numpy as jnp, random, lax +from flax import nn +from flax.nn import initializers from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union -from flax.deprecated import nn -from flax.deprecated.nn import initializers -from dense import Dense from flax.linen import Module -import jax -from jax import lax, numpy as jnp, random import numpy as np +from pprint import pprint +from dense import Dense # Add `in_features` to the built-in Dense layer that normally works diff --git a/flax/__init__.py b/flax/__init__.py index 23a871db9..680456976 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -30,11 +30,11 @@ """Flax API.""" +from .version import __version__ + +# Allow `import flax`; `flax.optim.[...]`, etc from . import core from . import linen +from . import nn from . import optim -from .deprecated import nn # DO NOT REMOVE - Marker for internal logging. -from .version import __version__ - -# Allow `import flax`; `flax.optim.[...]`, etc diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index cd1467b37..98fb5f093 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -17,13 +17,12 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions from .attention import (dot_product_attention, multi_head_dot_product_attention) -from flax.deprecated.nn import activation -from flax.deprecated.nn import initializers -from flax.deprecated.nn.activation import (celu, elu, gelu, glu, leaky_relu, - log_sigmoid, log_softmax, relu, - sigmoid, silu, soft_sign, softmax, - softplus, swish, tanh) -from flax.deprecated.nn.pooling import avg_pool, max_pool +from flax.nn import activation +from flax.nn import initializers +from flax.nn.activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, + log_softmax, relu, sigmoid, soft_sign, softmax, + softplus, swish, silu, tanh) +from flax.nn.pooling import avg_pool, max_pool from .linear import Embedding, conv, conv_transpose, dense, dense_general, embedding from .normalization import batch_norm, group_norm, layer_norm from .stochastic import dropout diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index 0cd3c1262..a1d73dbf1 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -16,14 +16,17 @@ from collections.abc import Iterable # pylint: disable=g-importing-member import functools -from typing import Any import warnings +from typing import Any + from . import stochastic from flax import jax_utils from flax import struct +from flax.nn import initializers + from flax.core import Scope -from flax.deprecated.nn import initializers + import jax from jax import lax from jax import random diff --git a/flax/core/nn/linear.py b/flax/core/nn/linear.py index ae57840a2..8a34f3e0c 100644 --- a/flax/core/nn/linear.py +++ b/flax/core/nn/linear.py @@ -15,9 +15,13 @@ """Linear modules.""" from collections.abc import Iterable # pylint: disable=g-importing-member -from flax import struct + +from flax.nn import initializers + from flax.core import Scope -from flax.deprecated.nn import initializers + +from flax import struct + from jax import lax import jax.numpy as jnp diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index a5f5ea143..3a12f165f 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -14,10 +14,10 @@ """Normalization modules for Flax.""" -from flax.core import Scope -from flax.deprecated.nn import initializers from jax import lax +from flax.nn import initializers import jax.numpy as jnp +from flax.core import Scope def _absolute_dims(rank, dims): diff --git a/flax/deprecated/__init__.py b/flax/deprecated/__init__.py deleted file mode 100644 index decb78691..000000000 --- a/flax/deprecated/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2021 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/flax/deprecated/nn/__init__.py b/flax/nn/__init__.py similarity index 100% rename from flax/deprecated/nn/__init__.py rename to flax/nn/__init__.py diff --git a/flax/deprecated/nn/activation.py b/flax/nn/activation.py similarity index 100% rename from flax/deprecated/nn/activation.py rename to flax/nn/activation.py diff --git a/flax/deprecated/nn/attention.py b/flax/nn/attention.py similarity index 98% rename from flax/deprecated/nn/attention.py rename to flax/nn/attention.py index e520a96cd..40265f3f6 100644 --- a/flax/deprecated/nn/attention.py +++ b/flax/nn/attention.py @@ -19,12 +19,13 @@ import warnings from flax import jax_utils +from flax.nn.activation import softmax +from flax.nn.base import Collection, Module, collection_from_iterable, iterate_collection +from flax.nn.initializers import zeros +from flax.nn.stochastic import make_rng +from flax.nn.linear import DenseGeneral, default_kernel_init from flax import struct -from flax.deprecated.nn.activation import softmax -from flax.deprecated.nn.base import Collection, Module, collection_from_iterable, iterate_collection -from flax.deprecated.nn.initializers import zeros -from flax.deprecated.nn.linear import DenseGeneral, default_kernel_init -from flax.deprecated.nn.stochastic import make_rng + import jax from jax import lax from jax import random diff --git a/flax/deprecated/nn/base.py b/flax/nn/base.py similarity index 100% rename from flax/deprecated/nn/base.py rename to flax/nn/base.py diff --git a/flax/deprecated/nn/initializers.py b/flax/nn/initializers.py similarity index 100% rename from flax/deprecated/nn/initializers.py rename to flax/nn/initializers.py diff --git a/flax/deprecated/nn/linear.py b/flax/nn/linear.py similarity index 100% rename from flax/deprecated/nn/linear.py rename to flax/nn/linear.py diff --git a/flax/deprecated/nn/normalization.py b/flax/nn/normalization.py similarity index 100% rename from flax/deprecated/nn/normalization.py rename to flax/nn/normalization.py diff --git a/flax/deprecated/nn/pooling.py b/flax/nn/pooling.py similarity index 100% rename from flax/deprecated/nn/pooling.py rename to flax/nn/pooling.py diff --git a/flax/deprecated/nn/recurrent.py b/flax/nn/recurrent.py similarity index 100% rename from flax/deprecated/nn/recurrent.py rename to flax/nn/recurrent.py diff --git a/flax/deprecated/nn/stochastic.py b/flax/nn/stochastic.py similarity index 100% rename from flax/deprecated/nn/stochastic.py rename to flax/nn/stochastic.py diff --git a/flax/deprecated/nn/utils.py b/flax/nn/utils.py similarity index 100% rename from flax/deprecated/nn/utils.py rename to flax/nn/utils.py diff --git a/flax/optim/base.py b/flax/optim/base.py index a79a079cc..20c34c568 100644 --- a/flax/optim/base.py +++ b/flax/optim/base.py @@ -26,6 +26,8 @@ import jax import jax.numpy as jnp +from ..nn import base + from ..core import FrozenDict, unfreeze # Backwards compatibility symbol import. diff --git a/tests/nn_attention_test.py b/tests/nn_attention_test.py index 19a7b6a16..ccfd2b9c7 100644 --- a/tests/nn_attention_test.py +++ b/tests/nn_attention_test.py @@ -16,8 +16,10 @@ from absl.testing import absltest from absl.testing import parameterized + +from flax import nn from flax import jax_utils -from flax.deprecated import nn + import jax from jax import lax from jax import random diff --git a/tests/nn_linear_test.py b/tests/nn_linear_test.py index fe3c0c2fd..7288556ca 100644 --- a/tests/nn_linear_test.py +++ b/tests/nn_linear_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -from flax.deprecated import nn +from flax import nn import jax from jax import random diff --git a/tests/nn_test.py b/tests/nn_test.py index 3f1d89d54..434773322 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -17,7 +17,7 @@ import threading from absl.testing import absltest -from flax.deprecated import nn +from flax import nn import jax from jax import random @@ -288,7 +288,7 @@ def apply(self, x): _, params = MultiMethod.init(random.PRNGKey(0), x) model = nn.Model(MultiMethod, params) self.assertEqual(model.l2(), 2.) - + y, _ = MultiMethodModel.init(random.PRNGKey(0), x) self.assertEqual(y, 2.) @@ -695,8 +695,8 @@ def test_optimized_lstm_cell_matches_regular(self): self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x) - lstm = nn.Model(nn.LSTMCell, initial_params) - + lstm = nn.Model(nn.LSTMCell, initial_params) + # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) @@ -706,9 +706,9 @@ def test_optimized_lstm_cell_matches_regular(self): self.assertEqual(h0.shape, (2, 4)) (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial( name='LSTMCell').init(key2, (c0, h0), x) - lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), + lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), initial_params) - + np.testing.assert_allclose(y, y_opt, rtol=1e-6) jtu.check_eq(lstm.params, lstm_opt.params) diff --git a/tests/optim_test.py b/tests/optim_test.py index c34fb39c9..e383a9a37 100644 --- a/tests/optim_test.py +++ b/tests/optim_test.py @@ -16,15 +16,15 @@ from functools import partial from absl.testing import absltest +from flax import nn from flax import optim from flax import traverse_util from flax.core.frozen_dict import FrozenDict -from flax.deprecated import nn -from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState from flax.optim.adadelta import _AdadeltaHyperParams, _AdadeltaParamState from flax.optim.adafactor import _AdafactorHyperParams, _AdafactorParamState from flax.optim.adagrad import _AdagradHyperParams, _AdagradParamState from flax.optim.adam import _AdamHyperParams, _AdamParamState +from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState from flax.optim.momentum import _MomentumHyperParams, _MomentumParamState from flax.optim.rmsprop import _RMSPropHyperParams, _RMSPropParamState from flax.optim.sgd import _GradientDescentHyperParams diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 9de4bd247..f53955ddb 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -19,10 +19,11 @@ from typing import Any from absl.testing import absltest +from flax import nn from flax import optim from flax import serialization from flax import struct -from flax.deprecated import nn + import jax from jax import random import jax.numpy as jnp @@ -244,7 +245,7 @@ def test_namedtuple_serialization(self): restored_x1 = serialization.from_bytes(x2, x1_serialized) self.assertEqual(type(x1), type(restored_x1)) self.assertEqual(x1, restored_x1) - + def test_namedtuple_restore_legacy(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3)