diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index c495186d72..bc64221667 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -13,16 +13,17 @@ # limitations under the License. """Wrapper around jax.lax.scan with in_axes/out_axes API.""" +from collections.abc import Callable import functools from typing import Any, Optional -from collections.abc import Callable import jax -import jax.numpy as jnp -import numpy as np -from jax import core, lax +from jax import core +from jax import lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np ScanAxis = Optional[int] @@ -35,13 +36,14 @@ class _Broadcast: def scan( - fn: Callable[..., Any], - in_axes: Any, - out_axes: Any, - length: int | None = None, - reverse: bool = False, - unroll: int = 1, - _split_transpose: bool = False + fn: Callable[..., Any], + in_axes: Any, + out_axes: Any, + length: int | None = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + check_constancy_invariants: bool = True, ): """A wrapper around `jax.lax.scan` with in_axes/out_axes api. @@ -78,6 +80,11 @@ def body_fn(b, c, x): iteration of a loop (default: 1). _split_transpose: An experimental feature to split the transpose of scan into a scan and a map, backed by an experimental Jax lax.scan() feature. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: the function that performs the scan of the form: (broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out). @@ -114,39 +121,43 @@ def trans(x): return jax.tree_util.tree_map(trans, xs) def scan_fn(broadcast_in, init, *args): + # Requires one extra tracing operation to test invariants: + # Verifies that broadcast constants are true loop invariants, and further + # supports broadcast function (non-carry) outputs. + xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_util.tree_map( - lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs ) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_util.tree_map( - lambda ax, y: (y if ax is broadcast else ()), out_axes, ys + lambda ax, y: (y if ax is broadcast else ()), out_axes, ys ) return broadcast_out, ys else: ys = jax.tree_util.tree_map( - lambda ax, y: (() if ax is broadcast else y), out_axes, ys + lambda ax, y: (() if ax is broadcast else y), out_axes, ys ) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init + lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init ) scan_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs + lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs ) input_avals = (carry_avals, scan_avals) in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( - lu.wrap_init(broadcast_body), in_tree + lu.wrap_init(broadcast_body), in_tree ) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) @@ -155,29 +166,63 @@ def body_fn(c, xs, init_mode=False): for pv, const in out_pvals: if pv is not None: raise ValueError( - 'broadcasted variable has a data dependency on the scan body.' + 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_util.tree_unflatten( - out_tree(), out_flat + out_tree(), out_flat ) if jax.version.__version_info__ > (0, 4, 25): c, ys = lax.scan( - body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, - _split_transpose=_split_transpose + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, + _split_transpose=_split_transpose ) else: c, ys = lax.scan( - body_fn, init, xs, length=length, reverse=reverse, unroll=unroll + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll ) ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) ys = jax.tree_util.tree_map( - lambda ax, const, y: (const if ax is broadcast else y), - out_axes, - constants_out, - ys, + lambda ax, const, y: (const if ax is broadcast else y), + out_axes, + constants_out, + ys, ) return broadcast_in, c, ys - return scan_fn + def simple_scan_fn(broadcast_in, init, *args): + # Saves an extra tracing operation. + # No verification of constancy, and no support for non-carry broadcast + # function outputs. + xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) + + if broadcast in jax.tree_util.tree_leaves(out_axes): + raise ValueError(f"nn.scan run with check_constancy_invariants=False " + f"does not support broadcast non-carry function " + f"outputs. out_axes was given as {out_axes}") + + def body_fn(c, xs): + # inject constants + xs = jax.tree_util.tree_map( + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + ) + _, c, ys = fn(broadcast_in, c, *xs) + return c, ys + + if jax.version.__version_info__ > (0, 4, 25): + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, + _split_transpose=_split_transpose + ) + else: + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll + ) + ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) + return broadcast_in, c, ys + + if check_constancy_invariants: + return scan_fn + else: + return simple_scan_fn diff --git a/flax/core/lift.py b/flax/core/lift.py index f7b7bfb739..98a929be0c 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -879,6 +879,7 @@ def scan( _split_transpose: bool = False, data_transform: Callable[..., Any] | None = None, metadata_params: dict[Any, Any] = {}, + check_constancy_invariants: bool = True, ) -> Callable[..., Any]: """A lifted version of ``jax.lax.scan``. @@ -946,6 +947,11 @@ def body_fn(scope, c, x): intended for inline SPMD annotations. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: The scan function with the signature @@ -1000,7 +1006,8 @@ def find_length(axis, x): length=length, reverse=reverse, unroll=unroll, - _split_transpose=_split_transpose + _split_transpose=_split_transpose, + check_constancy_invariants=check_constancy_invariants, ) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 80c44f9946..06a130f10e 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -1165,6 +1165,7 @@ def scan( metadata_params: Mapping[Any, Any] = {}, methods=None, _split_transpose: bool = False, + check_constancy_invariants: bool = True, ) -> Target: """A lifted version of ``jax.lax.scan``. @@ -1304,6 +1305,11 @@ def scan( methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over. _split_transpose: An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: The scan function with the signature ``(module, carry, *xs) -> (carry, @@ -1326,6 +1332,7 @@ def scan( data_transform=data_transform, metadata_params=metadata_params, methods=methods, + check_constancy_invariants=check_constancy_invariants, ) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index d5634a011f..8154a2c349 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -2715,6 +2715,53 @@ def __call__(self, x): params = foo.init(key, x) foo.apply(params, x) + @parameterized.named_parameters( + ('retracing scan', True), ('simple scan', False) + ) + def test_jit_scan_retracing(self, retracing_scan: bool): + num_blocks = 4 + num_patterns = 4 + features = 4 + trace_counts = [0, 0] + + class Block(nn.Module): + def setup(self): + self.dense = nn.Dense(features, use_bias=False) + @nn.jit + def __call__(self, x): + nonlocal trace_counts + trace_counts[1] += 1 + return self.dense(x) + + class BlockSequence(nn.Module): + def setup(self): + self.blocks = [Block() for _ in range(num_blocks)] + @nn.jit + def __call__(self, carry, inputs): + nonlocal trace_counts + trace_counts[0] += 1 + for block in self.blocks: + carry = block(carry) + return carry, inputs + + class Transformer(nn.Module): + retracing_scan: bool = True + def setup(self): + self.scan = nn.scan( + BlockSequence, + variable_axes={'params': 0}, + split_rngs={'params': False}, + length=num_patterns, + check_constancy_invariants=retracing_scan, + )() + def __call__(self, inputs): + return self.scan(jnp.zeros_like(inputs), inputs) + + model = Transformer(retracing_scan=retracing_scan) + _ = model.init(random.key(0), jnp.ones((num_patterns, features,))) + self.assertEqual(trace_counts[0], 2 if retracing_scan else 1) + self.assertEqual(trace_counts[1], 2 if retracing_scan else 1) + if __name__ == '__main__': absltest.main()