Skip to content

Commit

Permalink
Use a whitelist to restrict visibility in top-level jax namespace. (j…
Browse files Browse the repository at this point in the history
…ax-ml#2982)

* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
  • Loading branch information
hawkinsp authored May 7, 2020
1 parent 9f04d98 commit 0ea22b7
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
6 changes: 3 additions & 3 deletions docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"# TODO: for GPU? and TPU?\n",
"from jax import xla\n",
"from jax.interpreters import xla\n",
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
],
"execution_count": 0,
Expand Down Expand Up @@ -876,7 +876,7 @@
"colab": {}
},
"source": [
"from jax import ad\n",
"from jax.interpreters import ad\n",
"\n",
"\n",
"@trace(\"multiply_add_value_and_jvp\")\n",
Expand Down Expand Up @@ -1529,7 +1529,7 @@
"colab": {}
},
"source": [
"from jax import batching\n",
"from jax.interpreters import batching\n",
"\n",
"\n",
"@trace(\"multiply_add_batch\")\n",
Expand Down
76 changes: 71 additions & 5 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')

from jax.version import __version__
from jax.api import *
from .config import config
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
checkpoint,
curry, # TODO(phawkins): update users to avoid this.
custom_gradient,
custom_jvp,
custom_vjp,
custom_transforms,
defjvp,
defjvp_all,
defvjp,
defvjp_all,
device_count,
device_get,
device_put,
devices,
disable_jit,
eval_shape,
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
grad,
hessian,
host_count,
host_id,
host_ids,
jacobian,
jacfwd,
jacrev,
jit,
jvp,
local_device_count,
local_devices,
linearize,
make_jaxpr,
mask,
partial, # TODO(phawkins): update callers to use functools.partial.
pmap,
pxla, # TODO(phawkins): update users to avoid this.
remat,
shapecheck,
ShapedArray,
ShapeDtypeStruct,
soft_pmap,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
tree_flatten,
tree_leaves,
tree_map,
tree_multimap,
tree_structure,
tree_transpose,
tree_unflatten,
value_and_grad,
vjp,
vmap,
xla, # TODO(phawkins): update users to avoid this.
xla_computation,
)
from jax import nn
from jax import random
import jax.numpy as np # side-effecting import sets up operator overloads

# TODO(phawkins): remove the `np` name.
import jax.numpy as np # side-effecting import sets up operator overloads


def _init():
import os
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')

_init()
del _init
4 changes: 2 additions & 2 deletions tests/polynomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import partial
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized

from jax import numpy as jnp
from jax import test_util as jtu, jit, partial
from jax import test_util as jtu, jit

from jax.config import config
config.parse_flags_with_absl()
Expand Down

0 comments on commit 0ea22b7

Please sign in to comment.