Skip to content

Commit

Permalink
Merge pull request #3461 from levskaya:copymethod
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579888321
  • Loading branch information
Flax Authors committed Nov 6, 2023
2 parents 055e28f + 40e2b93 commit 85245ad
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ vNext
-
-
-
-
-
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
defaults for common use cases.
-
-
-
Expand Down
2 changes: 1 addition & 1 deletion docs/api_reference/flax.linen/module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Module
.. currentmodule:: flax.linen

.. autoclass:: Module
:members: setup, variable, param, bind, unbind, apply, init, init_with_output, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb
:members: setup, variable, param, bind, unbind, apply, init, init_with_output, copy, make_rng, sow, variables, Variable, __setattr__, tabulate, is_initializing, perturb
29 changes: 28 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,13 +1434,17 @@ def path(self):
def clone(
self: M,
*,
parent: Optional[Union[Scope, 'Module']] = None,
parent: Optional[Union[Scope, 'Module', _Sentinel]] = None,
_deep_clone: Union[bool, weakref.WeakValueDictionary] = False,
_reset_names: bool = False,
**updates,
) -> M:
"""Creates a clone of this Module, with optionally updated arguments.
NOTE: end users are encouraged to use the `copy` method. `clone` is used
primarily for internal routines, and `copy` offers simpler arguments and
better defaults.
Args:
parent: The parent of the clone. The clone will have no parent if no
explicit parent is specified.
Expand Down Expand Up @@ -1503,6 +1507,29 @@ def clone_fn(m: Module) -> Module:

return module

def copy(
self: M,
*,
parent: Optional[Union[Scope, 'Module', _Sentinel]] = _unspecified_parent,
name: Optional[str] = None,
**updates,
) -> M:
"""Creates a copy of this Module, with optionally updated arguments.
Args:
parent: The parent of the copy. By default the current module is taken
as parent if not explicitly specified.
name: A new name for the copied Module, by default a new automatic name
will be given.
**updates: Attribute updates.
Returns:
A copy of the this Module with the updated name, parent, and attributes.
"""
return self.clone(
parent=parent, name=name, _deep_clone=True, _reset_names=False, **updates
)

@overload
def variable(
self,
Expand Down
59 changes: 55 additions & 4 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,11 +2736,62 @@ def __call__(self, x):
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(
jnp.all(
variables['params']['Child_0']['w']
== variables['params']['Child_1']['w']
)
variables['params']['Child_0']['w'].shape
== variables['params']['Child_1']['w'].shape
)

def test_copy_method(self):
class Parent(nn.Module):
@nn.compact
def __call__(self, x):
child = nn.Dense(
2,
)
x = child(x)
x = child.copy()(x)
return x

model = Parent()
x = jnp.ones((2, 2))
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(
variables['params']['Dense_0']['kernel'].shape
== variables['params']['Dense_1']['kernel'].shape
)

def test_copy_from_template(self):
class Child(nn.Module):
@nn.compact
def __call__(self, x):
w = self.param('w', nn.initializers.zeros, (5, x.shape[1]))
return x @ w

class Parent(nn.Module):
num_layers: int
child_template: Child

@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
x = self.child_template.copy()(x)
for i in range(self.num_layers):
x = self.child_template.copy(name=f'next_layer_{i}')(x)
return x

model = Parent(num_layers=2, child_template=Child())
x = jnp.ones((32, 5))
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(
variables['params']['Child_0']['w'].shape
== variables['params']['Child_1']['w'].shape
)
self.assertIn('Child_0', variables['params'])
self.assertIn('Child_1', variables['params'])
self.assertIn('next_layer_0', variables['params'])
self.assertIn('next_layer_1', variables['params'])
self.assertNotIn('child_template', variables['params'])


class FrozenDictTests(absltest.TestCase):
Expand Down

0 comments on commit 85245ad

Please sign in to comment.