Skip to content

Commit

Permalink
add LinearGeneral and MultiHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 16, 2023
1 parent 091df13 commit a9e1902
Show file tree
Hide file tree
Showing 11 changed files with 920 additions and 24 deletions.
6 changes: 4 additions & 2 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .nnx.dataclasses import treenode_field as treenode_field
from .nnx.dataclasses import variable_field as variable_field
from .nnx.errors import TraceContextError as TraceContextError
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.flaglib import flags as flags
from .nnx.helpers import Dict as Dict
from .nnx.helpers import Sequence as Sequence
Expand Down Expand Up @@ -58,14 +60,14 @@
from .nnx.nn.activations import standardize as standardize
from .nnx.nn.activations import swish as swish
from .nnx.nn.activations import tanh as tanh
from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.stochastic import Dropout as Dropout
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.pytreelib import Pytree as Pytree
from .nnx.pytreelib import TreeNode as TreeNode
from .nnx.rnglib import Rngs as Rngs
Expand Down
12 changes: 5 additions & 7 deletions flax/experimental/nnx/nnx/flaglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@

@dataclasses.dataclass
class FlagsContext(threading.local):
flags_stack: tp.List[MappingProxyType[str, tp.Hashable]] = dataclasses.field(
flags_stack: tp.List[MappingProxyType[str, tp.Any]] = dataclasses.field(
default_factory=lambda: [MappingProxyType({})]
)


FLAGS_CONTEXT = FlagsContext()


class Flags(tp.Mapping[str, tp.Hashable]):
class Flags(tp.Mapping[str, tp.Any]):
__slots__ = ()

def __getitem__(self, name: str) -> tp.Hashable:
def __getitem__(self, name: str) -> tp.Any:
current_flags = FLAGS_CONTEXT.flags_stack[-1]
if name not in current_flags:
raise ValueError(f'Unknown Flag: {name}')
Expand All @@ -50,7 +50,7 @@ def __contains__(self, name: tp.Any) -> bool:
return name in FLAGS_CONTEXT.flags_stack[-1]

@contextmanager
def __call__(self, **kwargs: tp.Hashable):
def __call__(self, **kwargs: tp.Any):
current_flags = FLAGS_CONTEXT.flags_stack[-1]
FLAGS_CONTEXT.flags_stack.append(
MappingProxyType(dict(current_flags, **kwargs))
Expand All @@ -60,9 +60,7 @@ def __call__(self, **kwargs: tp.Hashable):
finally:
FLAGS_CONTEXT.flags_stack.pop()

def get(
self, name: str, default: tp.Hashable = None
) -> tp.Optional[tp.Hashable]:
def get(self, name: str, default: tp.Any = None) -> tp.Optional[tp.Any]:
return FLAGS_CONTEXT.flags_stack[-1].get(name, default)


Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def __len__(self) -> int:


class Sequence(Module, tp.Generic[A]):
def __init__(self, iterable: tp.Iterable[A]):
def __init__(self, layers: tp.Iterable[A]):
i = 0
for i, value in enumerate(iterable):
for i, value in enumerate(layers):
setattr(self, str(i), value)
self._length = i + 1

Expand Down
12 changes: 3 additions & 9 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@
import numpy as np
import typing_extensions as tpe

from flax.experimental.nnx.nnx import (
errors,
filterlib,
ids,
reprlib,
tracers,
)
from flax.experimental.nnx.nnx import errors, filterlib, ids, reprlib, tracers
from flax.experimental.nnx.nnx import variables as variableslib
from flax.experimental.nnx.nnx.rnglib import Rngs
from flax.experimental.nnx.nnx.state import State
Expand Down Expand Up @@ -917,12 +911,12 @@ def _update_module_static_state_recursive(
setattr(module, name, value)


def first_from(*args: tp.Optional[A]) -> A:
def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
"""Return the first non-None argument."""
for arg in args:
if arg is not None:
return arg
raise ValueError('No non-None arguments found.')
raise ValueError(f'No non-None arguments found for {arg_name!r}')


def merge(
Expand Down
Loading

0 comments on commit a9e1902

Please sign in to comment.