Skip to content

Commit

Permalink
Renamed experimental/jax_to_tf to experimental/jax2tf (jax-ml#3404)
Browse files Browse the repository at this point in the history
* Renamed experimental/jax_to_tf to experimental/jax2tf

* Leave a trampoline behind, for backwards compatibility
  • Loading branch information
gnecula authored Jun 11, 2020
1 parent 832431d commit 27906ce
Show file tree
Hide file tree
Showing 16 changed files with 1,021 additions and 982 deletions.
12 changes: 6 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
- python: "3.7"
env: JAX_ENABLE_X64=0 JAX_ONLY_DOCUMENTATION=true
- python: "3.7"
# TODO: enable x64 for JAX_TO_TF
env: JAX_ENABLE_X64=0 JAX_TO_TF=true
# TODO: enable x64 for JAX2TF
env: JAX_ENABLE_X64=0 JAX2TF=true

before_install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
Expand Down Expand Up @@ -44,8 +44,8 @@ install:
conda install --yes sphinx sphinx_rtd_theme nbsphinx sphinx-autodoc-typehints jupyter_client matplotlib;
pip install sklearn;
fi
# jax_to_tf needs some fixes that are not in tensorflow==2.2.0
- if [ "$JAX_TO_TF" = true ] ;then
# jax2tf needs some fixes that are not in tensorflow==2.2.0
- if [ "$JAX2TF" = true ] ;then
pip install tf-nightly==2.3.0.dev20200525 ;
fi
script:
Expand All @@ -58,8 +58,8 @@ script:
time mypy --config-file=mypy.ini jax &&
echo "===== Checking lint with flake8 ====" &&
time flake8 . ;
elif [ "$JAX_TO_TF" = true ]; then
pytest jax/experimental/jax_to_tf/tests ;
elif [ "$JAX2TF" = true ]; then
pytest jax/experimental/jax2tf/tests ;
else
pytest tests examples ;
fi
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ Most commonly people want to use this tool in order to:
### Converting basic functions.

As a rule of thumb, if you can `jax.jit` your function then you should be able
to use `jax_to_tf.convert`:
to use `jax2tf.convert`:

```python
import jax
from jax.experimental import jax_to_tf
from jax.experimental import jax2tf
def some_jax_function(x, y, z):
return jax.something(x, y, z)

# tf_ops.from_jax is a higher order function that returns a wrapped function with
# the same signature as your input function but accepting TensorFlow tensors (or
# variables) as input.
tf_version = jax_to_tf.convert(some_jax_function)
tf_version = jax2tf.convert(some_jax_function)

# For example you can call tf_version with some TensorFlow tensors:
tf_version(tf_x, tf_y, tf_z)
Expand All @@ -59,7 +59,7 @@ is trivial.

```python
f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
f_tf = jax_to_tf.convert(f_jax)
f_tf = jax2tf.convert(f_jax)

# You can save the model just like you would with any other TensorFlow function:
my_model = tf.Module()
Expand Down
16 changes: 16 additions & 0 deletions jax/experimental/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa: F401
from .jax2tf import enable_jit, convert
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax.config import config
config.config_with_absl()

from jax.experimental import jax_to_tf
from jax.experimental import jax2tf

TRAIN_EXAMPLES = 60000
BATCH_SIZE = 1000
Expand Down Expand Up @@ -86,7 +86,7 @@ def train_epoch_jax(params, train_dataset):


@tf.function
@jax_to_tf.convert
@jax2tf.convert
def accuracy_fn_tf(params, batch):
logits = predict_jax(params, batch["image"])
return jnp.mean(jnp.argmax(logits, axis=-1) == batch["label"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from absl import app
import jax
from jax.experimental import jax_to_tf
from jax.experimental import jax2tf

import tensorflow as tf

Expand Down Expand Up @@ -47,7 +47,7 @@ class StaxModule(tf.Module):

def __init__(self, apply_fn, params, name=None):
super().__init__(name=name)
self.apply_fn = jax_to_tf.convert(apply_fn)
self.apply_fn = jax2tf.convert(apply_fn)
self.params = tf.nest.map_structure(tf.Variable, params)

@tf.function(autograph=False)
Expand Down
Loading

0 comments on commit 27906ce

Please sign in to comment.