diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index a029b8094aa1..4e4776736cf1 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -38,6 +38,8 @@ P = jax.sharding.PartitionSpec +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 9d6188cbc0cc..cfbf5d70e212 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -82,6 +82,8 @@ def grid_points(self): return self._grid_points +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. @jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase):