From 0ea22b7e191a8cfabefeae18834ced2eec80670b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 7 May 2020 17:24:19 -0400 Subject: [PATCH] Use a whitelist to restrict visibility in top-level jax namespace. (#2982) * Use a whitelist to restrict visibility in top-level jax namespace. The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs. --- docs/notebooks/How_JAX_primitives_work.ipynb | 6 +- jax/__init__.py | 76 ++++++++++++++++++-- tests/polynomial_test.py | 4 +- 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index c605336969c4..54a51ee665a6 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -700,7 +700,7 @@ "\n", "# Now we register the XLA compilation rule with JAX\n", "# TODO: for GPU? and TPU?\n", - "from jax import xla\n", + "from jax.interpreters import xla\n", "xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation" ], "execution_count": 0, @@ -876,7 +876,7 @@ "colab": {} }, "source": [ - "from jax import ad\n", + "from jax.interpreters import ad\n", "\n", "\n", "@trace(\"multiply_add_value_and_jvp\")\n", @@ -1529,7 +1529,7 @@ "colab": {} }, "source": [ - "from jax import batching\n", + "from jax.interpreters import batching\n", "\n", "\n", "@trace(\"multiply_add_batch\")\n", diff --git a/jax/__init__.py b/jax/__init__.py index e846351e58e1..9eb3dad2756b 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -12,11 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') - from jax.version import __version__ -from jax.api import * +from .config import config +from .api import ( + ad, # TODO(phawkins): update users to avoid this. + argnums_partial, # TODO(phawkins): update Haiku to not use this. + checkpoint, + curry, # TODO(phawkins): update users to avoid this. + custom_gradient, + custom_jvp, + custom_vjp, + custom_transforms, + defjvp, + defjvp_all, + defvjp, + defvjp_all, + device_count, + device_get, + device_put, + devices, + disable_jit, + eval_shape, + flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this. + grad, + hessian, + host_count, + host_id, + host_ids, + jacobian, + jacfwd, + jacrev, + jit, + jvp, + local_device_count, + local_devices, + linearize, + make_jaxpr, + mask, + partial, # TODO(phawkins): update callers to use functools.partial. + pmap, + pxla, # TODO(phawkins): update users to avoid this. + remat, + shapecheck, + ShapedArray, + ShapeDtypeStruct, + soft_pmap, + # TODO(phawkins): hide tree* functions from jax, update callers to use + # jax.tree_util. + treedef_is_leaf, + tree_flatten, + tree_leaves, + tree_map, + tree_multimap, + tree_structure, + tree_transpose, + tree_unflatten, + value_and_grad, + vjp, + vmap, + xla, # TODO(phawkins): update users to avoid this. + xla_computation, +) from jax import nn from jax import random -import jax.numpy as np # side-effecting import sets up operator overloads + +# TODO(phawkins): remove the `np` name. +import jax.numpy as np # side-effecting import sets up operator overloads + + +def _init(): + import os + os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') + +_init() +del _init diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index ba298eaa0b9a..30ebdea2e0d4 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial import numpy as np from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp -from jax import test_util as jtu, jit, partial +from jax import test_util as jtu, jit from jax.config import config config.parse_flags_with_absl()