From d8accc9f7bbdc1c24d560adf97efcd387821957c Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 6 Jan 2025 11:44:11 -0500 Subject: [PATCH] [nnx] test fori_loop output --- tests/nnx/transforms_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 736da9acf0..bfa461be39 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2985,6 +2985,18 @@ def loop_fn(inputs): nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2)) nnx.fori_loop(0, 2, fori_loop_fn, (a, b)) + def test_fori_output(self): + model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0))) + model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1))) + + def f(i, x): + return x + + model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2)) + + self.assertIs(model, model_out) + self.assertIs(model2, model2_out) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self):