From 342a5e8322f684a7ba7b9107b7c73cd34854f7c0 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 16 Jan 2025 03:35:43 +0000 Subject: [PATCH] [nnx] fix ToNNX linen_attributes update --- flax/nnx/bridge/wrappers.py | 6 +++++- flax/nnx/rnglib.py | 2 +- tests/nnx/bridge/wrappers_test.py | 22 ++++++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index eed4ba2f7a..ab673644c8 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -148,7 +148,7 @@ def __call__( out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs) nnx_attrs = bv.linen_vars_to_nnx_attrs(variables) - linen_attributes = set() + linen_attributes = set(self.linen_attributes) for attr_name, value in nnx_attrs.items(): setattr(self, attr_name, value) linen_attributes.add(attr_name) @@ -167,13 +167,17 @@ def __call__( if kwargs.get('mutable', False) != False: out, updates = out nnx_attrs = bv.linen_vars_to_nnx_attrs(updates) + linen_attributes = set(self.linen_attributes) for attr_name, value in nnx_attrs.items(): + linen_attributes.add(attr_name) if hasattr(self, attr_name) and isinstance(value, dict): original_tree = getattr(self, attr_name) setattr(self, attr_name, original_tree | value) else: setattr(self, attr_name, value) + self.linen_attributes = tuple(linen_attributes) # make it hashable + return out diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index ab9817acaa..a100740afa 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -80,7 +80,7 @@ def __call__(self) -> jax.Array: ] -class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]): +class Rngs(Object, tp.Mapping[str, RngStream]): """NNX rng container class. To instantiate the ``Rngs``, pass in an integer, specifying the starting seed. ``Rngs`` can have different "streams", allowing the user to generate different diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 5b65603a24..8c58491261 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -217,6 +217,28 @@ def get_weights(model): self.assertEqual(jax.tree.structure(from_top_weights), jax.tree.structure(from_middle_weights)) + def test_adding_new_attributes(self): + class LinenModule(nn.Module): + @nn.compact + def __call__(self): + if not self.is_initializing() and self.is_mutable_collection('cache'): + self.put_variable('cache', 'x', 0) + res = self.get_variable('cache', 'x') + return res + + class NNXModule(nnx.Module): + def __init__(self): + self.module = nnx.bridge.ToNNX(LinenModule()).lazy_init() + + def __call__(self): + result1 = self.module(mutable=['cache']) + assert result1 == 0 + result2 = self.module() + assert result2 == 0, result2 # fails: result2 is None + + module = NNXModule() + module() + ################## ### NNXToLinen ### ##################