From 89923c66793d56fd0dcdd5032a1ab7289247fba0 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 12:39:59 +0000 Subject: [PATCH 01/12] add sow and capture_intermediates --- flax/linen/module.py | 121 +++++++++++++++++++++++++++++++++++-- tests/linen/module_test.py | 34 +++++++++++ 2 files changed, 150 insertions(+), 5 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 5bf11837c..1b0d23779 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -33,7 +33,7 @@ from flax import traverse_util from flax import serialization from flax.core import Scope, apply -from flax.core.scope import CollectionFilter, Variable, VariableDict, FrozenVariableDict +from flax.core.scope import CollectionFilter, Variable, VariableDict, FrozenVariableDict, union_filters from flax.core.frozen_dict import FrozenDict, freeze # from .dotgetter import DotGetter @@ -41,9 +41,13 @@ PRNGKey = Any # pylint: disable=invalid-name RNGSequences = Dict[str, PRNGKey] Array = Any # pylint: disable=invalid-name + + T = TypeVar('T') +K = TypeVar('T') _CallableT = TypeVar('_CallableT', bound=Callable) + # pylint: disable=protected-access,attribute-defined-outside-init def _check_omnistaging(): @@ -108,6 +112,12 @@ def module_stack(self): if not hasattr(self._thread_data, 'module_stack'): self._thread_data.module_stack = [None,] return self._thread_data.module_stack + + @property + def capture_stack(self): + if not hasattr(self._thread_data, 'capture_stack'): + self._thread_data.capture_stack = [] + return self._thread_data.capture_stack # The global context _context = _DynamicContext() @@ -263,7 +273,12 @@ def wrapped_module_method(*args, **kwargs): self._state.in_compact_method = True _context.module_stack.append(self) try: - return fun(self, *args, **kwargs) + y = fun(self, *args, **kwargs) + if _context.capture_stack: + filter_fn = _context.capture_stack[-1] + if filter_fn and filter_fn(self, fun.__name__): + self.sow('intermediates', fun.__name__, y) + return y finally: _context.module_stack.pop() if is_compact_method: @@ -345,10 +360,20 @@ def reimport(self, other): '__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__') -_caches = weakref.WeakKeyDictionary() +tuple_reduce = lambda xs, x: xs + (x,) +tuple_init = lambda: () + + +capture_call_intermediates = lambda _, method_name: method_name == '__call__' + # Base Module definition. # ----------------------------------------------------------------------------- + + +_caches = weakref.WeakKeyDictionary() + + class Module: """Base class for all neural network modules. Layers and models should subclass this class. @@ -767,6 +792,7 @@ def make_rng(self, name: str) -> PRNGKey: def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, method: Callable[..., Any] = None, mutable: Union[bool, str, Sequence[str]] = False, + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, **kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]: """Applies a module method to variables and returns output and modified variables. @@ -789,6 +815,12 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. + capture_intermediates: If `True`, captures intermediate return values + of all Modules inside the "intermediates" collection. By default only + the return value of the `__call__` method is stored. A function can + be passed to change the filter behavior. The filter function takes + the Module instance and method name and returns a bool indicating + whether the output of that method invocation should be stored. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict @@ -799,7 +831,16 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, else: method = _get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) - return apply(fn, mutable=mutable)(variables, rngs=rngs) + if capture_intermediates is True: + capture_intermediates = capture_call_intermediates + if capture_intermediates: + mutable = union_filters(mutable, 'intermediates') + _context.capture_stack.append(capture_intermediates) + try: + return apply(fn, mutable=mutable)(variables, rngs=rngs) + finally: + _context.capture_stack.pop() + def init_with_output(self, rngs: Union[PRNGKey, RNGSequences], *args, method: Optional[Callable[..., Any]] = None, @@ -848,9 +889,79 @@ def variables(self) -> VariableDict: if self.scope is None: raise ValueError("Can't access variables on unbound modules") return self.scope.variables() + + def get_variable(self, col: str, name: str, default: T = None) -> T: + """Retrieves the value of a Variable. + + Args: + col: the variable collection. + name: the name of the variable. + default: the default value to return if the variable does not exist in + this scope. + + Returns: + The value of the input variable, of the default value if the variable + doesn't exist in this scope. + """ + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + return self.scope.get_variable(col, name, default) + + def sow(self, col: str, name: str, value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init) -> bool: + """Stores a value in a collection. + + Collections can be used to collect intermediate values without + the overhead of explicitly passing a container through each Module call. + + By default the value is stored in a tuple and each stored value + is appended at the end. This way all intermediates can be tracked when + the same module is called multiple times. + + If the target collection is not mutable `sow` behaves like a no-op + and returns `False`. + + Example:: + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + h = nn.Dense(4)(x) + self.sow('intermediates', 'h', h) + return nn.Dense(2)(h) + + y, state = Foo.apply(params, x, mutable=['intermediates']) + print(state['intermediates']) # {'h': (...,)} + + Args: + col: the variable collection. + name: the name of the variable. + reduce_fn: The function used to combine the existing value with + the new value the default is to append the value to a tuple. + init_fn: For the first value stored reduce_fn will be passed + the result of `init_fn` together with the value to be stored. + The default is an empty tuple. + + Returns: + `True` if the value has been stored succesfully, `False` otherwise. + """ + if self.scope is None: + raise ValueError("Can't store variables on unbound modules") + if not self.scope.is_mutable_collection(col): + return False + if self.scope.has_variable(col, name): + xs = self.scope.get_variable(col, name) + else: + self.scope.reserve(name) + self._state.children[name] = col + xs = init_fn() + xs = reduce_fn(xs, value) + self.scope.put_variable(col, name, xs) + return True + -T = TypeVar('T') def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: """Merges construction and call time argument. diff --git a/tests/linen/module_test.py b/tests/linen/module_test.py index 4103c1ef8..943510415 100644 --- a/tests/linen/module_test.py +++ b/tests/linen/module_test.py @@ -1024,6 +1024,40 @@ def __call__(self, x): return x+1 self.assertTrue(isinstance(X.Hyper(a=1), X.Hyper)) + def test_sow(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + self.sow('intermediates', 'h', x) + self.sow('intermediates', 'h', 2 * x) + return 3 * x + + _, state = Foo().apply({}, 1, mutable=['intermediates']) + self.assertEqual(state, { + 'intermediates': {'h': (1, 2)} + }) + self.assertEqual(Foo().apply({}, 1), 3) + + def test_capture_intermediates(self): + class Bar(nn.Module): + def test(self, x): + return x + 1 + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return Bar().test(x) + 1 + + _, state = Foo().apply({}, 1, capture_intermediates=True) + self.assertEqual(state, { + 'intermediates': {'__call__': (3,)} + }) + fn = lambda mdl, _: isinstance(mdl, Bar) + _, state = Foo().apply({}, 1, capture_intermediates=fn) + self.assertEqual(state, { + 'intermediates': {'Bar_0': {'test': (2,)}} + }) + if __name__ == '__main__': absltest.main() From f5a3f52c84b8d7411d55075b067f49ee91cb97e0 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 12:43:45 +0000 Subject: [PATCH 02/12] Add extracting intermediates howto Co-authored-by: Avital Oliver --- .github/workflows/build.yml | 1 + docs/howtos/extracting_intermediates.rst | 198 +++++++++++++++++++++++ docs/index.rst | 1 + 3 files changed, 200 insertions(+) create mode 100644 docs/howtos/extracting_intermediates.rst diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c702a76b9..3828a6c62 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,6 +36,7 @@ jobs: sphinx-build -M doctest docs docs/_build - name: Build documentation run: | + sphinx-build -M doctest docs docs/_build sphinx-build -M html docs docs/_build - name: Test with pytest and generate coverage report run: | diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst new file mode 100644 index 000000000..236587f1d --- /dev/null +++ b/docs/howtos/extracting_intermediates.rst @@ -0,0 +1,198 @@ +Extracting intermediate values +============================== + +This pattern will show you how to extract intermediate values from a module. +Let's start with this simple CNN that uses :code:`nn.compact`. + +.. testsetup:: + + import flax.linen as nn + import jax + import jax.numpy as jnp + from flax.core import FrozenDict + from typing import Sequence + + batch = jnp.ones((4, 32, 32, 3)) + +.. testcode:: + + class CNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + +Because this module uses ``nn.compact``, we don't have direct access to +intermediate values. There are a few ways to expose them: + + +Store intermediate values in a new variable collection +------------------------------------------------------ + +You can augment any module with calls to ``sow`` +which store any intermediate values. ``sow`` only +stores a value if the given variable collection is passed in +as "mutable" in the call to :code:`Module.apply`. + +.. testcode:: + + class CNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv1', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv2', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + self.sow('intermediates', 'features', x) + x = nn.Dense(features=256)(x) + self.sow('intermediates', 'conv3', x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + self.sow('intermediates', 'dense', x) + x = nn.log_softmax(x) + return x + + @jax.jit + def init(key, x): + variables = CNN().init(key, x) + return variables + + @jax.jit + def predict(variables, x): + return CNN().apply(variables, x) + + @jax.jit + def features(variables, x): + # `mutable=['intermediates']` specified which collections are treated as + # mutable during `apply`. The variables aren't actually mutated, instead + # `apply` returns a second value, which is a dictionary of the modified + # collections. + output, modified_variables = CNN().apply(variables, x, mutable=['intermediates']) + return modified_variables['intermediates']['features'] + + variables = init(jax.random.PRNGKey(0), batch) + predict(variables, batch) + features(variables, batch) + +Refactor module into submodules +------------------------------- + +This is a useful pattern for cases where it's clear in what particular +way you want to split your submodules. Any submodule you expose in ``setup`` can be used directly. In the limit, you +can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether. + +.. testcode:: + + class CNN(nn.Module): + def setup(self): + self.features = Features() + self.classifier = Classifier() + + def __call__(self, x): + x = self.features(x) + x = self.classifier(x) + return x + + class Features(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + return x + + class Classifier(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + + @jax.jit + def init(key, x): + variables = CNN().init(key, x) + return variables['params'] + + @jax.jit + def features(params, x): + # Proposal #686 should allow for this alternative: + # return CNN().features.apply({"params": params['features']}) + return CNN().apply({"params": params}, x, + method=lambda module, x: module.features(x)) + + params = init(jax.random.PRNGKey(0), batch) + + features(params, batch) + + +Use ``nn.Sequential`` +--------------------- + +You could also define ``CNN`` using a simple implementation of a ``Sequential`` combinator (this is quite common in more stateful approaches). This may be useful +for very simple models and gives you arbitrary model +surgery. But it can be very limiting -- if you even want to add one conditional, you are +forced to refactor away from ``nn.Sequential`` and structure +your model more explicitly. + +.. testcode:: + + class Sequential(nn.Module): + submodules: Sequence[nn.Module] + + def setup(self): + # Bind layers to `self` -- `__setattr__` gives submodules + # names and connects their variables through `self.variables` + self.layers = self.submodules + + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + def CNN(): + return Sequential([ + nn.Conv(features=32, kernel_size=(3, 3)), + nn.relu, + lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), + nn.Conv(features=64, kernel_size=(3, 3)), + nn.relu, + lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), + lambda x: x.reshape((x.shape[0], -1)), # flatten + nn.Dense(features=256), + nn.relu, + nn.Dense(features=10), + nn.log_softmax, + ]) + + @jax.jit + def init(key, x): + variables = CNN().init(key, x) + return variables['params'] + + @jax.jit + def features(params, x): + return Sequential(CNN().submodules[0:7]).apply({"params": params}, x) + + params = init(jax.random.PRNGKey(0), batch) + features(params, batch) diff --git a/docs/index.rst b/docs/index.rst index cc07f0a19..14cc7d5f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ For a quick introduction and short example snippets, see our `README howtos/state_params howtos/ensembling howtos/lr_schedule + howtos/extracting_intermediates .. toctree:: :maxdepth: 1 From 7211dee7a125b7b75166203724540db33285afed Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 12:44:13 +0000 Subject: [PATCH 03/12] add capture_intermediates to howto --- docs/howtos/extracting_intermediates.rst | 45 ++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst index 236587f1d..cf1b6069e 100644 --- a/docs/howtos/extracting_intermediates.rst +++ b/docs/howtos/extracting_intermediates.rst @@ -146,6 +146,51 @@ can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether features(params, batch) +Use `capture_intermediates` +--------------------------- + +Linen supports the capture of intermediate return values from submodules automatically without any code changes. +This pattern should be considered the "sledge hammer" approach to capturing intermediates. +As a debugging and inspection tool it is very useful but using the other patterns described in this howto. + +In the following code example we check if any intermediate activations are non-finite (NaN or infinite): + +.. testcode:: + + class CNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + + @jax.jit + def init(key, x): + variables = CNN().init(key, x) + return variables + + @jax.jit + def predict(variables, x): + y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"]) + intermediates = state['intermediates'] + fin = jax.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates) + return y, fin + + variables = init(jax.random.PRNGKey(0), batch) + y, is_finite = predict(variables, batch) + all_finite = all(jax.tree_leaves(is_finite)) + assert all_finite, "non finite intermediate detected!" + + Use ``nn.Sequential`` --------------------- From 3113ad0d17e14c0b8e8ad1cb5f6a7e7dd0696639 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 13:49:03 +0000 Subject: [PATCH 04/12] add changelog entry --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e8265d7..80afd3357 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ vNext - Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list. - - - - + - Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`. + See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns. - - - From daaeef5917745a58c75c93098878a8c17dd6dbc9 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 14:12:31 +0000 Subject: [PATCH 05/12] Fix TypeVar --- flax/linen/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 1b0d23779..b670611ed 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -44,7 +44,7 @@ T = TypeVar('T') -K = TypeVar('T') +K = TypeVar('K') _CallableT = TypeVar('_CallableT', bound=Callable) From 0a21ce840a38bd30fab696461116bfdfefc18f16 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 14:38:25 +0000 Subject: [PATCH 06/12] Add codediff --- docs/howtos/extracting_intermediates.rst | 103 +++++++++++++++-------- 1 file changed, 66 insertions(+), 37 deletions(-) diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst index cf1b6069e..4d50fec0c 100644 --- a/docs/howtos/extracting_intermediates.rst +++ b/docs/howtos/extracting_intermediates.rst @@ -14,6 +14,27 @@ Let's start with this simple CNN that uses :code:`nn.compact`. batch = jnp.ones((4, 32, 32, 3)) + class SowCNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv1', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv2', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + self.sow('intermediates', 'features', x) + x = nn.Dense(features=256)(x) + self.sow('intermediates', 'conv3', x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + self.sow('intermediates', 'dense', x) + x = nn.log_softmax(x) + return x + .. testcode:: class CNN(nn.Module): @@ -39,42 +60,68 @@ intermediate values. There are a few ways to expose them: Store intermediate values in a new variable collection ------------------------------------------------------ -You can augment any module with calls to ``sow`` -which store any intermediate values. ``sow`` only -stores a value if the given variable collection is passed in -as "mutable" in the call to :code:`Module.apply`. +The CNN can be augmented with calls to ``sow`` to store intermediates as following: -.. testcode:: +.. codediff:: + :title_left: Default CNN + :title_right: CNN using sow API + class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) - self.sow('intermediates', 'conv1', x) + x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) - self.sow('intermediates', 'conv2', x) + x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten - self.sow('intermediates', 'features', x) + x = nn.Dense(features=256)(x) - self.sow('intermediates', 'conv3', x) + x = nn.relu(x) x = nn.Dense(features=10)(x) - self.sow('intermediates', 'dense', x) + x = nn.log_softmax(x) return x + --- + class SowCNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv1', x) #! + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv2', x) #! + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + self.sow('intermediates', 'features', x) #! + x = nn.Dense(features=256)(x) + self.sow('intermediates', 'conv3', x) #! + x = nn.relu(x) + x = nn.Dense(features=10)(x) + self.sow('intermediates', 'dense', x) #! + x = nn.log_softmax(x) + return x + +``sow`` only stores a value if the given variable collection is passed in +as "mutable" in the call to :code:`Module.apply`. + +.. testcode:: @jax.jit def init(key, x): - variables = CNN().init(key, x) + variables = SowCNN().init(key, x) return variables @jax.jit def predict(variables, x): - return CNN().apply(variables, x) + return SowCNN().apply(variables, x) @jax.jit def features(variables, x): @@ -82,7 +129,7 @@ as "mutable" in the call to :code:`Module.apply`. # mutable during `apply`. The variables aren't actually mutated, instead # `apply` returns a second value, which is a dictionary of the modified # collections. - output, modified_variables = CNN().apply(variables, x, mutable=['intermediates']) + output, modified_variables = SowCNN().apply(variables, x, mutable=['intermediates']) return modified_variables['intermediates']['features'] variables = init(jax.random.PRNGKey(0), batch) @@ -98,7 +145,7 @@ can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether .. testcode:: - class CNN(nn.Module): + class RefactoredCNN(nn.Module): def setup(self): self.features = Features() self.classifier = Classifier() @@ -131,14 +178,12 @@ can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether @jax.jit def init(key, x): - variables = CNN().init(key, x) + variables = RefactoredCNN().init(key, x) return variables['params'] @jax.jit def features(params, x): - # Proposal #686 should allow for this alternative: - # return CNN().features.apply({"params": params['features']}) - return CNN().apply({"params": params}, x, + return RefactoredCNN().apply({"params": params}, x, method=lambda module, x: module.features(x)) params = init(jax.random.PRNGKey(0), batch) @@ -157,22 +202,6 @@ In the following code example we check if any intermediate activations are non-f .. testcode:: - class CNN(nn.Module): - @nn.compact - def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) - x = nn.log_softmax(x) - return x - @jax.jit def init(key, x): variables = CNN().init(key, x) @@ -215,7 +244,7 @@ your model more explicitly. x = layer(x) return x - def CNN(): + def SeqCNN(): return Sequential([ nn.Conv(features=32, kernel_size=(3, 3)), nn.relu, @@ -232,12 +261,12 @@ your model more explicitly. @jax.jit def init(key, x): - variables = CNN().init(key, x) + variables = SeqCNN().init(key, x) return variables['params'] @jax.jit def features(params, x): - return Sequential(CNN().submodules[0:7]).apply({"params": params}, x) + return Sequential(SeqCNN().submodules[0:7]).apply({"params": params}, x) params = init(jax.random.PRNGKey(0), batch) features(params, batch) From c884b4b0bc6836781aea901fd04b6dac988165af Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 16:09:22 +0000 Subject: [PATCH 07/12] Fix issues & add custom filter example --- docs/howtos/extracting_intermediates.rst | 13 ++++++++++++- flax/linen/module.py | 13 +++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst index 4d50fec0c..9da048168 100644 --- a/docs/howtos/extracting_intermediates.rst +++ b/docs/howtos/extracting_intermediates.rst @@ -71,7 +71,7 @@ The CNN can be augmented with calls to ``sow`` to store intermediates as followi @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) - + x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) @@ -219,6 +219,17 @@ In the following code example we check if any intermediate activations are non-f all_finite = all(jax.tree_leaves(is_finite)) assert all_finite, "non finite intermediate detected!" +By default only the intermediates of `__call__` methods are collected. +Alternatively, you can pass a custom filter based on the ``Module`` instance and the method name. + +.. testcode:: + + filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense) + filter_encodings = lambda mdl, method_name: method_name == "encode" + + y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"]) + dense_intermediates = state['intermediates'] + Use ``nn.Sequential`` --------------------- diff --git a/flax/linen/module.py b/flax/linen/module.py index b670611ed..b893245d5 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -115,6 +115,7 @@ def module_stack(self): @property def capture_stack(self): + """Keeps track of the active capture_intermediates filter functions.""" if not hasattr(self._thread_data, 'capture_stack'): self._thread_data.capture_stack = [] return self._thread_data.capture_stack @@ -360,6 +361,9 @@ def reimport(self, other): '__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__') +_caches = weakref.WeakKeyDictionary() + + tuple_reduce = lambda xs, x: xs + (x,) tuple_init = lambda: () @@ -371,9 +375,6 @@ def reimport(self, other): # ----------------------------------------------------------------------------- -_caches = weakref.WeakKeyDictionary() - - class Module: """Base class for all neural network modules. Layers and models should subclass this class. @@ -796,13 +797,13 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, **kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]: """Applies a module method to variables and returns output and modified variables. - Note that `method` should be set if one would like to call `apply` on a + Note that `method` should be set if one would like to call `apply` on a different class method than `_call__`. For instance, suppose a Transformer modules has a method called `encode`, then the following calls `apply` on that method:: model = models.Transformer(config) - encoded = model.apply({'params': params}, inputs, method=model.encode) + encoded = model.apply({'params': params}, inputs, method=model.encode) Args: variables: A dictionary containing variables keyed by variable @@ -817,7 +818,7 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, list of names of mutable collections. capture_intermediates: If `True`, captures intermediate return values of all Modules inside the "intermediates" collection. By default only - the return value of the `__call__` method is stored. A function can + the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. From 73568b3e5823e329dc0ad584a8de6d6bdbe802ff Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 16:10:19 +0000 Subject: [PATCH 08/12] Remove redundant doc test --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3828a6c62..c702a76b9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,7 +36,6 @@ jobs: sphinx-build -M doctest docs docs/_build - name: Build documentation run: | - sphinx-build -M doctest docs docs/_build sphinx-build -M html docs docs/_build - name: Test with pytest and generate coverage report run: | From 0ad966390d7045debc49ac7840ca4b5f09888b48 Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 16:12:59 +0000 Subject: [PATCH 09/12] Remove redundant setup in Sequential --- docs/howtos/extracting_intermediates.rst | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst index 9da048168..5f1695095 100644 --- a/docs/howtos/extracting_intermediates.rst +++ b/docs/howtos/extracting_intermediates.rst @@ -243,12 +243,7 @@ your model more explicitly. .. testcode:: class Sequential(nn.Module): - submodules: Sequence[nn.Module] - - def setup(self): - # Bind layers to `self` -- `__setattr__` gives submodules - # names and connects their variables through `self.variables` - self.layers = self.submodules + layers: Sequence[nn.Module] def __call__(self, x): for layer in self.layers: @@ -277,7 +272,7 @@ your model more explicitly. @jax.jit def features(params, x): - return Sequential(SeqCNN().submodules[0:7]).apply({"params": params}, x) + return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x) params = init(jax.random.PRNGKey(0), batch) features(params, batch) From 04e1162f8a30d6bba9247ec7d6ba515e77c8c14f Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Wed, 24 Feb 2021 16:23:35 +0000 Subject: [PATCH 10/12] Improve howto --- docs/howtos/extracting_intermediates.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst index 5f1695095..d1b57276d 100644 --- a/docs/howtos/extracting_intermediates.rst +++ b/docs/howtos/extracting_intermediates.rst @@ -219,7 +219,7 @@ In the following code example we check if any intermediate activations are non-f all_finite = all(jax.tree_leaves(is_finite)) assert all_finite, "non finite intermediate detected!" -By default only the intermediates of `__call__` methods are collected. +By default only the intermediates of ``__call__`` methods are collected. Alternatively, you can pass a custom filter based on the ``Module`` instance and the method name. .. testcode:: @@ -231,13 +231,13 @@ Alternatively, you can pass a custom filter based on the ``Module`` instance and dense_intermediates = state['intermediates'] -Use ``nn.Sequential`` +Use ``Sequential`` --------------------- You could also define ``CNN`` using a simple implementation of a ``Sequential`` combinator (this is quite common in more stateful approaches). This may be useful for very simple models and gives you arbitrary model surgery. But it can be very limiting -- if you even want to add one conditional, you are -forced to refactor away from ``nn.Sequential`` and structure +forced to refactor away from ``Sequential`` and structure your model more explicitly. .. testcode:: From 0cf99e6492b26d19ec73d95e0ce919ac452fd32e Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Thu, 25 Feb 2021 10:55:44 +0000 Subject: [PATCH 11/12] test custom reduce --- tests/linen/module_test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/linen/module_test.py b/tests/linen/module_test.py index 943510415..bb4264138 100644 --- a/tests/linen/module_test.py +++ b/tests/linen/module_test.py @@ -1027,15 +1027,23 @@ def __call__(self, x): def test_sow(self): class Foo(nn.Module): @nn.compact - def __call__(self, x): - self.sow('intermediates', 'h', x) - self.sow('intermediates', 'h', 2 * x) + def __call__(self, x, **sow_args): + self.sow('intermediates', 'h', x, **sow_args) + self.sow('intermediates', 'h', 2 * x, **sow_args) return 3 * x _, state = Foo().apply({}, 1, mutable=['intermediates']) self.assertEqual(state, { 'intermediates': {'h': (1, 2)} }) + _, state = Foo().apply( + {}, 1, + init_fn=lambda: 0, + reduce_fn=lambda a, b: a + b, + mutable=['intermediates']) + self.assertEqual(state, { + 'intermediates': {'h': 3} + }) self.assertEqual(Foo().apply({}, 1), 3) def test_capture_intermediates(self): From ed93615cadd5b911af5f96015283b142b38954be Mon Sep 17 00:00:00 2001 From: Jonathan Heek Date: Thu, 25 Feb 2021 11:05:28 +0000 Subject: [PATCH 12/12] Update sow docstring with custom reducer --- flax/linen/module.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index b893245d5..a804e4c72 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -916,10 +916,6 @@ def sow(self, col: str, name: str, value: T, Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. - By default the value is stored in a tuple and each stored value - is appended at the end. This way all intermediates can be tracked when - the same module is called multiple times. - If the target collection is not mutable `sow` behaves like a no-op and returns `False`. @@ -931,9 +927,26 @@ def __call__(self, x): h = nn.Dense(4)(x) self.sow('intermediates', 'h', h) return nn.Dense(2)(h) - y, state = Foo.apply(params, x, mutable=['intermediates']) print(state['intermediates']) # {'h': (...,)} + + By default the values are stored in a tuple and each stored value + is appended at the end. This way all intermediates can be tracked when + the same module is called multiple times. Alternatively, a custom + init/reduce function can be passed:: + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + init_fn = lambda: 0 + reduce_fn = lambda a, b: a + b + self.sow('intermediates', x, h, + init_fn=init_fn, reduce_fn=reduce_fn) + self.sow('intermediates', x * 2, h, + init_fn=init_fn, reduce_fn=reduce_fn) + return x + y, state = Foo.apply(params, 1, mutable=['intermediates']) + print(state['intermediates']) # ==> {'h': 3} Args: col: the variable collection.