Skip to content

Commit

Permalink
[nnx] test fori_loop output
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 6, 2025
1 parent 53bde74 commit d8accc9
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,18 @@ def loop_fn(inputs):
nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2))
nnx.fori_loop(0, 2, fori_loop_fn, (a, b))

def test_fori_output(self):
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))

def f(i, x):
return x

model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2))

self.assertIs(model, model_out)
self.assertIs(model2, model2_out)


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
Expand Down

0 comments on commit d8accc9

Please sign in to comment.