Skip to content

Commit

Permalink
Change relative imports into absolute imports.
Browse files Browse the repository at this point in the history
Remove redundant as <...> clauses from imports.

PiperOrigin-RevId: 583426671
  • Loading branch information
DrMarcII authored and Flax Authors committed Nov 17, 2023
1 parent 70214f4 commit 6bee7bd
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 358 deletions.
68 changes: 34 additions & 34 deletions flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .axes_scan import broadcast as broadcast
from .frozen_dict import (
FrozenDict as FrozenDict,
copy as copy,
freeze as freeze,
pop as pop,
pretty_repr as pretty_repr,
unfreeze as unfreeze,
from flax.core.axes_scan import broadcast
from flax.core.frozen_dict import (
FrozenDict,
copy,
freeze,
pop,
pretty_repr,
unfreeze,
)
from .lift import (
custom_vjp as custom_vjp,
jit as jit,
jvp as jvp,
remat_scan as remat_scan,
remat as remat,
scan as scan,
vjp as vjp,
vmap as vmap,
while_loop as while_loop,
from flax.core.lift import (
custom_vjp,
jit,
jvp,
remat,
remat_scan,
scan,
vjp,
vmap,
while_loop,
)
from .meta import (
AxisMetadata as AxisMetadata,
map_axis_meta as map_axis_meta,
unbox as unbox,
from flax.core.meta import (
AxisMetadata,
map_axis_meta,
unbox,
)
from .scope import (
Array as Array,
DenyList as DenyList,
Scope as Scope,
apply as apply,
bind as bind,
init as init,
lazy_init as lazy_init,
from flax.core.scope import (
Array,
DenyList,
Scope,
apply,
bind,
init,
lazy_init,
)
from .tracers import (
check_trace_level as check_trace_level,
current_trace as current_trace,
trace_level as trace_level,
from flax.core.tracers import (
check_trace_level,
current_trace,
trace_level,
)
36 changes: 17 additions & 19 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import collections
import dataclasses
import functools
import warnings
from typing import (
Any,
Callable,
Expand All @@ -32,27 +31,26 @@
TypeVar,
Union,
)

import jax
from jax import random
import warnings

from flax import traceback_util

from . import axes_scan, meta
from .frozen_dict import freeze, unfreeze
from .scope import (
CollectionFilter,
DenyList, # pylint: disable=g-multiple-import
Filter,
PRNGSequenceFilter,
Scope,
group_collections,
in_filter,
intersect_filters,
is_filter_empty,
subtract_filters,
union_filters,
from flax.core import axes_scan, meta
from flax.core.frozen_dict import freeze, unfreeze
from flax.core.scope import (
CollectionFilter,
DenyList, # pylint: disable=g-multiple-import
Filter,
PRNGSequenceFilter,
Scope,
group_collections,
in_filter,
intersect_filters,
is_filter_empty,
subtract_filters,
union_filters,
)
import jax
from jax import random

traceback_util.register_exclusion(__file__)

Expand Down
68 changes: 34 additions & 34 deletions flax/core/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,43 @@

# pylint: disable=g-multiple-import
# re-export commonly used modules and functions
from flax.linen import activation as activation
from flax.linen import initializers as initializers
from flax.linen.activation import (
celu as celu,
elu as elu,
gelu as gelu,
glu as glu,
leaky_relu as leaky_relu,
log_sigmoid as log_sigmoid,
log_softmax as log_softmax,
relu as relu,
sigmoid as sigmoid,
silu as silu,
soft_sign as soft_sign,
softmax as softmax,
softplus as softplus,
swish as swish,
tanh as tanh,
from flax.core.nn.attention import (
dot_product_attention,
multi_head_dot_product_attention,
)
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
from .attention import (
dot_product_attention as dot_product_attention,
multi_head_dot_product_attention as multi_head_dot_product_attention,
from flax.core.nn.linear import (
Embedding,
conv,
conv_transpose,
dense,
dense_general,
embedding,
)
from .linear import (
Embedding as Embedding,
conv_transpose as conv_transpose,
conv as conv,
dense_general as dense_general,
dense as dense,
embedding as embedding,
from flax.core.nn.normalization import (
batch_norm,
group_norm,
layer_norm,
)
from .normalization import (
batch_norm as batch_norm,
group_norm as group_norm,
layer_norm as layer_norm,
from flax.core.nn.stochastic import dropout
from flax.linen import activation
from flax.linen import initializers
from flax.linen.activation import (
celu,
elu,
gelu,
glu,
leaky_relu,
log_sigmoid,
log_softmax,
relu,
sigmoid,
silu,
soft_sign,
softmax,
softplus,
swish,
tanh,
)
from .stochastic import dropout as dropout
from flax.linen.pooling import (avg_pool, max_pool)

# pylint: enable=g-multiple-import
14 changes: 6 additions & 8 deletions flax/core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@

"""Attention core modules for Flax."""

import functools
from collections.abc import Iterable # pylint: disable=g-importing-member
import functools
from typing import Any, Callable, Union

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax, random

from flax import struct
from flax.core import Scope
from flax.core.nn.linear import default_kernel_init, dense_general
from flax.linen import initializers

from .linear import default_kernel_init, dense_general
import jax
from jax import lax, random
import jax.numpy as jnp
import numpy as np


def dot_product_attention(
Expand Down
16 changes: 7 additions & 9 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,16 @@
overload,
)

import jax
import numpy as np
from jax import numpy as jnp
from jax import random, tree_util

from flax import config as config
from flax import config
from flax import configurations as legacy_config # only for flax_lazy_rng
from flax import errors, struct, traceback_util
from flax.core import meta, partial_eval, tracers
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.ids import uuid

from . import meta, partial_eval, tracers
from .frozen_dict import FrozenDict, freeze, unfreeze
import jax
from jax import numpy as jnp
from jax import random, tree_util
import numpy as np

traceback_util.register_exclusion(__file__)

Expand Down
3 changes: 1 addition & 2 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

"""Functionality for inspecting jax tracers."""

from flax import errors
import jax

from .. import errors


def current_trace():
"""Returns the innermost Jax tracer."""
Expand Down
2 changes: 1 addition & 1 deletion flax/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@
TODO: Make "variable dict" design note, and link to it from here.
"""

from .scope import Variable
from flax.core.scope import Variable
Loading

0 comments on commit 6bee7bd

Please sign in to comment.