Skip to content

Commit 1b6dc93

Browse files
jburnimGoogle-ML-Automation
authored andcommitted
Add TODO to run TPU interpret mode tests in parallel.
PiperOrigin-RevId: 762143977
1 parent c2c55ae commit 1b6dc93

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

tests/pallas/tpu_pallas_interpret_distributed_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
P = jax.sharding.PartitionSpec
3939

4040

41+
# TODO(jburnim): Figure out how to safely run different instance of TPU
42+
# interpret mode in parallel, and then remove this decorator.
4143
@jtu.thread_unsafe_test_class()
4244
class InterpretDistributedTest(jtu.JaxTestCase):
4345
def setUp(self):

tests/pallas/tpu_pallas_interpret_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def grid_points(self):
8282
return self._grid_points
8383

8484

85+
# TODO(jburnim): Figure out how to safely run different instance of TPU
86+
# interpret mode in parallel, and then remove this decorator.
8587
@jtu.thread_unsafe_test_class()
8688
class InterpretTest(jtu.JaxTestCase):
8789

0 commit comments

Comments
 (0)