Skip to content

Commit

Permalink
Merge pull request #2809 from cgarciae:string-methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 504400092
  • Loading branch information
Flax Authors committed Jan 25, 2023
2 parents e83711b + 05053d9 commit a309273
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
37 changes: 29 additions & 8 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def apply(self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]:
Expand All @@ -1383,6 +1383,11 @@ def apply(self,
encoded = model.apply({'params': params}, x, method=model.encode)
You can also pass a string to a callable attribute of the module. For
example, the previous can be written as::
encoded = model.apply({'params': params}, x, method='encode')
Note ``method`` can also be a function that is not defined in
``Transformer``. In that case, the function should have at least one
argument representing an instance of the Module class::
Expand All @@ -1402,7 +1407,8 @@ def other_fn(instance, ...):
The "params" PRNG sequence is used to initialize parameters.
method: A function to call apply on. This is generally a function in the
module. If provided, applies this method. If not provided, applies the
``__call__`` method of the module.
``__call__`` method of the module. A string can also be provided to
specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand All @@ -1421,7 +1427,13 @@ def other_fn(instance, ...):
"""
Module._module_checks(self)

if method is None:
if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
if not callable(method):
class_name = type(self).__name__
raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.')
elif method is None:
method = self.__call__
method = _get_unbound_fn(method)
return apply(
Expand All @@ -1434,7 +1446,7 @@ def other_fn(instance, ...):
def init_with_output(self,
rngs: Union[PRNGKey, RNGSequences],
*args,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = DenyList('intermediates'),
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> Tuple[Any, FrozenVariableDict]:
Expand All @@ -1444,7 +1456,8 @@ def init_with_output(self,
rngs: The rngs for the variable collections.
*args: Named arguments passed to the init function.
method: An optional method. If provided, applies this method. If not
provided, applies the ``__call__`` method.
provided, applies the ``__call__`` method. A string can also be'
provided to specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand All @@ -1469,7 +1482,14 @@ def init_with_output(self,
'RNGs should be of shape (2,) or KeyArray in Module '
f'{self.__class__.__name__}, but rngs are: {rngs}')
rngs = {'params': rngs}
if method is None:

if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
if not callable(method):
class_name = type(self).__name__
raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.')
elif method is None:
method = self.__call__
method = _get_unbound_fn(method)
return init_with_output(
Expand All @@ -1483,7 +1503,7 @@ def init_with_output(self,
def init(self,
rngs: Union[PRNGKey, RNGSequences],
*args,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = DenyList('intermediates'),
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> FrozenVariableDict:
Expand Down Expand Up @@ -1548,7 +1568,8 @@ def init(self,
rngs: The rngs for the variable collections.
*args: Named arguments passed to the init function.
method: An optional method. If provided, applies this method. If not
provided, applies the ``__call__`` method.
provided, applies the ``__call__`` method. A string can also be
provided to specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand Down
18 changes: 17 additions & 1 deletion tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def __call__(self, x):
def test_module_apply_method(self):

class Foo(nn.Module):
not_callable: int = 1

@nn.compact
def __call__(self):
Expand All @@ -849,8 +850,23 @@ def test(self):
with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
Foo().apply({}, method=lambda: True)

with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
# string method names are also allowed.
Foo().apply({}, method='test')
# test same for init.
Foo().init({}, method='test')

# non-existent attribute names will yield AttributeError.
with self.assertRaisesRegex(AttributeError, "allowed_apply_fn"):
Foo().apply({}, method='allowed_apply_fn')
# test same for init.
Foo().init({}, method='allowed_apply_fn')

# attributes which are not callables yield TypeError.
with self.assertRaisesRegex(TypeError, "'Foo.not_callable' must be a callable"):
Foo().apply({}, method='not_callable')
# test same for init.
Foo().init({}, method='not_callable')


def test_call_unbound_compact_module_methods(self):
dense = Dense(3)
Expand Down

0 comments on commit a309273

Please sign in to comment.