diff --git a/jax/experimental/jax_to_tf/jax_to_tf.py b/jax/experimental/jax_to_tf/jax_to_tf.py index 90bb6ef2f704..be06160a6067 100644 --- a/jax/experimental/jax_to_tf/jax_to_tf.py +++ b/jax/experimental/jax_to_tf/jax_to_tf.py @@ -175,7 +175,7 @@ def __init__(self, trace, val): self._trace = trace if not isinstance(val, (tf.Tensor, tf.Variable)): aval = xla.abstractify(val) - val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) + val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) # type: ignore[attribute-error] self.val = val @property @@ -321,12 +321,45 @@ def _rem(lhs, rhs): tf_impl[lax.max_p] = wrap_binary_op(tf.math.maximum) tf_impl[lax.min_p] = wrap_binary_op(tf.math.minimum) +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + tf.int8: tf.uint8, + tf.int16: tf.uint16, + tf.int32: tf.uint32, + tf.int64: tf.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = {u: s for s, u in _SIGNED_TO_UNSIGNED_TABLE.items()} # Note: Bitwise operations only yield identical results on unsigned integers! -tf_impl[lax.shift_left_p] = tf.bitwise.left_shift # pylint: disable=protected-access -tf_impl[lax.shift_right_arithmetic_p] = tfxla._shift_right_arithmetic_helper -tf_impl[lax.shift_right_logical_p] = tfxla._shift_right_logical_helper +def _shift_right_arithmetic(x, y): + if x.dtype.is_unsigned: + assert x.dtype == y.dtype + orig_dtype = x.dtype + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[orig_dtype] + x = tf.cast(x, signed_dtype) + y = tf.cast(y, signed_dtype) + res = tf.bitwise.right_shift(x, y) + return tf.cast(res, orig_dtype) + else: + return tf.bitwise.right_shift(x, y) +tf_impl[lax.shift_right_arithmetic_p] = _shift_right_arithmetic + +def _shift_right_logical(x, y): + if x.dtype.is_unsigned: + return tf.bitwise.right_shift(x, y) + else: + assert x.dtype == y.dtype + orig_dtype = x.dtype + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[orig_dtype] + x = tf.cast(x, unsigned_dtype) + y = tf.cast(y, unsigned_dtype) + res = tf.bitwise.right_shift(x, y) + return tf.cast(res, orig_dtype) +tf_impl[lax.shift_right_logical_p] = _shift_right_logical + tf_impl[lax.shift_left_p] = tf.bitwise.left_shift tf_impl[lax.not_p] = tf.bitwise.invert diff --git a/jax/experimental/jax_to_tf/tests/primitive_harness.py b/jax/experimental/jax_to_tf/tests/primitive_harness.py index 9cd6959b1825..dea13eeb5033 100644 --- a/jax/experimental/jax_to_tf/tests/primitive_harness.py +++ b/jax/experimental/jax_to_tf/tests/primitive_harness.py @@ -21,23 +21,30 @@ from typing import Any, Callable, Dict, Iterable, Optional, NamedTuple, Sequence, Tuple, Union from absl import testing +from jax import config from jax import test_util as jtu from jax import dtypes from jax import lax import numpy as np +FLAGS = config.FLAGS + # TODO: these are copied from tests/lax_test.py (make this source of truth) +# Do not run int64 tests unless FLAGS.jax_enable_x64, otherwise we get a +# mix of int32 and int64 operations. def supported_dtypes(dtypes): - return [t for t in dtypes if t in jtu.supported_dtypes()] + return [t for t in dtypes if + t in jtu.supported_dtypes() and + (FLAGS.jax_enable_x64 or np.dtype(t).itemsize != 8)] float_dtypes = supported_dtypes([dtypes.bfloat16, np.float16, np.float32, np.float64]) complex_elem_dtypes = supported_dtypes([np.float32, np.float64]) complex_dtypes = supported_dtypes([np.complex64, np.complex128]) inexact_dtypes = float_dtypes + complex_dtypes -int_dtypes = supported_dtypes([np.int32, np.int64]) -uint_dtypes = supported_dtypes([np.uint32, np.uint64]) +int_dtypes = supported_dtypes([np.int8, np.int16, np.int32, np.int64]) +uint_dtypes = supported_dtypes([np.uint8, np.uint16, np.uint32, np.uint64]) bool_dtypes = [np.bool_] default_dtypes = float_dtypes + int_dtypes all_dtypes = float_dtypes + complex_dtypes + int_dtypes + bool_dtypes @@ -168,3 +175,33 @@ def parameterized(harness_group: Iterable[Harness], ] for dtype in [np.float32] ) + +shift_inputs = [ + (arg, dtype, shift_amount) + for dtype in supported_dtypes(uint_dtypes + int_dtypes) + for arg in [ + np.array([-250, -1, 0, 1, 250], dtype=dtype), + ] + for shift_amount in [0, 1, 2, 3, 7] +] + +lax_shift_left = jtu.cases_from_list( + Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore + lax.shift_left, + [arg, StaticArg(np.array([shift_amount], dtype=dtype))]) + for arg, dtype, shift_amount in shift_inputs +) + +lax_shift_right_logical = jtu.cases_from_list( + Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore + lax.shift_right_logical, + [arg, StaticArg(np.array([shift_amount], dtype=dtype))]) + for arg, dtype, shift_amount in shift_inputs +) + +lax_shift_right_arithmetic = jtu.cases_from_list( + Harness(f"_dtype={dtype.__name__}_shift_amount={shift_amount}", # type: ignore + lax.shift_right_arithmetic, + [arg, StaticArg(np.array([shift_amount], dtype=dtype))]) + for arg, dtype, shift_amount in shift_inputs +) diff --git a/jax/experimental/jax_to_tf/tests/tf_ops_test.py b/jax/experimental/jax_to_tf/tests/tf_ops_test.py index ae51a6c36dc8..93b71a9dcb3d 100644 --- a/jax/experimental/jax_to_tf/tests/tf_ops_test.py +++ b/jax/experimental/jax_to_tf/tests/tf_ops_test.py @@ -86,8 +86,6 @@ lax.bitwise_or, lax.bitwise_xor, lax.shift_left, - lax.shift_right_arithmetic, - lax.shift_right_logical, ) REDUCE = ( @@ -245,6 +243,23 @@ def test_binary_logical_elementwise(self, f_jax): self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) + # TODO(necula): combine tests that are identical except for the harness + # wait until we get more experience with using harnesses. + @primitive_harness.parameterized(primitive_harness.lax_shift_left) + def test_shift_left(self, harness): + self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), + with_function=True) + + @primitive_harness.parameterized(primitive_harness.lax_shift_right_logical) + def test_shift_right_logical(self, harness): + self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), + with_function=True) + + @primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic) + def test_shift_right_arithmetic(self, harness): + self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), + with_function=True) + @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)