diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index c605336969c4..54a51ee665a6 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -700,7 +700,7 @@ "\n", "# Now we register the XLA compilation rule with JAX\n", "# TODO: for GPU? and TPU?\n", - "from jax import xla\n", + "from jax.interpreters import xla\n", "xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation" ], "execution_count": 0, @@ -876,7 +876,7 @@ "colab": {} }, "source": [ - "from jax import ad\n", + "from jax.interpreters import ad\n", "\n", "\n", "@trace(\"multiply_add_value_and_jvp\")\n", @@ -1529,7 +1529,7 @@ "colab": {} }, "source": [ - "from jax import batching\n", + "from jax.interpreters import batching\n", "\n", "\n", "@trace(\"multiply_add_batch\")\n", diff --git a/jax/__init__.py b/jax/__init__.py index e846351e58e1..9eb3dad2756b 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -12,11 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') - from jax.version import __version__ -from jax.api import * +from .config import config +from .api import ( + ad, # TODO(phawkins): update users to avoid this. + argnums_partial, # TODO(phawkins): update Haiku to not use this. + checkpoint, + curry, # TODO(phawkins): update users to avoid this. + custom_gradient, + custom_jvp, + custom_vjp, + custom_transforms, + defjvp, + defjvp_all, + defvjp, + defvjp_all, + device_count, + device_get, + device_put, + devices, + disable_jit, + eval_shape, + flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this. + grad, + hessian, + host_count, + host_id, + host_ids, + jacobian, + jacfwd, + jacrev, + jit, + jvp, + local_device_count, + local_devices, + linearize, + make_jaxpr, + mask, + partial, # TODO(phawkins): update callers to use functools.partial. + pmap, + pxla, # TODO(phawkins): update users to avoid this. + remat, + shapecheck, + ShapedArray, + ShapeDtypeStruct, + soft_pmap, + # TODO(phawkins): hide tree* functions from jax, update callers to use + # jax.tree_util. + treedef_is_leaf, + tree_flatten, + tree_leaves, + tree_map, + tree_multimap, + tree_structure, + tree_transpose, + tree_unflatten, + value_and_grad, + vjp, + vmap, + xla, # TODO(phawkins): update users to avoid this. + xla_computation, +) from jax import nn from jax import random -import jax.numpy as np # side-effecting import sets up operator overloads + +# TODO(phawkins): remove the `np` name. +import jax.numpy as np # side-effecting import sets up operator overloads + + +def _init(): + import os + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + +_init() +del _init diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ba298eaa0b9a..30ebdea2e0d4 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial import numpy as np from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp -from jax import test_util as jtu, jit, partial +from jax import test_util as jtu, jit from jax.config import config config.parse_flags_with_absl()