Skip to content

Commit

Permalink
[nnx] fix ToNNX linen_attributes update
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 16, 2025
1 parent 1961c12 commit 342a5e8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
6 changes: 5 additions & 1 deletion flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __call__(
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)

nnx_attrs = bv.linen_vars_to_nnx_attrs(variables)
linen_attributes = set()
linen_attributes = set(self.linen_attributes)
for attr_name, value in nnx_attrs.items():
setattr(self, attr_name, value)
linen_attributes.add(attr_name)
Expand All @@ -167,13 +167,17 @@ def __call__(
if kwargs.get('mutable', False) != False:
out, updates = out
nnx_attrs = bv.linen_vars_to_nnx_attrs(updates)
linen_attributes = set(self.linen_attributes)
for attr_name, value in nnx_attrs.items():
linen_attributes.add(attr_name)
if hasattr(self, attr_name) and isinstance(value, dict):
original_tree = getattr(self, attr_name)
setattr(self, attr_name, original_tree | value)
else:
setattr(self, attr_name, value)

self.linen_attributes = tuple(linen_attributes) # make it hashable

return out


Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __call__(self) -> jax.Array:
]


class Rngs(Object, tp.Mapping[str, tp.Callable[[], jax.Array]]):
class Rngs(Object, tp.Mapping[str, RngStream]):
"""NNX rng container class. To instantiate the ``Rngs``, pass
in an integer, specifying the starting seed. ``Rngs`` can have
different "streams", allowing the user to generate different
Expand Down
22 changes: 22 additions & 0 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,28 @@ def get_weights(model):
self.assertEqual(jax.tree.structure(from_top_weights),
jax.tree.structure(from_middle_weights))

def test_adding_new_attributes(self):
class LinenModule(nn.Module):
@nn.compact
def __call__(self):
if not self.is_initializing() and self.is_mutable_collection('cache'):
self.put_variable('cache', 'x', 0)
res = self.get_variable('cache', 'x')
return res

class NNXModule(nnx.Module):
def __init__(self):
self.module = nnx.bridge.ToNNX(LinenModule()).lazy_init()

def __call__(self):
result1 = self.module(mutable=['cache'])
assert result1 == 0
result2 = self.module()
assert result2 == 0, result2 # fails: result2 is None

module = NNXModule()
module()

##################
### NNXToLinen ###
##################
Expand Down

0 comments on commit 342a5e8

Please sign in to comment.