From 6bee7bd239cdfc03e392acb219502a4782414f2c Mon Sep 17 00:00:00 2001 From: Marc Fisher Date: Fri, 17 Nov 2023 10:27:08 -0800 Subject: [PATCH] Change relative imports into absolute imports. Remove redundant as <...> clauses from imports. PiperOrigin-RevId: 583426671 --- flax/core/__init__.py | 68 +++--- flax/core/lift.py | 36 ++- flax/core/nn/__init__.py | 68 +++--- flax/core/nn/attention.py | 14 +- flax/core/scope.py | 16 +- flax/core/tracers.py | 3 +- flax/core/variables.py | 2 +- flax/experimental/nnx/__init__.py | 171 +++++++------- flax/experimental/nnx/nnx/nn/initializers.py | 32 +-- flax/io.py | 5 +- flax/linen/__init__.py | 232 +++++++++---------- flax/linen/initializers.py | 40 ++-- flax/struct.py | 5 +- flax/testing/__init__.py | 3 +- flax/traverse_util.py | 8 +- 15 files changed, 345 insertions(+), 358 deletions(-) diff --git a/flax/core/__init__.py b/flax/core/__init__.py index 105f0f6c..f9cd6867 100644 --- a/flax/core/__init__.py +++ b/flax/core/__init__.py @@ -12,42 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .axes_scan import broadcast as broadcast -from .frozen_dict import ( - FrozenDict as FrozenDict, - copy as copy, - freeze as freeze, - pop as pop, - pretty_repr as pretty_repr, - unfreeze as unfreeze, +from flax.core.axes_scan import broadcast +from flax.core.frozen_dict import ( + FrozenDict, + copy, + freeze, + pop, + pretty_repr, + unfreeze, ) -from .lift import ( - custom_vjp as custom_vjp, - jit as jit, - jvp as jvp, - remat_scan as remat_scan, - remat as remat, - scan as scan, - vjp as vjp, - vmap as vmap, - while_loop as while_loop, +from flax.core.lift import ( + custom_vjp, + jit, + jvp, + remat, + remat_scan, + scan, + vjp, + vmap, + while_loop, ) -from .meta import ( - AxisMetadata as AxisMetadata, - map_axis_meta as map_axis_meta, - unbox as unbox, +from flax.core.meta import ( + AxisMetadata, + map_axis_meta, + unbox, ) -from .scope import ( - Array as Array, - DenyList as DenyList, - Scope as Scope, - apply as apply, - bind as bind, - init as init, - lazy_init as lazy_init, +from flax.core.scope import ( + Array, + DenyList, + Scope, + apply, + bind, + init, + lazy_init, ) -from .tracers import ( - check_trace_level as check_trace_level, - current_trace as current_trace, - trace_level as trace_level, +from flax.core.tracers import ( + check_trace_level, + current_trace, + trace_level, ) diff --git a/flax/core/lift.py b/flax/core/lift.py index 3f2cb90f..faa4f7a6 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -17,7 +17,6 @@ import collections import dataclasses import functools -import warnings from typing import ( Any, Callable, @@ -32,27 +31,26 @@ TypeVar, Union, ) - -import jax -from jax import random +import warnings from flax import traceback_util - -from . import axes_scan, meta -from .frozen_dict import freeze, unfreeze -from .scope import ( - CollectionFilter, - DenyList, # pylint: disable=g-multiple-import - Filter, - PRNGSequenceFilter, - Scope, - group_collections, - in_filter, - intersect_filters, - is_filter_empty, - subtract_filters, - union_filters, +from flax.core import axes_scan, meta +from flax.core.frozen_dict import freeze, unfreeze +from flax.core.scope import ( + CollectionFilter, + DenyList, # pylint: disable=g-multiple-import + Filter, + PRNGSequenceFilter, + Scope, + group_collections, + in_filter, + intersect_filters, + is_filter_empty, + subtract_filters, + union_filters, ) +import jax +from jax import random traceback_util.register_exclusion(__file__) diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index 7a7d430e..e00bac7f 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -16,43 +16,43 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions -from flax.linen import activation as activation -from flax.linen import initializers as initializers -from flax.linen.activation import ( - celu as celu, - elu as elu, - gelu as gelu, - glu as glu, - leaky_relu as leaky_relu, - log_sigmoid as log_sigmoid, - log_softmax as log_softmax, - relu as relu, - sigmoid as sigmoid, - silu as silu, - soft_sign as soft_sign, - softmax as softmax, - softplus as softplus, - swish as swish, - tanh as tanh, +from flax.core.nn.attention import ( + dot_product_attention, + multi_head_dot_product_attention, ) -from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool) -from .attention import ( - dot_product_attention as dot_product_attention, - multi_head_dot_product_attention as multi_head_dot_product_attention, +from flax.core.nn.linear import ( + Embedding, + conv, + conv_transpose, + dense, + dense_general, + embedding, ) -from .linear import ( - Embedding as Embedding, - conv_transpose as conv_transpose, - conv as conv, - dense_general as dense_general, - dense as dense, - embedding as embedding, +from flax.core.nn.normalization import ( + batch_norm, + group_norm, + layer_norm, ) -from .normalization import ( - batch_norm as batch_norm, - group_norm as group_norm, - layer_norm as layer_norm, +from flax.core.nn.stochastic import dropout +from flax.linen import activation +from flax.linen import initializers +from flax.linen.activation import ( + celu, + elu, + gelu, + glu, + leaky_relu, + log_sigmoid, + log_softmax, + relu, + sigmoid, + silu, + soft_sign, + softmax, + softplus, + swish, + tanh, ) -from .stochastic import dropout as dropout +from flax.linen.pooling import (avg_pool, max_pool) # pylint: enable=g-multiple-import diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index d39aec87..3c999d1b 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -14,20 +14,18 @@ """Attention core modules for Flax.""" -import functools from collections.abc import Iterable # pylint: disable=g-importing-member +import functools from typing import Any, Callable, Union -import jax -import jax.numpy as jnp -import numpy as np -from jax import lax, random - from flax import struct from flax.core import Scope +from flax.core.nn.linear import default_kernel_init, dense_general from flax.linen import initializers - -from .linear import default_kernel_init, dense_general +import jax +from jax import lax, random +import jax.numpy as jnp +import numpy as np def dot_product_attention( diff --git a/flax/core/scope.py b/flax/core/scope.py index 00f65547..ad7006e3 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -38,18 +38,16 @@ overload, ) -import jax -import numpy as np -from jax import numpy as jnp -from jax import random, tree_util - -from flax import config as config +from flax import config from flax import configurations as legacy_config # only for flax_lazy_rng from flax import errors, struct, traceback_util +from flax.core import meta, partial_eval, tracers +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.ids import uuid - -from . import meta, partial_eval, tracers -from .frozen_dict import FrozenDict, freeze, unfreeze +import jax +from jax import numpy as jnp +from jax import random, tree_util +import numpy as np traceback_util.register_exclusion(__file__) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index ea831955..d7633507 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -14,10 +14,9 @@ """Functionality for inspecting jax tracers.""" +from flax import errors import jax -from .. import errors - def current_trace(): """Returns the innermost Jax tracer.""" diff --git a/flax/core/variables.py b/flax/core/variables.py index 4aef63d4..54be87ae 100644 --- a/flax/core/variables.py +++ b/flax/core/variables.py @@ -39,4 +39,4 @@ TODO: Make "variable dict" design note, and link to it from here. """ -from .scope import Variable +from flax.core.scope import Variable diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index efca7ba6..8c6d0554 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -12,89 +12,88 @@ # See the License for the specific language governing permissions and # limitations under the License. -from flax.linen.pooling import avg_pool as avg_pool -from flax.linen.pooling import max_pool as max_pool -from flax.linen.pooling import min_pool as min_pool -from flax.linen.pooling import pool as pool - -from .nnx import compatibility as compatibility -from .nnx import graph_utils -from .nnx.dataclasses import dataclass as dataclass -from .nnx.dataclasses import field as field -from .nnx.dataclasses import param_field as param_field -from .nnx.dataclasses import treenode_field as treenode_field -from .nnx.dataclasses import variable_field as variable_field -from .nnx.errors import TraceContextError as TraceContextError -from .nnx.filterlib import All as All -from .nnx.filterlib import Not as Not -from .nnx.flaglib import flags as flags -from .nnx.graph_utils import GraphDef as GraphDef -from .nnx.helpers import Dict as Dict -from .nnx.helpers import Sequence as Sequence -from .nnx.helpers import TrainState as TrainState -from .nnx.module import GraphDef as GraphDef -from .nnx.module import M as M -from .nnx.module import Module as Module -from .nnx.module import merge as merge -from .nnx.nn import initializers as initializers -from .nnx.nn.activations import celu as celu -from .nnx.nn.activations import elu as elu -from .nnx.nn.activations import gelu as gelu -from .nnx.nn.activations import glu as glu -from .nnx.nn.activations import hard_sigmoid as hard_sigmoid -from .nnx.nn.activations import hard_silu as hard_silu -from .nnx.nn.activations import hard_swish as hard_swish -from .nnx.nn.activations import hard_tanh as hard_tanh -from .nnx.nn.activations import leaky_relu as leaky_relu -from .nnx.nn.activations import log_sigmoid as log_sigmoid -from .nnx.nn.activations import log_softmax as log_softmax -from .nnx.nn.activations import logsumexp as logsumexp -from .nnx.nn.activations import normalize as normalize -from .nnx.nn.activations import one_hot as one_hot -from .nnx.nn.activations import relu as relu -from .nnx.nn.activations import relu6 as relu6 -from .nnx.nn.activations import selu as selu -from .nnx.nn.activations import sigmoid as sigmoid -from .nnx.nn.activations import silu as silu -from .nnx.nn.activations import soft_sign as soft_sign -from .nnx.nn.activations import softmax as softmax -from .nnx.nn.activations import softplus as softplus -from .nnx.nn.activations import standardize as standardize -from .nnx.nn.activations import swish as swish -from .nnx.nn.activations import tanh as tanh -from .nnx.nn.linear import Conv as Conv -from .nnx.nn.linear import Embed as Embed -from .nnx.nn.linear import Linear as Linear -from .nnx.nn.normalization import BatchNorm as BatchNorm -from .nnx.nn.normalization import LayerNorm as LayerNorm -from .nnx.nn.stochastic import Dropout as Dropout -from .nnx.pytreelib import Pytree as Pytree -from .nnx.pytreelib import TreeNode as TreeNode -from .nnx.rnglib import Rngs as Rngs -from .nnx.rnglib import RngStream as RngStream -from .nnx.spmd import PARTITION_NAME as PARTITION_NAME -from .nnx.spmd import get_partition_spec as get_partition_spec -from .nnx.spmd import with_partitioning as with_partitioning -from .nnx.spmd import with_sharding_constraint as with_sharding_constraint -from .nnx.state import State as State -from .nnx.transforms import JIT as JIT -from .nnx.transforms import Remat as Remat -from .nnx.transforms import Scan as Scan -from .nnx.transforms import Vmap as Vmap -from .nnx.transforms import grad as grad -from .nnx.transforms import jit as jit -from .nnx.transforms import remat as remat -from .nnx.transforms import scan as scan -from .nnx.transforms import value_and_grad as value_and_grad -from .nnx.transforms import vmap as vmap -from .nnx.variables import EMPTY as EMPTY -from .nnx.variables import A as A -from .nnx.variables import BatchStat as BatchStat -from .nnx.variables import Cache as Cache -from .nnx.variables import Empty as Empty -from .nnx.variables import Intermediate as Intermediate -from .nnx.variables import Param as Param -from .nnx.variables import Rng as Rng -from .nnx.variables import Variable as Variable -from .nnx.variables import VariableMetadata as VariableMetadata -from .nnx.variables import with_metadata as with_metadata +from flax.experimental.nnx.nnx import compatibility +from flax.experimental.nnx.nnx import graph_utils +from flax.experimental.nnx.nnx.dataclasses import dataclass +from flax.experimental.nnx.nnx.dataclasses import field +from flax.experimental.nnx.nnx.dataclasses import param_field +from flax.experimental.nnx.nnx.dataclasses import treenode_field +from flax.experimental.nnx.nnx.dataclasses import variable_field +from flax.experimental.nnx.nnx.errors import TraceContextError +from flax.experimental.nnx.nnx.filterlib import All +from flax.experimental.nnx.nnx.filterlib import Not +from flax.experimental.nnx.nnx.flaglib import flags +from flax.experimental.nnx.nnx.graph_utils import GraphDef +from flax.experimental.nnx.nnx.helpers import Dict +from flax.experimental.nnx.nnx.helpers import Sequence +from flax.experimental.nnx.nnx.helpers import TrainState +from flax.experimental.nnx.nnx.module import GraphDef +from flax.experimental.nnx.nnx.module import M +from flax.experimental.nnx.nnx.module import merge +from flax.experimental.nnx.nnx.module import Module +from flax.experimental.nnx.nnx.nn import initializers +from flax.experimental.nnx.nnx.nn.activations import celu +from flax.experimental.nnx.nnx.nn.activations import elu +from flax.experimental.nnx.nnx.nn.activations import gelu +from flax.experimental.nnx.nnx.nn.activations import glu +from flax.experimental.nnx.nnx.nn.activations import hard_sigmoid +from flax.experimental.nnx.nnx.nn.activations import hard_silu +from flax.experimental.nnx.nnx.nn.activations import hard_swish +from flax.experimental.nnx.nnx.nn.activations import hard_tanh +from flax.experimental.nnx.nnx.nn.activations import leaky_relu +from flax.experimental.nnx.nnx.nn.activations import log_sigmoid +from flax.experimental.nnx.nnx.nn.activations import log_softmax +from flax.experimental.nnx.nnx.nn.activations import logsumexp +from flax.experimental.nnx.nnx.nn.activations import normalize +from flax.experimental.nnx.nnx.nn.activations import one_hot +from flax.experimental.nnx.nnx.nn.activations import relu +from flax.experimental.nnx.nnx.nn.activations import relu6 +from flax.experimental.nnx.nnx.nn.activations import selu +from flax.experimental.nnx.nnx.nn.activations import sigmoid +from flax.experimental.nnx.nnx.nn.activations import silu +from flax.experimental.nnx.nnx.nn.activations import soft_sign +from flax.experimental.nnx.nnx.nn.activations import softmax +from flax.experimental.nnx.nnx.nn.activations import softplus +from flax.experimental.nnx.nnx.nn.activations import standardize +from flax.experimental.nnx.nnx.nn.activations import swish +from flax.experimental.nnx.nnx.nn.activations import tanh +from flax.experimental.nnx.nnx.nn.linear import Conv +from flax.experimental.nnx.nnx.nn.linear import Embed +from flax.experimental.nnx.nnx.nn.linear import Linear +from flax.experimental.nnx.nnx.nn.normalization import BatchNorm +from flax.experimental.nnx.nnx.nn.normalization import LayerNorm +from flax.experimental.nnx.nnx.nn.stochastic import Dropout +from flax.experimental.nnx.nnx.pytreelib import Pytree +from flax.experimental.nnx.nnx.pytreelib import TreeNode +from flax.experimental.nnx.nnx.rnglib import Rngs +from flax.experimental.nnx.nnx.rnglib import RngStream +from flax.experimental.nnx.nnx.spmd import get_partition_spec +from flax.experimental.nnx.nnx.spmd import PARTITION_NAME +from flax.experimental.nnx.nnx.spmd import with_partitioning +from flax.experimental.nnx.nnx.spmd import with_sharding_constraint +from flax.experimental.nnx.nnx.state import State +from flax.experimental.nnx.nnx.transforms import grad +from flax.experimental.nnx.nnx.transforms import JIT +from flax.experimental.nnx.nnx.transforms import jit +from flax.experimental.nnx.nnx.transforms import Remat +from flax.experimental.nnx.nnx.transforms import remat +from flax.experimental.nnx.nnx.transforms import Scan +from flax.experimental.nnx.nnx.transforms import scan +from flax.experimental.nnx.nnx.transforms import value_and_grad +from flax.experimental.nnx.nnx.transforms import Vmap +from flax.experimental.nnx.nnx.transforms import vmap +from flax.experimental.nnx.nnx.variables import A +from flax.experimental.nnx.nnx.variables import BatchStat +from flax.experimental.nnx.nnx.variables import Cache +from flax.experimental.nnx.nnx.variables import EMPTY +from flax.experimental.nnx.nnx.variables import Empty +from flax.experimental.nnx.nnx.variables import Intermediate +from flax.experimental.nnx.nnx.variables import Param +from flax.experimental.nnx.nnx.variables import Rng +from flax.experimental.nnx.nnx.variables import Variable +from flax.experimental.nnx.nnx.variables import VariableMetadata +from flax.experimental.nnx.nnx.variables import with_metadata +from flax.linen.pooling import avg_pool +from flax.linen.pooling import max_pool +from flax.linen.pooling import min_pool +from flax.linen.pooling import pool diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/experimental/nnx/nnx/nn/initializers.py index 0e989c80..642ee9a0 100644 --- a/flax/experimental/nnx/nnx/nn/initializers.py +++ b/flax/experimental/nnx/nnx/nn/initializers.py @@ -15,23 +15,23 @@ import typing as tp import jax +from jax.nn.initializers import constant +from jax.nn.initializers import delta_orthogonal +from jax.nn.initializers import glorot_normal +from jax.nn.initializers import glorot_uniform +from jax.nn.initializers import he_normal +from jax.nn.initializers import he_uniform +from jax.nn.initializers import kaiming_normal +from jax.nn.initializers import kaiming_uniform +from jax.nn.initializers import lecun_normal +from jax.nn.initializers import lecun_uniform +from jax.nn.initializers import normal +from jax.nn.initializers import orthogonal +from jax.nn.initializers import uniform +from jax.nn.initializers import variance_scaling +from jax.nn.initializers import xavier_normal +from jax.nn.initializers import xavier_uniform import jax.numpy as jnp -from jax.nn.initializers import constant as constant -from jax.nn.initializers import delta_orthogonal as delta_orthogonal -from jax.nn.initializers import glorot_normal as glorot_normal -from jax.nn.initializers import glorot_uniform as glorot_uniform -from jax.nn.initializers import he_normal as he_normal -from jax.nn.initializers import he_uniform as he_uniform -from jax.nn.initializers import kaiming_normal as kaiming_normal -from jax.nn.initializers import kaiming_uniform as kaiming_uniform -from jax.nn.initializers import lecun_normal as lecun_normal -from jax.nn.initializers import lecun_uniform as lecun_uniform -from jax.nn.initializers import normal as normal -from jax.nn.initializers import orthogonal as orthogonal -from jax.nn.initializers import uniform as uniform -from jax.nn.initializers import variance_scaling as variance_scaling -from jax.nn.initializers import xavier_normal as xavier_normal -from jax.nn.initializers import xavier_uniform as xavier_uniform Shape = tp.Sequence[int] DTypeLikeInexact = tp.Any diff --git a/flax/io.py b/flax/io.py index 4cca5c87..17589952 100644 --- a/flax/io.py +++ b/flax/io.py @@ -17,15 +17,14 @@ as an open-source dependency solely for its tensorflow.io.gfile functions. """ import contextlib +from enum import Enum import glob as glob_module import importlib import os import shutil -from enum import Enum from absl import logging - -from . import errors +from flax import errors # Global Modes and selective import of tensorflow.io gfile. diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 84188931..6e63306f 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -17,136 +17,136 @@ # pylint: disable=g-multiple-import,useless-import-alias # re-export commonly used modules and functions -from ..core import ( - DenyList as DenyList, - FrozenDict as FrozenDict, - broadcast as broadcast, - meta as meta, +from flax.core import ( + DenyList, + FrozenDict, + broadcast, + meta, ) -from ..core.meta import ( - PARTITION_NAME as PARTITION_NAME, - Partitioned as Partitioned, - get_partition_spec as get_partition_spec, - get_sharding as get_sharding, - unbox as unbox, - with_partitioning as with_partitioning, +from flax.core.meta import ( + PARTITION_NAME, + Partitioned, + get_partition_spec, + get_sharding, + unbox, + with_partitioning, ) -from .activation import ( - PReLU as PReLU, - celu as celu, - elu as elu, - gelu as gelu, - glu as glu, - hard_sigmoid as hard_sigmoid, - hard_silu as hard_silu, - hard_swish as hard_swish, - hard_tanh as hard_tanh, - leaky_relu as leaky_relu, - log_sigmoid as log_sigmoid, - log_softmax as log_softmax, - logsumexp as logsumexp, - normalize as normalize, - one_hot as one_hot, - relu6 as relu6, - relu as relu, - selu as selu, - sigmoid as sigmoid, - silu as silu, - soft_sign as soft_sign, - softmax as softmax, - softplus as softplus, - standardize as standardize, - swish as swish, - tanh as tanh, +from flax.linen.activation import ( + PReLU, + celu, + elu, + gelu, + glu, + hard_sigmoid, + hard_silu, + hard_swish, + hard_tanh, + leaky_relu, + log_sigmoid, + log_softmax, + logsumexp, + normalize, + one_hot, + relu, + relu6, + selu, + sigmoid, + silu, + soft_sign, + softmax, + softplus, + standardize, + swish, + tanh, ) -from .attention import ( - MultiHeadDotProductAttention as MultiHeadDotProductAttention, - SelfAttention as SelfAttention, - combine_masks as combine_masks, - dot_product_attention_weights as dot_product_attention_weights, - dot_product_attention as dot_product_attention, - make_attention_mask as make_attention_mask, - make_causal_mask as make_causal_mask, +from flax.linen.attention import ( + MultiHeadDotProductAttention, + SelfAttention, + combine_masks, + dot_product_attention, + dot_product_attention_weights, + make_attention_mask, + make_causal_mask, ) -from .combinators import Sequential as Sequential -from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp -from .initializers import ( - ones_init as ones_init, - ones as ones, - zeros_init as zeros_init, - zeros as zeros, +from flax.linen.combinators import Sequential +from flax.linen.fp8_ops import Fp8DotGeneralOp +from flax.linen.initializers import ( + ones, + ones_init, + zeros, + zeros_init, ) -from .linear import ( - ConvLocal as ConvLocal, - ConvTranspose as ConvTranspose, - Conv as Conv, - DenseGeneral as DenseGeneral, - Dense as Dense, - Embed as Embed, +from flax.linen.linear import ( + Conv, + ConvLocal, + ConvTranspose, + Dense, + DenseGeneral, + Embed, ) -from .module import ( - Module as Module, - Variable as Variable, - apply as apply, - compact as compact, - disable_named_call as disable_named_call, - enable_named_call as enable_named_call, - init_with_output as init_with_output, - init as init, - intercept_methods as intercept_methods, - merge_param as merge_param, - nowrap as nowrap, - override_named_call as override_named_call, +from flax.linen.module import ( + Module, + Variable, + apply, + compact, + disable_named_call, + enable_named_call, + init, + init_with_output, + intercept_methods, + merge_param, + nowrap, + override_named_call, ) -from .normalization import ( - BatchNorm as BatchNorm, - GroupNorm as GroupNorm, - LayerNorm as LayerNorm, - RMSNorm as RMSNorm, - SpectralNorm as SpectralNorm, - WeightNorm as WeightNorm +from flax.linen.normalization import ( + BatchNorm, + GroupNorm, + LayerNorm, + RMSNorm, + SpectralNorm, + WeightNorm, ) -from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) -from .recurrent import ( - Bidirectional as Bidirectional, - ConvLSTMCell as ConvLSTMCell, - GRUCell as GRUCell, - MGUCell as MGUCell, - LSTMCell as LSTMCell, - OptimizedLSTMCell as OptimizedLSTMCell, - RNNCellBase as RNNCellBase, - RNN as RNN, +from flax.linen.pooling import (avg_pool, max_pool, pool) +from flax.linen.recurrent import ( + Bidirectional, + ConvLSTMCell, + GRUCell, + LSTMCell, + MGUCell, + OptimizedLSTMCell, + RNN, + RNNCellBase, ) -from .spmd import ( - LogicallyPartitioned as LogicallyPartitioned, - get_logical_axis_rules as get_logical_axis_rules, - logical_axis_rules as logical_axis_rules, +from flax.linen.spmd import ( + LogicallyPartitioned, + get_logical_axis_rules, + logical_axis_rules, logical_to_mesh, logical_to_mesh_axes, logical_to_mesh_sharding, - set_logical_axis_rules as set_logical_axis_rules, + set_logical_axis_rules, with_logical_constraint, - with_logical_partitioning as with_logical_partitioning, + with_logical_partitioning, ) -from .stochastic import Dropout as Dropout -from .summary import tabulate -from .transforms import ( +from flax.linen.stochastic import Dropout +from flax.linen.summary import tabulate +from flax.linen.transforms import ( add_metadata_axis, - checkpoint as checkpoint, - cond as cond, - custom_vjp as custom_vjp, - jit as jit, - jvp as jvp, - map_variables as map_variables, - named_call as named_call, - remat_scan as remat_scan, - remat as remat, - scan as scan, - switch as switch, - vjp as vjp, - grad as grad, - value_and_grad as value_and_grad, - vmap as vmap, - while_loop as while_loop, + checkpoint, + cond, + custom_vjp, + grad, + jit, + jvp, + map_variables, + named_call, + remat, + remat_scan, + scan, + switch, + value_and_grad, + vjp, + vmap, + while_loop, ) # pylint: enable=g-multiple-import diff --git a/flax/linen/initializers.py b/flax/linen/initializers.py index 412ab1cc..441268a3 100644 --- a/flax/linen/initializers.py +++ b/flax/linen/initializers.py @@ -16,26 +16,26 @@ # pylint: disable=unused-import # re-export initializer functions from jax.nn -from jax.nn.initializers import Initializer as Initializer -from jax.nn.initializers import constant as constant -from jax.nn.initializers import delta_orthogonal as delta_orthogonal -from jax.nn.initializers import glorot_normal as glorot_normal -from jax.nn.initializers import glorot_uniform as glorot_uniform -from jax.nn.initializers import he_normal as he_normal -from jax.nn.initializers import he_uniform as he_uniform -from jax.nn.initializers import kaiming_normal as kaiming_normal -from jax.nn.initializers import kaiming_uniform as kaiming_uniform -from jax.nn.initializers import lecun_normal as lecun_normal -from jax.nn.initializers import lecun_uniform as lecun_uniform -from jax.nn.initializers import normal as normal -from jax.nn.initializers import ones as ones -from jax.nn.initializers import orthogonal as orthogonal -from jax.nn.initializers import truncated_normal as truncated_normal -from jax.nn.initializers import uniform as uniform -from jax.nn.initializers import variance_scaling as variance_scaling -from jax.nn.initializers import xavier_normal as xavier_normal -from jax.nn.initializers import xavier_uniform as xavier_uniform -from jax.nn.initializers import zeros as zeros +from jax.nn.initializers import Initializer +from jax.nn.initializers import constant +from jax.nn.initializers import delta_orthogonal +from jax.nn.initializers import glorot_normal +from jax.nn.initializers import glorot_uniform +from jax.nn.initializers import he_normal +from jax.nn.initializers import he_uniform +from jax.nn.initializers import kaiming_normal +from jax.nn.initializers import kaiming_uniform +from jax.nn.initializers import lecun_normal +from jax.nn.initializers import lecun_uniform +from jax.nn.initializers import normal +from jax.nn.initializers import ones +from jax.nn.initializers import orthogonal +from jax.nn.initializers import truncated_normal +from jax.nn.initializers import uniform +from jax.nn.initializers import variance_scaling +from jax.nn.initializers import xavier_normal +from jax.nn.initializers import xavier_uniform +from jax.nn.initializers import zeros # pylint: enable=unused-import diff --git a/flax/struct.py b/flax/struct.py index 522b8163..11a23cae 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -17,13 +17,12 @@ import dataclasses from typing import TypeVar +from flax import serialization import jax from typing_extensions import ( - dataclass_transform, # pytype: disable=not-supported-yet + dataclass_transform, # pytype: disable=not-supported-yet ) -from . import serialization - _T = TypeVar('_T') diff --git a/flax/testing/__init__.py b/flax/testing/__init__.py index 794cc8d1..ea4ac715 100644 --- a/flax/testing/__init__.py +++ b/flax/testing/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - """Flax testing utilities.""" -from .benchmark import Benchmark +from flax.testing.benchmark import Benchmark diff --git a/flax/traverse_util.py b/flax/traverse_util.py index 1ff02c1f..835a5807 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -56,15 +56,13 @@ import abc import copy import dataclasses -import warnings from typing import Any, Callable, Tuple - -import jax +import warnings import flax +from flax import struct from flax.core.scope import VariableDict - -from . import struct +import jax Path = Tuple[str, ...]