Skip to content

Commit

Permalink
Import tpu_driver after xla_client (jax-ml#3064)
Browse files Browse the repository at this point in the history
This is a workaround until we build a new jaxlib with tensorflow/tensorflow@f462867
  • Loading branch information
skye authored May 12, 2020
1 parent 8008aa9 commit 0d97c3b
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def _check_jaxlib_version():
_check_jaxlib_version()


try:
from jaxlib import tpu_client # pytype: disable=import-error
except:
tpu_client = None
from jaxlib import xla_client
from jaxlib import lapack

Expand All @@ -58,3 +54,8 @@ def _check_jaxlib_version():
from jaxlib import cuda_prng
except ImportError:
cuda_prng = None

try:
from jaxlib import tpu_client # pytype: disable=import-error
except:
tpu_client = None

0 comments on commit 0d97c3b

Please sign in to comment.