Skip to content

Commit

Permalink
simplify Rngs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 28, 2023
1 parent 7bddd53 commit 4d86767
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
3 changes: 3 additions & 0 deletions flax/experimental/nnx/nnx/rngslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __getitem__(self, name: str) -> tp.Callable[[], jax.Array]:

__getattr__ = __getitem__

def __call__(self):
return self.params()

def __iter__(self) -> tp.Iterator[str]:
return iter(self._rngs)

Expand Down
28 changes: 6 additions & 22 deletions flax/experimental/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def replace(self, **kwargs) -> 'Variable[tp.Any]':
if kwargs:
return value.replace(**kwargs)
else:
return value
return value

# get and update attributes
attributes = vars(self).copy()
Expand Down Expand Up @@ -392,29 +392,13 @@ class Intermediate(Variable[A]):

class Rng(Variable[jax.Array]):
tag: str

def __init__(
self,
value: jax.Array,
*,
tag: str,
get_value_hooks: tp.Union[
GetValueHook[jax.Array], tp.Sequence[GetValueHook[jax.Array]]
] = (),
**metadata: tp.Any,
):
def split_key_hook(variable: 'Variable[jax.Array]', key: jax.Array):
variable.value, key = jax.random.split(key)
return key

if callable(get_value_hooks):
get_value_hooks = (get_value_hooks, split_key_hook)
else:
get_value_hooks = (*get_value_hooks, split_key_hook)
def __init__(self, value: jax.Array, *, tag: str, **metadata: tp.Any):
super().__init__(value, tag=tag, **metadata)

super().__init__(
value, tag=tag, get_value_hooks=get_value_hooks, **metadata
)
def on_get_value(self, value: jax.Array):
self.value, value = jax.random.split(value)
return value


def with_metadata(
Expand Down

0 comments on commit 4d86767

Please sign in to comment.