Skip to content

Commit f7f0757

Browse files
committed
[jax2tf] Refine the disabling of jax2tf_test, for versions <= 2.19.1
Previously we disabled the jax2tf_test for older versions of TF. Re-enable for 2.19.1 and higher.
1 parent ba8120d commit f7f0757

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,17 @@
4848
config.parse_flags_with_absl()
4949

5050

51-
@unittest.skip("Failing after jax 0.6.1 release")
5251
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
5352

5453
def setUp(self):
5554
super().setUp()
55+
versions = tf.version.VERSION.split(".")
56+
if versions < ["2", "19", "1"]:
57+
# StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this
58+
# introduces ops like vhlo_sine_v2. These ops require a TF version
59+
# released after this date.
60+
self.skipTest("Need version of TensorFlow at least 2.19.1")
61+
5662
# One TF device of each device_type
5763
self.tf_devices = []
5864
for tf_device in (tf.config.list_logical_devices("TPU") +
@@ -1783,11 +1789,17 @@ def func():
17831789
jax_result = func()
17841790
self.assertEqual(tf_result, jax_result)
17851791

1786-
@unittest.skip("Failing after jax 0.6.1 release")
1792+
17871793
class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
17881794
# Use a separate test case with the default jax_serialization_version
17891795
def setUp(self):
17901796
self.use_max_serialization_version = False
1797+
versions = tf.version.VERSION.split(".")
1798+
if versions < ["2", "19", "1"]:
1799+
# StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this
1800+
# introduces ops like vhlo_sine_v2. These ops require a TF version
1801+
# released after this date.
1802+
self.skipTest("Need version of TensorFlow at least 2.19.1")
17911803
super().setUp()
17921804

17931805
@jtu.ignore_warning(

0 commit comments

Comments
 (0)