Skip to content

Commit

Permalink
Moved shift_right implementation from tfxla to jax_to_tf (jax-ml#3378)
Browse files Browse the repository at this point in the history
* Implemented shift_right without tfxla

These ops don't actually need XLA, they should not depend on tfxla.

* Small fixes
  • Loading branch information
gnecula authored Jun 10, 2020
1 parent e6a08e2 commit b3c348c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 9 deletions.
41 changes: 37 additions & 4 deletions jax/experimental/jax_to_tf/jax_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(self, trace, val):
self._trace = trace
if not isinstance(val, (tf.Tensor, tf.Variable)):
aval = xla.abstractify(val)
val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype)
val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) # type: ignore[attribute-error]
self.val = val

@property
Expand Down Expand Up @@ -321,12 +321,45 @@ def _rem(lhs, rhs):
tf_impl[lax.max_p] = wrap_binary_op(tf.math.maximum)
tf_impl[lax.min_p] = wrap_binary_op(tf.math.minimum)

# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
tf.int8: tf.uint8,
tf.int16: tf.uint16,
tf.int32: tf.uint32,
tf.int64: tf.uint64,
}

# Map from TF unsigned types to TF signed types.
_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()}

# Note: Bitwise operations only yield identical results on unsigned integers!
tf_impl[lax.shift_left_p] = tf.bitwise.left_shift
# pylint: disable=protected-access
tf_impl[lax.shift_right_arithmetic_p] = tfxla._shift_right_arithmetic_helper
tf_impl[lax.shift_right_logical_p] = tfxla._shift_right_logical_helper
def _shift_right_arithmetic(x, y):
if x.dtype.is_unsigned:
assert x.dtype == y.dtype
orig_dtype = x.dtype
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype]
x = tf.cast(x, signed_dtype)
y = tf.cast(y, signed_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
else:
return tf.bitwise.right_shift(x, y)
tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic

def _shift_right_logical(x, y):
if x.dtype.is_unsigned:
return tf.bitwise.right_shift(x, y)
else:
assert x.dtype == y.dtype
orig_dtype = x.dtype
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype]
x = tf.cast(x, unsigned_dtype)
y = tf.cast(y, unsigned_dtype)
res = tf.bitwise.right_shift(x, y)
return tf.cast(res, orig_dtype)
tf_impl[lax.shift_right_logical_p] = _shift_right_logical

tf_impl[lax.shift_left_p] = tf.bitwise.left_shift
tf_impl[lax.not_p] = tf.bitwise.invert

Expand Down
43 changes: 40 additions & 3 deletions jax/experimental/jax_to_tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,30 @@
from typing import Any, Callable, Dict, Iterable, Optional, NamedTuple, Sequence, Tuple, Union

from absl import testing
from jax import config
from jax import test_util as jtu
from jax import dtypes
from jax import lax

import numpy as np

FLAGS = config.FLAGS

# TODO: these are copied from tests/lax_test.py (make this source of truth)
# Do not run int64 tests unless FLAGS.jax_enable_x64, otherwise we get a
# mix of int32 and int64 operations.
def supported_dtypes(dtypes):
return [t for t in dtypes if t in jtu.supported_dtypes()]
return [t for t in dtypes if
t in jtu.supported_dtypes() and
(FLAGS.jax_enable_x64 or np.dtype(t).itemsize != 8)]

float_dtypes = supported_dtypes([dtypes.bfloat16, np.float16, np.float32,
np.float64])
complex_elem_dtypes = supported_dtypes([np.float32, np.float64])
complex_dtypes = supported_dtypes([np.complex64, np.complex128])
inexact_dtypes = float_dtypes + complex_dtypes
int_dtypes = supported_dtypes([np.int32, np.int64])
uint_dtypes = supported_dtypes([np.uint32, np.uint64])
int_dtypes = supported_dtypes([np.int8, np.int16, np.int32, np.int64])
uint_dtypes = supported_dtypes([np.uint8, np.uint16, np.uint32, np.uint64])
bool_dtypes = [np.bool_]
default_dtypes = float_dtypes + int_dtypes
all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes
Expand Down Expand Up @@ -168,3 +175,33 @@ def parameterized(harness_group: Iterable[Harness],
]
for dtype in [np.float32]
)

shift_inputs = [
(arg, dtype, shift_amount)
for dtype in supported_dtypes(uint_dtypes + int_dtypes)
for arg in [
np.array([-250, -1, 0, 1, 250], dtype=dtype),
]
for shift_amount in [0, 1, 2, 3, 7]
]

lax_shift_left = jtu.cases_from_list(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_left,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
for arg, dtype, shift_amount in shift_inputs
)

lax_shift_right_logical = jtu.cases_from_list(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_right_logical,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
for arg, dtype, shift_amount in shift_inputs
)

lax_shift_right_arithmetic = jtu.cases_from_list(
Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore
lax.shift_right_arithmetic,
[arg, StaticArg(np.array([shift_amount], dtype=dtype))])
for arg, dtype, shift_amount in shift_inputs
)
19 changes: 17 additions & 2 deletions jax/experimental/jax_to_tf/tests/tf_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@
lax.bitwise_or,
lax.bitwise_xor,
lax.shift_left,
lax.shift_right_arithmetic,
lax.shift_right_logical,
)

REDUCE = (
Expand Down Expand Up @@ -245,6 +243,23 @@ def test_binary_logical_elementwise(self, f_jax):
self.assertAllClose(r_jax[np.isfinite(r_jax)],
r_tf[np.isfinite(r_tf)], atol=1e-4)

# TODO(necula): combine tests that are identical except for the harness
# wait until we get more experience with using harnesses.
@primitive_harness.parameterized(primitive_harness.lax_shift_left)
def test_shift_left(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)

@primitive_harness.parameterized(primitive_harness.lax_shift_right_logical)
def test_shift_right_logical(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)

@primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic)
def test_shift_right_arithmetic(self, harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
with_function=True)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{f_jax.__name__}",
f_jax=f_jax)
Expand Down

0 comments on commit b3c348c

Please sign in to comment.