From 3b1a8073586b21adc9087bff3f1d50a89545c522 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Mon, 22 Jan 2024 17:18:01 -0800 Subject: [PATCH] add compact_name_scope v2 PiperOrigin-RevId: 600615557 --- flax/linen/__init__.py | 1 - flax/linen/module.py | 114 +------------------------------ tests/linen/linen_module_test.py | 70 ------------------- 3 files changed, 1 insertion(+), 184 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 1802a102b1..3106546cc3 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -106,7 +106,6 @@ Variable as Variable, apply as apply, compact as compact, - compact_name_scope as compact_name_scope, disable_named_call as disable_named_call, enable_named_call as enable_named_call, init_with_output as init_with_output, diff --git a/flax/linen/module.py b/flax/linen/module.py index 508198923e..8fb4497d9c 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -538,84 +538,6 @@ def nowrap(fun: _CallableT) -> _CallableT: return fun -def compact_name_scope(fun: _CallableT) -> _CallableT: - """Creates compact submodules from a method. - - This is a decorator that allows you to define compact submodules from a - method. It's intention is to make it easier to port code Haiku code to Flax - by providing the same functionality. - - Example:: - - >>> import flax.linen as nn - >>> import jax - >>> import jax.numpy as jnp - >>> from flax.core import pretty_repr - ... - >>> class Foo(nn.Module): - ... @nn.compact_name_scope - ... def up(self, x): - ... return nn.Dense(3)(x) - ... - ... @nn.compact_name_scope - ... def down(self, x): - ... return nn.Dense(3)(x) - ... - ... def __call__(self, x): - ... return self.up(x) + self.down(x) - ... - >>> module = Foo() - >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2))) - >>> params = variables['params'] - >>> print(pretty_repr(jax.tree_map(jnp.shape, params))) - { - down: { - Dense_0: { - bias: (3,), - kernel: (2, 3), - }, - }, - up: { - Dense_0: { - bias: (3,), - kernel: (2, 3), - }, - }, - } - - You can also use ``compact_name_scope`` inside ``@compact`` methods or even other - ``compact_name_scope`` methods. Methods that are decorated with ``compact_name_scope`` - can also be called directly from ``init`` or ``apply`` via the ``method`` argument:: - - >>> y_down = module.apply({'params': params}, jnp.ones((1, 2)), method='down') - >>> y_down.shape - (1, 3) - - Args: - fun: The Module method to mark as compact_name_scope. - - Returns: - The given function ``fun`` marked as compact_name_scope. - """ - - @functools.wraps(fun) - def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs): - name = fun.__name__ - if not hasattr(self, '_compact_name_scope_modules'): - raise ValueError( - f'Cannot call compact_name_scope method {name!r} on a Module that has not been ' - f'setup. This is likely because you are calling {name!r} ' - 'from outside of init or apply.' - ) - module = self._compact_name_scope_modules[name] - return module(*args, **kwargs) - - compact_name_scope_wrapper.compact_name_scope = True # type: ignore[attr-defined] - compact_name_scope_wrapper.inner_fun = fun # type: ignore[attr-defined] - compact_name_scope_wrapper.nowrap = True # type: ignore[attr-defined] - return compact_name_scope_wrapper # type: ignore[return-value] - - def _get_local_method_names( cls: Any, exclude: Iterable[str] = () ) -> Tuple[str, ...]: @@ -1033,7 +955,6 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: # We wrap user-defined methods including setup and __call__ to enforce # a number of different checks and to provide clear error messages. cls._verify_single_or_no_compact() - cls._find_compact_name_scope_methods() cls._wrap_module_attributes() # Set empty class defaults. cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] @@ -1125,17 +1046,6 @@ def _verify_single_or_no_compact(cls): if n_compact_fns > 1: raise errors.MultipleMethodsCompactError() - @classmethod - def _find_compact_name_scope_methods(cls): - """Finds all compact_name_scope methods in the class.""" - methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] - compact_name_scope_fns = tuple( - method_name - for method_name in methods - if hasattr(getattr(cls, method_name), 'compact_name_scope') - ) - cls._compact_name_scope_methods = compact_name_scope_fns - @classmethod def _wrap_module_attributes(cls): """Wraps user-defined non-inherited methods and descriptors with state @@ -1437,7 +1347,6 @@ def _register_submodules(self, name, val): def adopt_attr_modules(cache, queue, suffix, subvalue): if isinstance(subvalue, Module): - current_name = subvalue.name adopted_name = None if subvalue.parent is None: # Preserve sharing-by-reference relationships during adoption @@ -1457,11 +1366,7 @@ def adopt_attr_modules(cache, queue, suffix, subvalue): if subvalue.name is None: object.__setattr__(subvalue, 'parent', self) if adopted_name is None: - adopted_name = ( - f'{name}{suffix}' - if not isinstance(subvalue, NonTransparent) - else current_name - ) + adopted_name = f'{name}{suffix}' object.__setattr__(subvalue, 'name', adopted_name) queue.append(subvalue) return subvalue @@ -1492,14 +1397,6 @@ def _try_setup(self, shallow: bool = False) -> None: self._register_submodules(field.name, getattr(self, field.name)) if not shallow: self.setup() - # create NonTransparent Modules - self._compact_name_scope_modules = { - name: NonTransparent( - getattr(type(self), name).inner_fun, lambda: self, name=name - ) - for name in self._compact_name_scope_methods - } - # We run static checks abstractly once for setup before any transforms # to detect name collisions and other python errors. elif self._state.setup_called == SetupState.NEW: @@ -2938,12 +2835,3 @@ def init_wrapper(*args, **kwargs): return init_fn(*args, **kwargs)[1] return init_wrapper - - -class NonTransparent(Module): - fn: Callable - module_fn: Callable[[], Module] - - @compact - def __call__(self, *args, **kwargs) -> Any: - return self.fn(self.module_fn(), *args, **kwargs) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index af4df220b2..a5d0416ada 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2487,76 +2487,6 @@ def my_property(self): self.assertEqual(obj_loaded.b, 'ok') self.assertEqual(obj_loaded.my_property, 'okok') - def test_compact_name_scope(self): - class Foo(nn.Module): - @nn.compact_name_scope - def up(self, x): - return nn.Dense(3)(x) - - @nn.compact_name_scope - def down(self, x): - return nn.Dense(3)(x) - - @nn.compact - def __call__(self, x): - return self.up(x) + self.down(x) + nn.Dense(3)(x) - - m = Foo() - x = jnp.ones((1, 2)) - - self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'}) - - variables = m.init(random.key(0), x) - params = variables['params'] - - self.assertIn('Dense_0', params) - self.assertIn('down', params) - self.assertIn('up', params) - self.assertIn('Dense_0', params['down']) - self.assertIn('Dense_0', params['up']) - - y = m.apply(variables, x) - y_up = m.apply(variables, x, method='up') - y_down = m.apply(variables, x, method='down') - - assert y.shape == (1, 3) - assert y_up.shape == (1, 3) - assert y_down.shape == (1, 3) - - def test_compact_name_scope_outside_compact(self): - class Foo(nn.Module): - @nn.compact_name_scope - def up(self, x): - return nn.Dense(3)(x) - - @nn.compact_name_scope - def down(self, x): - return nn.Dense(3)(x) - - def __call__(self, x): - return self.up(x) + self.down(x) - - m = Foo() - x = jnp.ones((1, 2)) - - self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'}) - - variables = m.init(random.key(0), x) - params = variables['params'] - - self.assertIn('down', params) - self.assertIn('up', params) - self.assertIn('Dense_0', params['down']) - self.assertIn('Dense_0', params['up']) - - y = m.apply(variables, x) - y_up = m.apply(variables, x, method='up') - y_down = m.apply(variables, x, method='down') - - assert y.shape == (1, 3) - assert y_up.shape == (1, 3) - assert y_down.shape == (1, 3) - class LeakTests(absltest.TestCase): def test_tracer_leaks(self):