Skip to content

Commit

Permalink
add LinearGeneral and MultiHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 23, 2023
1 parent e172c76 commit 7ff6e4b
Show file tree
Hide file tree
Showing 13 changed files with 926 additions and 21 deletions.
6 changes: 3 additions & 3 deletions examples/lm1b/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def fill_unspecified_mesh_axes(
" parallelism axis. At most one axis can be unspecified."
)

determined_val = target_product / np.product(parallelism_vals) * -1
determined_val = target_product / np.prod(parallelism_vals) * -1

assert determined_val >= 1 and determined_val.is_integer, (
"Unspecified value unable to be determined with the given "
Expand All @@ -97,9 +97,9 @@ def fill_unspecified_mesh_axes(

target_type = "slices" if parallelism_type == "DCN" else "devices per slice"

assert np.product(parallelism_vals) == target_product, (
assert np.prod(parallelism_vals) == target_product, (
f"Number of {target_type} {target_product} does not match the product"
f" of the {parallelism_type} parallelism {np.product(parallelism_vals)}"
f" of the {parallelism_type} parallelism {np.prod(parallelism_vals)}"
)

return parallelism_vals
Expand Down
2 changes: 2 additions & 0 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@
from .nnx.nn.activations import standardize as standardize
from .nnx.nn.activations import swish as swish
from .nnx.nn.activations import tanh as tanh
from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.stochastic import Dropout as Dropout
Expand Down
12 changes: 5 additions & 7 deletions flax/experimental/nnx/nnx/flaglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@

@dataclasses.dataclass
class FlagsContext(threading.local):
flags_stack: tp.List[MappingProxyType[str, tp.Hashable]] = dataclasses.field(
flags_stack: tp.List[MappingProxyType[str, tp.Any]] = dataclasses.field(
default_factory=lambda: [MappingProxyType({})]
)


FLAGS_CONTEXT = FlagsContext()


class Flags(tp.Mapping[str, tp.Hashable]):
class Flags(tp.Mapping[str, tp.Any]):
__slots__ = ()

def __getitem__(self, name: str) -> tp.Hashable:
def __getitem__(self, name: str) -> tp.Any:
current_flags = FLAGS_CONTEXT.flags_stack[-1]
if name not in current_flags:
raise ValueError(f'Unknown Flag: {name}')
Expand All @@ -50,7 +50,7 @@ def __contains__(self, name: tp.Any) -> bool:
return name in FLAGS_CONTEXT.flags_stack[-1]

@contextmanager
def __call__(self, **kwargs: tp.Hashable):
def __call__(self, **kwargs: tp.Any):
current_flags = FLAGS_CONTEXT.flags_stack[-1]
FLAGS_CONTEXT.flags_stack.append(
MappingProxyType(dict(current_flags, **kwargs))
Expand All @@ -60,9 +60,7 @@ def __call__(self, **kwargs: tp.Hashable):
finally:
FLAGS_CONTEXT.flags_stack.pop()

def get(
self, name: str, default: tp.Hashable = None
) -> tp.Optional[tp.Hashable]:
def get(self, name: str, default: tp.Any = None) -> tp.Optional[tp.Any]:
return FLAGS_CONTEXT.flags_stack[-1].get(name, default)


Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def __len__(self) -> int:


class Sequence(Module, tp.Generic[A]):
def __init__(self, iterable: tp.Iterable[A]):
def __init__(self, layers: tp.Iterable[A]):
i = 0
for i, value in enumerate(iterable):
for i, value in enumerate(layers):
setattr(self, str(i), value)
self._length = i + 1

Expand Down
7 changes: 2 additions & 5 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,12 @@ def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]):
vars(node).update(items)


# -------------------------
# utils
# -------------------------
def first_from(*args: tp.Optional[A]) -> A:
def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
"""Return the first non-None argument."""
for arg in args:
if arg is not None:
return arg
raise ValueError('No non-None arguments found.')
raise ValueError(f'No non-None arguments found for {arg_name!r}')


def merge(
Expand Down
Loading

0 comments on commit 7ff6e4b

Please sign in to comment.