From 4d8676745f10ab8c9f01a853771bfdc1149c45c7 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 28 Oct 2023 10:27:27 +0000 Subject: [PATCH] simplify Rngs --- flax/experimental/nnx/nnx/rngslib.py | 3 +++ flax/experimental/nnx/nnx/variables.py | 28 ++++++-------------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/flax/experimental/nnx/nnx/rngslib.py b/flax/experimental/nnx/nnx/rngslib.py index be60f02b5f..9ccd92f178 100644 --- a/flax/experimental/nnx/nnx/rngslib.py +++ b/flax/experimental/nnx/nnx/rngslib.py @@ -110,6 +110,9 @@ def __getitem__(self, name: str) -> tp.Callable[[], jax.Array]: __getattr__ = __getitem__ + def __call__(self): + return self.params() + def __iter__(self) -> tp.Iterator[str]: return iter(self._rngs) diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index a9d8c998c2..4ad4ff1e59 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -289,7 +289,7 @@ def replace(self, **kwargs) -> 'Variable[tp.Any]': if kwargs: return value.replace(**kwargs) else: - return value + return value # get and update attributes attributes = vars(self).copy() @@ -392,29 +392,13 @@ class Intermediate(Variable[A]): class Rng(Variable[jax.Array]): tag: str - - def __init__( - self, - value: jax.Array, - *, - tag: str, - get_value_hooks: tp.Union[ - GetValueHook[jax.Array], tp.Sequence[GetValueHook[jax.Array]] - ] = (), - **metadata: tp.Any, - ): - def split_key_hook(variable: 'Variable[jax.Array]', key: jax.Array): - variable.value, key = jax.random.split(key) - return key - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks, split_key_hook) - else: - get_value_hooks = (*get_value_hooks, split_key_hook) + def __init__(self, value: jax.Array, *, tag: str, **metadata: tp.Any): + super().__init__(value, tag=tag, **metadata) - super().__init__( - value, tag=tag, get_value_hooks=get_value_hooks, **metadata - ) + def on_get_value(self, value: jax.Array): + self.value, value = jax.random.split(value) + return value def with_metadata(