|
48 | 48 | config.parse_flags_with_absl()
|
49 | 49 |
|
50 | 50 |
|
51 |
| -@unittest.skip("Failing after jax 0.6.1 release") |
52 | 51 | class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
53 | 52 |
|
54 | 53 | def setUp(self):
|
55 | 54 | super().setUp()
|
| 55 | + versions = tf.version.VERSION.split(".") |
| 56 | + if versions < ["2", "19", "1"]: |
| 57 | + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this |
| 58 | + # introduces ops like vhlo_sine_v2. These ops require a TF version |
| 59 | + # released after this date. |
| 60 | + self.skipTest("Need version of TensorFlow at least 2.19.1") |
| 61 | + |
56 | 62 | # One TF device of each device_type
|
57 | 63 | self.tf_devices = []
|
58 | 64 | for tf_device in (tf.config.list_logical_devices("TPU") +
|
@@ -1783,11 +1789,17 @@ def func():
|
1783 | 1789 | jax_result = func()
|
1784 | 1790 | self.assertEqual(tf_result, jax_result)
|
1785 | 1791 |
|
1786 |
| -@unittest.skip("Failing after jax 0.6.1 release") |
| 1792 | + |
1787 | 1793 | class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
|
1788 | 1794 | # Use a separate test case with the default jax_serialization_version
|
1789 | 1795 | def setUp(self):
|
1790 | 1796 | self.use_max_serialization_version = False
|
| 1797 | + versions = tf.version.VERSION.split(".") |
| 1798 | + if versions < ["2", "19", "1"]: |
| 1799 | + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this |
| 1800 | + # introduces ops like vhlo_sine_v2. These ops require a TF version |
| 1801 | + # released after this date. |
| 1802 | + self.skipTest("Need version of TensorFlow at least 2.19.1") |
1791 | 1803 | super().setUp()
|
1792 | 1804 |
|
1793 | 1805 | @jtu.ignore_warning(
|
|
0 commit comments