diff --git a/CHANGELOG.md b/CHANGELOG.md index a9c7e77074..bdb625c280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,8 @@ vNext - - - -- -- +- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better + defaults for common use cases. - - - diff --git a/docs/api_reference/flax.linen/module.rst b/docs/api_reference/flax.linen/module.rst index 8760fbe0e2..ff2a2037fb 100644 --- a/docs/api_reference/flax.linen/module.rst +++ b/docs/api_reference/flax.linen/module.rst @@ -5,4 +5,4 @@ Module .. currentmodule:: flax.linen .. autoclass:: Module - :members: setup, variable, param, bind, unbind, apply, init, init_with_output, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb \ No newline at end of file + :members: setup, variable, param, bind, unbind, apply, init, init_with_output, copy, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb diff --git a/flax/linen/module.py b/flax/linen/module.py index 0770391e6e..37175e7df8 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1434,13 +1434,17 @@ def path(self): def clone( self: M, *, - parent: Optional[Union[Scope, 'Module']] = None, + parent: Optional[Union[Scope, 'Module', _Sentinel]] = None, _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, _reset_names: bool = False, **updates, ) -> M: """Creates a clone of this Module, with optionally updated arguments. + NOTE: end users are encouraged to use the `copy` method. `clone` is used + primarily for internal routines, and `copy` offers simpler arguments and + better defaults. + Args: parent: The parent of the clone. The clone will have no parent if no explicit parent is specified. @@ -1503,6 +1507,29 @@ def clone_fn(m: Module) -> Module: return module + def copy( + self: M, + *, + parent: Optional[Union[Scope, 'Module', _Sentinel]] = _unspecified_parent, + name: Optional[str] = None, + **updates, + ) -> M: + """Creates a copy of this Module, with optionally updated arguments. + + Args: + parent: The parent of the copy. By default the current module is taken + as parent if not explicitly specified. + name: A new name for the copied Module, by default a new automatic name + will be given. + **updates: Attribute updates. + + Returns: + A copy of the this Module with the updated name, parent, and attributes. + """ + return self.clone( + parent=parent, name=name, _deep_clone=True, _reset_names=False, **updates + ) + @overload def variable( self, diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index c58331629f..8cd792859b 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2736,11 +2736,62 @@ def __call__(self, x): variables = model.init(jax.random.key(0), x) output = model.apply(variables, x) self.assertTrue( - jnp.all( - variables['params']['Child_0']['w'] - == variables['params']['Child_1']['w'] - ) + variables['params']['Child_0']['w'].shape + == variables['params']['Child_1']['w'].shape + ) + + def test_copy_method(self): + class Parent(nn.Module): + @nn.compact + def __call__(self, x): + child = nn.Dense( + 2, + ) + x = child(x) + x = child.copy()(x) + return x + + model = Parent() + x = jnp.ones((2, 2)) + variables = model.init(jax.random.key(0), x) + output = model.apply(variables, x) + self.assertTrue( + variables['params']['Dense_0']['kernel'].shape + == variables['params']['Dense_1']['kernel'].shape + ) + + def test_copy_from_template(self): + class Child(nn.Module): + @nn.compact + def __call__(self, x): + w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) + return x @ w + + class Parent(nn.Module): + num_layers: int + child_template: Child + + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = self.child_template.copy()(x) + for i in range(self.num_layers): + x = self.child_template.copy(name=f'next_layer_{i}')(x) + return x + + model = Parent(num_layers=2, child_template=Child()) + x = jnp.ones((32, 5)) + variables = model.init(jax.random.key(0), x) + output = model.apply(variables, x) + self.assertTrue( + variables['params']['Child_0']['w'].shape + == variables['params']['Child_1']['w'].shape ) + self.assertIn('Child_0', variables['params']) + self.assertIn('Child_1', variables['params']) + self.assertIn('next_layer_0', variables['params']) + self.assertIn('next_layer_1', variables['params']) + self.assertNotIn('child_template', variables['params']) class FrozenDictTests(absltest.TestCase):