Skip to content

Commit 5dba0cf

Browse files
Merge pull request #28935 from jburnim:jburnim_interpret_test
PiperOrigin-RevId: 761974061
2 parents 125f817 + be0ed4a commit 5dba0cf

2 files changed

Lines changed: 2 additions & 0 deletions

File tree

tests/pallas/tpu_pallas_interpret_distributed_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
P = jax.sharding.PartitionSpec
3939

4040

41+
@jtu.thread_unsafe_test_class()
4142
class InterpretDistributedTest(jtu.JaxTestCase):
4243
def setUp(self):
4344
super().setUp()

tests/pallas/tpu_pallas_interpret_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def grid_points(self):
8282
return self._grid_points
8383

8484

85+
@jtu.thread_unsafe_test_class()
8586
class InterpretTest(jtu.JaxTestCase):
8687

8788
def setUp(self):

0 commit comments

Comments
 (0)