From 4f5d6fb9283b3cd17e4cd78118356ed243d09950 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Thu, 31 Oct 2024 15:23:51 -0700 Subject: [PATCH] Add conditionals `nnx.while_loop` and `nnx.switch` --- .../api_reference/flax.nnx/transforms.rst | 4 +- flax/nnx/__init__.py | 2 + flax/nnx/transforms/iteration.py | 145 +++++++++++++++ flax/nnx/transforms/transforms.py | 16 +- tests/nnx/transforms_test.py | 168 +++++++++++++++++- 5 files changed, 332 insertions(+), 3 deletions(-) diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst index 179098b833..aead2f7841 100644 --- a/docs_nnx/api_reference/flax.nnx/transforms.rst +++ b/docs_nnx/api_reference/flax.nnx/transforms.rst @@ -20,5 +20,7 @@ transforms .. autofunction:: value_and_grad .. autofunction:: vmap .. autofunction:: eval_shape -.. autofunction:: cond .. autofunction:: custom_vjp +.. autofunction:: cond +.. autofunction:: switch +.. autofunction:: while_loop diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index f6ef81a9ad..c670cc8556 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -150,6 +150,8 @@ from .transforms.iteration import pmap as pmap from .transforms.transforms import eval_shape as eval_shape from .transforms.transforms import cond as cond +from .transforms.transforms import switch as switch +from .transforms.iteration import while_loop as while_loop from .transforms.iteration import StateAxes as StateAxes from .variablelib import A as A from .variablelib import BatchStat as BatchStat diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 3aaaa6ee37..466e307e90 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -40,6 +40,7 @@ M = tp.TypeVar('M', bound=Module) MA = tp.TypeVar('MA', bound=Module) N = tp.TypeVar('N', bound=Module) +T = tp.TypeVar('T') StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable Leaves = tp.List[Leaf] @@ -1304,3 +1305,147 @@ def scan_wrapper(*args, **kwargs): return out return scan_wrapper # type: ignore + + + + + +# ------------------------------- +# while_loop +# ------------------------------- + + +@dataclasses.dataclass(eq=False) +class WhileLoopCondFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, pure_val): + val = extract.from_tree(pure_val) + out = self.f(val) + return out + + +def _add_fake_index_mapping(tree: tp.Any): + def per_node_state(ns: extract.NodeStates | tp.Any): + global_index_mapping = {} + if not isinstance(ns, extract.NodeStates) or not isinstance( + ns._graphdef, graph.NodeDef + ): + return ns + + def per_node_def(nd: graph.NodeDef | tp.Any): + if nd.index >= 0: + global_index_mapping[nd.index] = nd.index + for sub_nd in nd.subgraphs.values(): + per_node_def(sub_nd) + for l in nd.leaves.values(): + if isinstance(l, graph.NodeRef) and l.index >= 0: + global_index_mapping[l.index] = l.index + return + + per_node_def(ns._graphdef) + return dataclasses.replace(ns, _graphdef=dataclasses.replace( + ns._graphdef, + index_mapping=FrozenDict(global_index_mapping) + )) + + return jax.tree.map(per_node_state, tree, + is_leaf=lambda x: isinstance(x, extract.NodeStates)) + + +def _remove_index_mapping(tree: tp.Any): + '''Remove a fake index_mapping for the input to match that of the output.''' + def per_node_state(ns: extract.NodeStates | tp.Any): + if not isinstance(ns, extract.NodeStates) or not isinstance( + ns._graphdef, graph.NodeDef + ): + return ns + assert isinstance(ns._graphdef, graph.NodeDef) + return dataclasses.replace(ns, _graphdef=dataclasses.replace( + ns._graphdef, index_mapping=None + )) + + return jax.tree.map(per_node_state, tree, + is_leaf=lambda x: isinstance(x, extract.NodeStates)) + + +@dataclasses.dataclass(eq=False) +class WhileLoopBodyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + @graph.update_context('while_loop_body') + def __call__(self, pure_val): + # Removing the dummy index mapping being added outside of body function. + pure_val_in = _remove_index_mapping(pure_val) + + val = extract.from_tree(pure_val_in, ctxtag='while_loop_body') + out = self.f(val) + pure_out = extract.to_tree(out, ctxtag='while_loop_body') + + try: + jax.tree.map(lambda a, b: None, pure_val, pure_out) + except ValueError as e: + msg = ("nnx.while_loop requires body function's input and output to " + "have the same reference and pytree structure, but they differ. " + "If the mismatch comes from `index_mapping` field, you might " + "have modified reference structure within the body function, " + "which is not allowed." + f"Detail of the mismatch: \n {str(e)}") + raise ValueError(msg) + + return pure_out + + +@graph.update_context('while_loop') +def while_loop(cond_fun: tp.Callable[[T], tp.Any], + body_fun: tp.Callable[[T], T], + init_val: T) -> T: + """NNX transform of `jax.lax.while_loop `_. + + Caution: for the NNX internal reference tracing mechanism to work, you cannot + change the reference structure of `init_val` inside `body_fun`. + + Example:: + + >>> import jax + >>> from flax import nnx + >>> def fwd_fn(input): + ... module, x, count = input + ... return module, module(x), count - 1.0 + + >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + >>> x = jax.random.normal(jax.random.key(0), (10,)) + >>> # `module` will be called three times + >>> _, y, _ = nnx.while_loop( + ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + + + Args: + cond_fun: a function for the continue condition of the while loop, taking a + single input of type `T` and outputting a boolean. + body_fun: a function that takes an input of type `T` and outputs an `T`. + Note that both data and modules of `T` must have the same reference + structure between inputs and outputs. + init_val: the initial input for cond_fun and body_fun. Must be of type `T`. + + """ + + pure_init_val = extract.to_tree(init_val, ctxtag='while_loop') + + # Adding the expected reference mapping to `pure_init_val` to match + # `body_fun`'s output pytree structure, to make JAX while_loop happy. + pure_init_val = _add_fake_index_mapping(pure_init_val) + + pure_out = jax.lax.while_loop( + WhileLoopCondFn(cond_fun), + WhileLoopBodyFn(body_fun), + pure_init_val, + ) + out = extract.from_tree(pure_out, ctxtag='while_loop') + return out \ No newline at end of file diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 558584dd9d..b74dd18c30 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -141,7 +141,7 @@ def _eval_shape_fn(*args, **kwargs): # ------------------------------- -# cond +# cond and switch # ------------------------------- @@ -160,3 +160,17 @@ def cond( *operands, **kwargs, ) + + +@general.split_inputs(ctxtag='switch') +def switch( + index, + branches: tp.Sequence[tp.Callable[..., A]], + *operands, +) -> A: + return jax.lax.switch( + index, + [general.merge_inputs(f, ctxtag='switch') for f in branches], + *operands, + ) + diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 5f478c4328..0ca99696c4 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1673,7 +1673,6 @@ def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: x = jnp.ones((16, 10, 20)) y = rnn_forward(cell, x) - print(y.shape) class TestRemat(absltest.TestCase): @@ -2612,6 +2611,173 @@ def no_nothing(env: Env): ) +class TestSwitch(absltest.TestCase): + def test_basic(self): + class RoundTable(nnx.Module): + def __init__(self): + self.next_index = 0 + self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + self.linear.kernel.value = jnp.identity(10) + self.rounds_count = nnx.Variable(jnp.array(0)) + + def __call__(self, x): + def fn0(m, x): + m.rounds_count += 1 + return m.linear(x) + def fn1(m, x): + return m.linear(x) * 2 + def fn2(m, x): + m.linear.kernel.value = jnp.zeros((10, 10)) + return m.linear(x) + + # y = nnx.cond(self.next_index.value == 0, fn0, fn1, self, x) + y = nnx.switch(self.next_index, (fn0, fn1, fn2), self, x) + self.next_index = (self.next_index + 1) % 3 + return y + + model = RoundTable() + x = jnp.ones((10,)) + np.testing.assert_array_equal(model(x), x) + assert model.rounds_count.value == 1 + assert model.next_index == 1 + np.testing.assert_array_equal(model(x), x * 2) + assert model.rounds_count.value == 1 + assert model.next_index == 2 + np.testing.assert_array_equal(model(x), jnp.zeros((10,))) + assert model.rounds_count.value == 1 + assert model.next_index == 0 + np.testing.assert_array_equal(model(x), jnp.zeros((10,))) + assert model.rounds_count.value == 2 + assert model.next_index == 1 + + +class TestWhileLoop(absltest.TestCase): + def test_basic(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + module.kernel.value = jnp.identity(10) * 2 + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(y, x * 8) + + def test_multiple_objects(self): + def fwd_fn(input): + m1, (w2,), x, c = input + y = m1(x) @ w2 + return m1, (w2,), y, c - 1.0 + + m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + m1.kernel.value = jnp.identity(10) * 2 + w2 = nnx.Variable(jnp.identity(10) * 0.5) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0)) + np.testing.assert_allclose(y, x) + + def test_nested_module(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + module.kernel.value = jnp.identity(10) * 2 + module = nnx.Sequential(module) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(y, x * 8) + + + def test_shared_module(self): + m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0)) + m2.kernel = m1.kernel + module = nnx.Sequential(m1, m2) + self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + + def fwd_fn(input): + m, x, c = input + y = m(x) + m.layers[0].kernel.value = jnp.zeros_like(m.layers[0].kernel.value) + return m, y, c - 1.0 + + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) + self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + np.testing.assert_array_equal(m1.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(m2.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(y, jnp.zeros((10,))) + + + def test_value_changed(self): + def fwd_fn(input): + m, x, c = input + m.kernel.value = jnp.zeros_like(m.kernel.value) + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + np.testing.assert_array_equal(module.kernel.value, jnp.zeros((10, 10,))) + np.testing.assert_array_equal(y, jnp.zeros((10,))) + + + def test_ref_changed(self): + def fwd_fn(input): + m, x, c = input + y = m(x) + m.kernel = nnx.Param(jnp.zeros_like(m.kernel.value)) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + with self.assertRaises(ValueError): + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) + + + def test_structure_changed(self): + def fwd_fn(input): + m, x, c = input + m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1)) + m.kernel.value = jnp.identity(10) * 2 + y = m(x) + return m, y, c - 1.0 + + module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0)) + x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) + + with self.assertRaises(ValueError): + _, y, _ = nnx.while_loop( + lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) + + def test_repeated_object(self): + m = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + + def body_fn(val): + count, m, _ = val + return count + 1, m, m + + count, m, _ = nnx.while_loop( + lambda val: val[0] < 2, + body_fn, + (0, m, m), + ) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): class StatefulLinear(nnx.Linear):