Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 405841167
  • Loading branch information
Flax Team committed Oct 27, 2021
1 parent a79dc33 commit e79a100
Show file tree
Hide file tree
Showing 25 changed files with 64 additions and 65 deletions.
23 changes: 10 additions & 13 deletions examples/linen_design_test/attention_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@

import functools
from pprint import pprint
from typing import Any, Callable, Iterable, Sequence, List, Optional, Tuple, Type, Union

import jax
from jax import numpy as jnp, random, lax
import numpy as np

from flax.nn import initializers
from flax.core.frozen_dict import freeze, unfreeze
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union
from flax.core import Scope

from flax.core.frozen_dict import freeze, unfreeze
from flax.deprecated.nn import initializers
from flax.linen import Module, compact, vmap
import jax
from jax import lax, numpy as jnp, random
import numpy as np



Expand Down Expand Up @@ -170,10 +167,10 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
dropout=(None, not self.broadcast_dropout),
axis_size=self.num_heads)
for axis in reversed(sorted(self.batch_axes)):
Attn = concise_vmap(Attn,
(axis, axis, axis), axis,
param=(None, False),
dropout=(None, not self.broadcast_dropout))
Attn = concise_vmap(Attn,
(axis, axis, axis), axis,
param=(None, False),
dropout=(None, not self.broadcast_dropout))

attn = Attn(attn_module=self.attn_module,
qkv_features=qkv_features // self.num_heads,
Expand Down
12 changes: 6 additions & 6 deletions examples/linen_design_test/mlp_explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from jax import numpy as jnp, random, lax
from flax import nn
from flax.nn import initializers
from pprint import pprint
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
from flax.deprecated import nn
from flax.deprecated.nn import initializers
from dense import Dense
from flax.linen import Module
import jax
from jax import lax, numpy as jnp, random
import numpy as np
from pprint import pprint
from dense import Dense


# Add `in_features` to the built-in Dense layer that normally works
Expand Down
8 changes: 4 additions & 4 deletions flax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@

"""Flax API."""

from .version import __version__

# Allow `import flax`; `flax.optim.[...]`, etc
from . import core
from . import linen
from . import nn
from . import optim
from .deprecated import nn
# DO NOT REMOVE - Marker for internal logging.
from .version import __version__

# Allow `import flax`; `flax.optim.[...]`, etc
13 changes: 7 additions & 6 deletions flax/core/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
# pylint: disable=g-multiple-import
# re-export commonly used modules and functions
from .attention import (dot_product_attention, multi_head_dot_product_attention)
from flax.nn import activation
from flax.nn import initializers
from flax.nn.activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid,
log_softmax, relu, sigmoid, soft_sign, softmax,
softplus, swish, silu, tanh)
from flax.nn.pooling import avg_pool, max_pool
from flax.deprecated.nn import activation
from flax.deprecated.nn import initializers
from flax.deprecated.nn.activation import (celu, elu, gelu, glu, leaky_relu,
log_sigmoid, log_softmax, relu,
sigmoid, silu, soft_sign, softmax,
softplus, swish, tanh)
from flax.deprecated.nn.pooling import avg_pool, max_pool
from .linear import Embedding, conv, conv_transpose, dense, dense_general, embedding
from .normalization import batch_norm, group_norm, layer_norm
from .stochastic import dropout
Expand Down
7 changes: 2 additions & 5 deletions flax/core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@

from collections.abc import Iterable # pylint: disable=g-importing-member
import functools
import warnings
from typing import Any

import warnings
from . import stochastic

from flax import jax_utils
from flax import struct
from flax.nn import initializers

from flax.core import Scope

from flax.deprecated.nn import initializers
import jax
from jax import lax
from jax import random
Expand Down
8 changes: 2 additions & 6 deletions flax/core/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
"""Linear modules."""

from collections.abc import Iterable # pylint: disable=g-importing-member

from flax.nn import initializers

from flax.core import Scope

from flax import struct

from flax.core import Scope
from flax.deprecated.nn import initializers
from jax import lax

import jax.numpy as jnp
Expand Down
4 changes: 2 additions & 2 deletions flax/core/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

"""Normalization modules for Flax."""

from flax.core import Scope
from flax.deprecated.nn import initializers
from jax import lax
from flax.nn import initializers
import jax.numpy as jnp
from flax.core import Scope


def _absolute_dims(rank, dims):
Expand Down
14 changes: 14 additions & 0 deletions flax/deprecated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

