From 5d24628ce70722de361bf874240cc0bc171a6267 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Fri, 23 May 2025 09:32:43 -0700 Subject: [PATCH] Add TODO to run TPU interpret mode tests in parallel. PiperOrigin-RevId: 762456274 --- tests/pallas/tpu_pallas_interpret_distributed_test.py | 2 ++ tests/pallas/tpu_pallas_interpret_test.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index a029b8094aa1..4e4776736cf1 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -38,6 +38,8 @@ P = jax.sharding.PartitionSpec +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9d6188cbc0cc..cfbf5d70e212 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -82,6 +82,8 @@ def grid_points(self): return self._grid_points +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase):