diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e8265d7..80afd3357 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ vNext - Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list. - - - - + - Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`. + See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns. - - - diff --git a/docs/howtos/extracting_intermediates.rst b/docs/howtos/extracting_intermediates.rst new file mode 100644 index 000000000..d1b57276d --- /dev/null +++ b/docs/howtos/extracting_intermediates.rst @@ -0,0 +1,278 @@ +Extracting intermediate values +============================== + +This pattern will show you how to extract intermediate values from a module. +Let's start with this simple CNN that uses :code:`nn.compact`. + +.. testsetup:: + + import flax.linen as nn + import jax + import jax.numpy as jnp + from flax.core import FrozenDict + from typing import Sequence + + batch = jnp.ones((4, 32, 32, 3)) + + class SowCNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv1', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv2', x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + self.sow('intermediates', 'features', x) + x = nn.Dense(features=256)(x) + self.sow('intermediates', 'conv3', x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + self.sow('intermediates', 'dense', x) + x = nn.log_softmax(x) + return x + +.. testcode:: + + class CNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + +Because this module uses ``nn.compact``, we don't have direct access to +intermediate values. There are a few ways to expose them: + + +Store intermediate values in a new variable collection +------------------------------------------------------ + +The CNN can be augmented with calls to ``sow`` to store intermediates as following: + + +.. codediff:: + :title_left: Default CNN + :title_right: CNN using sow API + + class CNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + + x = nn.Dense(features=256)(x) + + x = nn.relu(x) + x = nn.Dense(features=10)(x) + + x = nn.log_softmax(x) + return x + --- + class SowCNN(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv1', x) #! + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + self.sow('intermediates', 'conv2', x) #! + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + self.sow('intermediates', 'features', x) #! + x = nn.Dense(features=256)(x) + self.sow('intermediates', 'conv3', x) #! + x = nn.relu(x) + x = nn.Dense(features=10)(x) + self.sow('intermediates', 'dense', x) #! + x = nn.log_softmax(x) + return x + +``sow`` only stores a value if the given variable collection is passed in +as "mutable" in the call to :code:`Module.apply`. + +.. testcode:: + + @jax.jit + def init(key, x): + variables = SowCNN().init(key, x) + return variables + + @jax.jit + def predict(variables, x): + return SowCNN().apply(variables, x) + + @jax.jit + def features(variables, x): + # `mutable=['intermediates']` specified which collections are treated as + # mutable during `apply`. The variables aren't actually mutated, instead + # `apply` returns a second value, which is a dictionary of the modified + # collections. + output, modified_variables = SowCNN().apply(variables, x, mutable=['intermediates']) + return modified_variables['intermediates']['features'] + + variables = init(jax.random.PRNGKey(0), batch) + predict(variables, batch) + features(variables, batch) + +Refactor module into submodules +------------------------------- + +This is a useful pattern for cases where it's clear in what particular +way you want to split your submodules. Any submodule you expose in ``setup`` can be used directly. In the limit, you +can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether. + +.. testcode:: + + class RefactoredCNN(nn.Module): + def setup(self): + self.features = Features() + self.classifier = Classifier() + + def __call__(self, x): + x = self.features(x) + x = self.classifier(x) + return x + + class Features(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.relu(x) + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten + return x + + class Classifier(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(features=256)(x) + x = nn.relu(x) + x = nn.Dense(features=10)(x) + x = nn.log_softmax(x) + return x + + @jax.jit + def init(key, x): + variables = RefactoredCNN().init(key, x) + return variables['params'] + + @jax.jit + def features(params, x): + return RefactoredCNN().apply({"params": params}, x, + method=lambda module, x: module.features(x)) + + params = init(jax.random.PRNGKey(0), batch) + + features(params, batch) + + +Use `capture_intermediates` +--------------------------- + +Linen supports the capture of intermediate return values from submodules automatically without any code changes. +This pattern should be considered the "sledge hammer" approach to capturing intermediates. +As a debugging and inspection tool it is very useful but using the other patterns described in this howto. + +In the following code example we check if any intermediate activations are non-finite (NaN or infinite): + +.. testcode:: + + @jax.jit + def init(key, x): + variables = CNN().init(key, x) + return variables + + @jax.jit + def predict(variables, x): + y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"]) + intermediates = state['intermediates'] + fin = jax.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates) + return y, fin + + variables = init(jax.random.PRNGKey(0), batch) + y, is_finite = predict(variables, batch) + all_finite = all(jax.tree_leaves(is_finite)) + assert all_finite, "non finite intermediate detected!" + +By default only the intermediates of ``__call__`` methods are collected. +Alternatively, you can pass a custom filter based on the ``Module`` instance and the method name. + +.. testcode:: + + filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense) + filter_encodings = lambda mdl, method_name: method_name == "encode" + + y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"]) + dense_intermediates = state['intermediates'] + + +Use ``Sequential`` +--------------------- + +You could also define ``CNN`` using a simple implementation of a ``Sequential`` combinator (this is quite common in more stateful approaches). This may be useful +for very simple models and gives you arbitrary model +surgery. But it can be very limiting -- if you even want to add one conditional, you are +forced to refactor away from ``Sequential`` and structure +your model more explicitly. + +.. testcode:: + + class Sequential(nn.Module): + layers: Sequence[nn.Module] + + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + def SeqCNN(): + return Sequential([ + nn.Conv(features=32, kernel_size=(3, 3)), + nn.relu, + lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), + nn.Conv(features=64, kernel_size=(3, 3)), + nn.relu, + lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), + lambda x: x.reshape((x.shape[0], -1)), # flatten + nn.Dense(features=256), + nn.relu, + nn.Dense(features=10), + nn.log_softmax, + ]) + + @jax.jit + def init(key, x): + variables = SeqCNN().init(key, x) + return variables['params'] + + @jax.jit + def features(params, x): + return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x) + + params = init(jax.random.PRNGKey(0), batch) + features(params, batch) diff --git a/docs/index.rst b/docs/index.rst index cc07f0a19..14cc7d5f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ For a quick introduction and short example snippets, see our `README howtos/state_params howtos/ensembling howtos/lr_schedule + howtos/extracting_intermediates .. toctree:: :maxdepth: 1 diff --git a/flax/linen/module.py b/flax/linen/module.py index 262a36e9f..fa796df59 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -33,7 +33,7 @@ from flax import traverse_util from flax import serialization from flax.core import Scope, apply -from flax.core.scope import CollectionFilter, Variable, VariableDict, FrozenVariableDict +from flax.core.scope import CollectionFilter, Variable, VariableDict, FrozenVariableDict, union_filters from flax.core.frozen_dict import FrozenDict, freeze # from .dotgetter import DotGetter @@ -41,9 +41,13 @@ PRNGKey = Any # pylint: disable=invalid-name RNGSequences = Dict[str, PRNGKey] Array = Any # pylint: disable=invalid-name + + T = TypeVar('T') +K = TypeVar('K') _CallableT = TypeVar('_CallableT', bound=Callable) + # pylint: disable=protected-access,attribute-defined-outside-init def _check_omnistaging(): @@ -108,6 +112,13 @@ def module_stack(self): if not hasattr(self._thread_data, 'module_stack'): self._thread_data.module_stack = [None,] return self._thread_data.module_stack + + @property + def capture_stack(self): + """Keeps track of the active capture_intermediates filter functions.""" + if not hasattr(self._thread_data, 'capture_stack'): + self._thread_data.capture_stack = [] + return self._thread_data.capture_stack # The global context _context = _DynamicContext() @@ -263,7 +274,12 @@ def wrapped_module_method(*args, **kwargs): self._state.in_compact_method = True _context.module_stack.append(self) try: - return fun(self, *args, **kwargs) + y = fun(self, *args, **kwargs) + if _context.capture_stack: + filter_fn = _context.capture_stack[-1] + if filter_fn and filter_fn(self, fun.__name__): + self.sow('intermediates', fun.__name__, y) + return y finally: _context.module_stack.pop() if is_compact_method: @@ -347,8 +363,18 @@ def reimport(self, other): _caches = weakref.WeakKeyDictionary() + +tuple_reduce = lambda xs, x: xs + (x,) +tuple_init = lambda: () + + +capture_call_intermediates = lambda _, method_name: method_name == '__call__' + + # Base Module definition. # ----------------------------------------------------------------------------- + + class Module: """Base class for all neural network modules. Layers and models should subclass this class. @@ -767,16 +793,17 @@ def make_rng(self, name: str) -> PRNGKey: def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, method: Callable[..., Any] = None, mutable: Union[bool, str, Sequence[str]] = False, + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, **kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]: """Applies a module method to variables and returns output and modified variables. - Note that `method` should be set if one would like to call `apply` on a + Note that `method` should be set if one would like to call `apply` on a different class method than `_call__`. For instance, suppose a Transformer modules has a method called `encode`, then the following calls `apply` on that method:: model = models.Transformer(config) - encoded = model.apply({'params': params}, inputs, method=model.encode) + encoded = model.apply({'params': params}, inputs, method=model.encode) Args: variables: A dictionary containing variables keyed by variable @@ -789,6 +816,12 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. + capture_intermediates: If `True`, captures intermediate return values + of all Modules inside the "intermediates" collection. By default only + the return values of all `__call__` methods are stored. A function can + be passed to change the filter behavior. The filter function takes + the Module instance and method name and returns a bool indicating + whether the output of that method invocation should be stored. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict @@ -799,7 +832,16 @@ def apply(self, variables: VariableDict, *args, rngs: RNGSequences = None, else: method = _get_unbound_fn(method) fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs) - return apply(fn, mutable=mutable)(variables, rngs=rngs) + if capture_intermediates is True: + capture_intermediates = capture_call_intermediates + if capture_intermediates: + mutable = union_filters(mutable, 'intermediates') + _context.capture_stack.append(capture_intermediates) + try: + return apply(fn, mutable=mutable)(variables, rngs=rngs) + finally: + _context.capture_stack.pop() + def init_with_output(self, rngs: Union[PRNGKey, RNGSequences], *args, method: Optional[Callable[..., Any]] = None, @@ -848,9 +890,92 @@ def variables(self) -> VariableDict: if self.scope is None: raise ValueError("Can't access variables on unbound modules") return self.scope.variables() + + def get_variable(self, col: str, name: str, default: T = None) -> T: + """Retrieves the value of a Variable. + + Args: + col: the variable collection. + name: the name of the variable. + default: the default value to return if the variable does not exist in + this scope. + + Returns: + The value of the input variable, of the default value if the variable + doesn't exist in this scope. + """ + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + return self.scope.get_variable(col, name, default) + + def sow(self, col: str, name: str, value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init) -> bool: + """Stores a value in a collection. + + Collections can be used to collect intermediate values without + the overhead of explicitly passing a container through each Module call. + + If the target collection is not mutable `sow` behaves like a no-op + and returns `False`. + + Example:: + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + h = nn.Dense(4)(x) + self.sow('intermediates', 'h', h) + return nn.Dense(2)(h) + y, state = Foo.apply(params, x, mutable=['intermediates']) + print(state['intermediates']) # {'h': (...,)} + + By default the values are stored in a tuple and each stored value + is appended at the end. This way all intermediates can be tracked when + the same module is called multiple times. Alternatively, a custom + init/reduce function can be passed:: + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + init_fn = lambda: 0 + reduce_fn = lambda a, b: a + b + self.sow('intermediates', x, h, + init_fn=init_fn, reduce_fn=reduce_fn) + self.sow('intermediates', x * 2, h, + init_fn=init_fn, reduce_fn=reduce_fn) + return x + y, state = Foo.apply(params, 1, mutable=['intermediates']) + print(state['intermediates']) # ==> {'h': 3} + + Args: + col: the variable collection. + name: the name of the variable. + reduce_fn: The function used to combine the existing value with + the new value the default is to append the value to a tuple. + init_fn: For the first value stored reduce_fn will be passed + the result of `init_fn` together with the value to be stored. + The default is an empty tuple. + + Returns: + `True` if the value has been stored succesfully, `False` otherwise. + """ + if self.scope is None: + raise ValueError("Can't store variables on unbound modules") + if not self.scope.is_mutable_collection(col): + return False + if self.scope.has_variable(col, name): + xs = self.scope.get_variable(col, name) + else: + self.scope.reserve(name) + self._state.children[name] = col + xs = init_fn() + xs = reduce_fn(xs, value) + self.scope.put_variable(col, name, xs) + return True + -T = TypeVar('T') def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: """Merges construction and call time argument. diff --git a/tests/linen/module_test.py b/tests/linen/module_test.py index 4103c1ef8..bb4264138 100644 --- a/tests/linen/module_test.py +++ b/tests/linen/module_test.py @@ -1024,6 +1024,48 @@ def __call__(self, x): return x+1 self.assertTrue(isinstance(X.Hyper(a=1), X.Hyper)) + def test_sow(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x, **sow_args): + self.sow('intermediates', 'h', x, **sow_args) + self.sow('intermediates', 'h', 2 * x, **sow_args) + return 3 * x + + _, state = Foo().apply({}, 1, mutable=['intermediates']) + self.assertEqual(state, { + 'intermediates': {'h': (1, 2)} + }) + _, state = Foo().apply( + {}, 1, + init_fn=lambda: 0, + reduce_fn=lambda a, b: a + b, + mutable=['intermediates']) + self.assertEqual(state, { + 'intermediates': {'h': 3} + }) + self.assertEqual(Foo().apply({}, 1), 3) + + def test_capture_intermediates(self): + class Bar(nn.Module): + def test(self, x): + return x + 1 + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return Bar().test(x) + 1 + + _, state = Foo().apply({}, 1, capture_intermediates=True) + self.assertEqual(state, { + 'intermediates': {'__call__': (3,)} + }) + fn = lambda mdl, _: isinstance(mdl, Bar) + _, state = Foo().apply({}, 1, capture_intermediates=fn) + self.assertEqual(state, { + 'intermediates': {'Bar_0': {'test': (2,)}} + }) + if __name__ == '__main__': absltest.main()