From f8164dd4345f96b3fcf60665958bcf56eaeaa366 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 6 Jan 2025 11:44:11 -0500 Subject: [PATCH] [nnx] test fori_loop output --- tests/jax_utils_test.py | 16 +++++++++++++++- tests/nnx/transforms_test.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py index c9cd9b3095..d54262413b 100644 --- a/tests/jax_utils_test.py +++ b/tests/jax_utils_test.py @@ -15,6 +15,8 @@ """Tests for flax.jax_utils.""" from functools import partial +import os +import re from absl.testing import absltest from absl.testing import parameterized @@ -26,9 +28,21 @@ NDEV = 4 +_xla_device_count_flag_regexp = ( + r'[-]{0,2}xla_force_host_platform_device_count=(\d+)?(\s|$)' +) + + +def set_n_cpu_devices(n: int): + xla_flags = os.getenv('XLA_FLAGS', '') + xla_flags = re.sub(_xla_device_count_flag_regexp, '', xla_flags) + os.environ['XLA_FLAGS'] = ' '.join( + [f'--xla_force_host_platform_device_count={n}'] + xla_flags.split() + ) + def setUpModule(): - chex.set_n_cpu_devices(NDEV) + set_n_cpu_devices(NDEV) class PadShardUnpadTest(chex.TestCase): diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 736da9acf0..bfa461be39 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2985,6 +2985,18 @@ def loop_fn(inputs): nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2)) nnx.fori_loop(0, 2, fori_loop_fn, (a, b)) + def test_fori_output(self): + model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0))) + model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1))) + + def f(i, x): + return x + + model_out, model2_out = nnx.fori_loop(0, 10, f, (model, model2)) + + self.assertIs(model, model_out) + self.assertIs(model2, model2_out) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self):