File renamed without changes.
File renamed without changes.
11 changes: 5 additions & 6 deletions flax/nn/attention.py → flax/deprecated/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
import warnings

from flax import jax_utils
from flax.nn.activation import softmax
from flax.nn.base import Collection, Module, collection_from_iterable, iterate_collection
from flax.nn.initializers import zeros
from flax.nn.stochastic import make_rng
from flax.nn.linear import DenseGeneral, default_kernel_init
from flax import struct

from flax.deprecated.nn.activation import softmax
from flax.deprecated.nn.base import Collection, Module, collection_from_iterable, iterate_collection
from flax.deprecated.nn.initializers import zeros
from flax.deprecated.nn.linear import DenseGeneral, default_kernel_init
from flax.deprecated.nn.stochastic import make_rng
import jax
from jax import lax
from jax import random
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 0 additions & 2 deletions flax/optim/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import jax
import jax.numpy as jnp

from ..nn import base

from ..core import FrozenDict, unfreeze

# Backwards compatibility symbol import.
Expand Down
4 changes: 1 addition & 3 deletions tests/nn_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

from absl.testing import absltest
from absl.testing import parameterized

from flax import nn
from flax import jax_utils

from flax.deprecated import nn
import jax
from jax import lax
from jax import random
Expand Down
2 changes: 1 addition & 1 deletion tests/nn_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from flax import nn
from flax.deprecated import nn

import jax
from jax import random
Expand Down
12 changes: 6 additions & 6 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import threading
from absl.testing import absltest

from flax import nn
from flax.deprecated import nn

import jax
from jax import random
Expand Down Expand Up @@ -288,7 +288,7 @@ def apply(self, x):
_, params = MultiMethod.init(random.PRNGKey(0), x)
model = nn.Model(MultiMethod, params)
self.assertEqual(model.l2(), 2.)

y, _ = MultiMethodModel.init(random.PRNGKey(0), x)
self.assertEqual(y, 2.)

Expand Down Expand Up @@ -695,8 +695,8 @@ def test_optimized_lstm_cell_matches_regular(self):
self.assertEqual(c0.shape, (2, 4))
self.assertEqual(h0.shape, (2, 4))
(carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x)
lstm = nn.Model(nn.LSTMCell, initial_params)
lstm = nn.Model(nn.LSTMCell, initial_params)

# Create OptimizedLSTMCell.
rng = random.PRNGKey(0)
key1, key2 = random.split(rng)
Expand All @@ -706,9 +706,9 @@ def test_optimized_lstm_cell_matches_regular(self):
self.assertEqual(h0.shape, (2, 4))
(carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial(
name='LSTMCell').init(key2, (c0, h0), x)
lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'),
lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'),
initial_params)

np.testing.assert_allclose(y, y_opt, rtol=1e-6)
jtu.check_eq(lstm.params, lstm_opt.params)

Expand Down
4 changes: 2 additions & 2 deletions tests/optim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

from functools import partial
from absl.testing import absltest
from flax import nn
from flax import optim
from flax import traverse_util
from flax.core.frozen_dict import FrozenDict
from flax.deprecated import nn
from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState
from flax.optim.adadelta import _AdadeltaHyperParams, _AdadeltaParamState
from flax.optim.adafactor import _AdafactorHyperParams, _AdafactorParamState
from flax.optim.adagrad import _AdagradHyperParams, _AdagradParamState
from flax.optim.adam import _AdamHyperParams, _AdamParamState
from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState
from flax.optim.momentum import _MomentumHyperParams, _MomentumParamState
from flax.optim.rmsprop import _RMSPropHyperParams, _RMSPropParamState
from flax.optim.sgd import _GradientDescentHyperParams
Expand Down
5 changes: 2 additions & 3 deletions tests/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from typing import Any

from absl.testing import absltest
from flax import nn
from flax import optim
from flax import serialization
from flax import struct

from flax.deprecated import nn
import jax
from jax import random
import jax.numpy as jnp
Expand Down Expand Up @@ -245,7 +244,7 @@ def test_namedtuple_serialization(self):
restored_x1 = serialization.from_bytes(x2, x1_serialized)
self.assertEqual(type(x1), type(restored_x1))
self.assertEqual(x1, restored_x1)

def test_namedtuple_restore_legacy(self):
foo_class = collections.namedtuple('Foo', 'a b c')
x1 = foo_class(a=1, b=2, c=3)
Expand Down

0 comments on commit e79a100

Please sign in to comment.