diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 09f169548796..edd6e8fc29e3 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -175,7 +175,7 @@ jobs: run: | pip install uv~=0.5.30 uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt - uv pip install --system --pre tensorflow==2.19.0rc0 + uv pip install --system --pre tensorflow==2.19.0 - name: Run tests env: diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ece88841fdc5..bde148cb514e 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -48,11 +48,17 @@ config.parse_flags_with_absl() -@unittest.skip("Failing after jax 0.6.1 release") class Jax2TfTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") + # One TF device of each device_type self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + @@ -1783,11 +1789,17 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) -@unittest.skip("Failing after jax 0.6.1 release") + class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): self.use_max_serialization_version = False + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") super().setUp() @jtu.ignore_warning(