From 0d97c3ba0193430dff21293a85cb93266e17a2de Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 12 May 2020 11:05:03 -0700 Subject: [PATCH] Import tpu_driver after xla_client (#3064) This is a workaround until we build a new jaxlib with https://github.com/tensorflow/tensorflow/commit/f4628678066c72309d3fd121af1aaf54d9905ca3 --- jax/lib/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 924af52823d0..951873010b69 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -45,10 +45,6 @@ def _check_jaxlib_version(): _check_jaxlib_version() -try: - from jaxlib import tpu_client # pytype: disable=import-error -except: - tpu_client = None from jaxlib import xla_client from jaxlib import lapack @@ -58,3 +54,8 @@ def _check_jaxlib_version(): from jaxlib import cuda_prng except ImportError: cuda_prng = None + +try: + from jaxlib import tpu_client # pytype: disable=import-error +except: + tpu_client = None